In [7]:
import numpy as np
import torch
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib as mpl
import time

In [2]:
from src.probabilistic_dag_model.probabilistic_dag import ProbabilisticDAG
from src.probabilistic_dag_model.train_dag import train

# Sampling time

In [8]:
n_samples = 30
sampling_times = np.zeros(n_samples)
prob_dag_model = ProbabilisticDAG(n_nodes=100,
                                  order_type='topk',
                                  #order_type='sinkhorn',
                                  initial_adj=None, 
                                  seed=100)
for i in range(n_samples):
    t0 = time.time()
    A = prob_dag_model.sample().detach().cpu().numpy()
    sampling_times[i] = time.time() - t0
print('Mean sampling time: ', sampling_times.mean())

Mean sampling time:  0.0016643206278483074


# DAG learning with a ground-truth dag

In [9]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

n_nodes=10
true_dag_adj = torch.triu(torch.ones(n_nodes, n_nodes, device=device), 1)
model = ProbabilisticDAG(n_nodes=n_nodes,
                         hard=True,
                         #order_type='sinkhorn',
                         order_type='topk',
                         lr=1e-2,
                         seed=0)

In [12]:
model, losses, sampled_mse_losses = train(model,
                                          true_dag_adj=true_dag_adj,
                                          max_epochs=30000,
                                          patience=100
                                         )

Epoch 0 -> prob_abs_loss 0.06914740800857544 | sampled_mse_loss 1.0
Model saved
Epoch 10 -> prob_abs_loss 0.06894344091415405 | sampled_mse_loss 1.0
Model saved
Epoch 20 -> prob_abs_loss 0.06835818290710449 | sampled_mse_loss 4.0
Model saved
Epoch 30 -> prob_abs_loss 0.06802433729171753 | sampled_mse_loss 7.0
Model saved
Epoch 40 -> prob_abs_loss 0.06790167093276978 | sampled_mse_loss 3.0
Model saved
Epoch 50 -> prob_abs_loss 0.0678589940071106 | sampled_mse_loss 5.0
Model saved
Epoch 60 -> prob_abs_loss 0.06750613451004028 | sampled_mse_loss 5.0
Model saved
Epoch 70 -> prob_abs_loss 0.06730824708938599 | sampled_mse_loss 4.0
Model saved
Epoch 80 -> prob_abs_loss 0.06708943843841553 | sampled_mse_loss 4.0
Model saved
Epoch 90 -> prob_abs_loss 0.06689667701721191 | sampled_mse_loss 6.0
Model saved
Epoch 100 -> prob_abs_loss 0.06656724214553833 | sampled_mse_loss 7.0
Model saved
Epoch 110 -> prob_abs_loss 0.06631773710250854 | sampled_mse_loss 4.0
Model saved
Epoch 120 -> prob_abs_loss 0

Model saved
Epoch 1030 -> prob_abs_loss 0.05301368236541748 | sampled_mse_loss 2.0
Model saved
Epoch 1040 -> prob_abs_loss 0.05298799276351929 | sampled_mse_loss 6.0
Model saved
Epoch 1050 -> prob_abs_loss 0.0528489351272583 | sampled_mse_loss 0.0
Model saved
Epoch 1060 -> prob_abs_loss 0.05279666185379028 | sampled_mse_loss 3.0
Model saved
Epoch 1070 -> prob_abs_loss 0.05252760648727417 | sampled_mse_loss 3.0
Model saved
Epoch 1080 -> prob_abs_loss 0.05216771364212036 | sampled_mse_loss 3.0
Model saved
Epoch 1090 -> prob_abs_loss 0.05197864770889282 | sampled_mse_loss 2.0
Model saved
Epoch 1100 -> prob_abs_loss 0.05191540718078613 | sampled_mse_loss 5.0
Model saved
Epoch 1110 -> prob_abs_loss 0.05172300338745117 | sampled_mse_loss 1.0
Model saved
Epoch 1120 -> prob_abs_loss 0.051579415798187256 | sampled_mse_loss 8.0
Model saved
Epoch 1130 -> prob_abs_loss 0.05153089761734009 | sampled_mse_loss 4.0
Model saved
Epoch 1140 -> prob_abs_loss 0.051514387130737305 | sampled_mse_loss 5.0
Mod

Epoch 2060 -> prob_abs_loss 0.043060362339019775 | sampled_mse_loss 6.0
Model saved
Epoch 2070 -> prob_abs_loss 0.04280298948287964 | sampled_mse_loss 4.0
Model saved
Epoch 2080 -> prob_abs_loss 0.04244166612625122 | sampled_mse_loss 6.0
Model saved
Epoch 2090 -> prob_abs_loss 0.04221218824386597 | sampled_mse_loss 2.0
Model saved
Epoch 2100 -> prob_abs_loss 0.0421367883682251 | sampled_mse_loss 7.0
Model saved
Epoch 2110 -> prob_abs_loss 0.04211074113845825 | sampled_mse_loss 3.0
Model saved
Epoch 2120 -> prob_abs_loss 0.042101919651031494 | sampled_mse_loss 2.0
Model saved
Epoch 2130 -> prob_abs_loss 0.04209417104721069 | sampled_mse_loss 2.0
Model saved
Epoch 2140 -> prob_abs_loss 0.0419653058052063 | sampled_mse_loss 3.0
Model saved
Epoch 2150 -> prob_abs_loss 0.04189032316207886 | sampled_mse_loss 3.0
Model saved
Epoch 2160 -> prob_abs_loss 0.041731059551239014 | sampled_mse_loss 1.0
Model saved
Epoch 2170 -> prob_abs_loss 0.04168045520782471 | sampled_mse_loss 7.0
Model saved
Epo

Epoch 3100 -> prob_abs_loss 0.036210834980010986 | sampled_mse_loss 3.0
Epoch 3110 -> prob_abs_loss 0.03614550828933716 | sampled_mse_loss 6.0
Model saved
Epoch 3120 -> prob_abs_loss 0.03601109981536865 | sampled_mse_loss 9.0
Model saved
Epoch 3130 -> prob_abs_loss 0.0357816219329834 | sampled_mse_loss 3.0
Model saved
Epoch 3140 -> prob_abs_loss 0.035633087158203125 | sampled_mse_loss 3.0
Model saved
Epoch 3150 -> prob_abs_loss 0.035584092140197754 | sampled_mse_loss 5.0
Model saved
Epoch 3160 -> prob_abs_loss 0.035567402839660645 | sampled_mse_loss 1.0
Model saved
Epoch 3170 -> prob_abs_loss 0.035561442375183105 | sampled_mse_loss 4.0
Model saved
Epoch 3180 -> prob_abs_loss 0.035472095012664795 | sampled_mse_loss 3.0
Model saved
Epoch 3190 -> prob_abs_loss 0.03543102741241455 | sampled_mse_loss 2.0
Model saved
Epoch 3200 -> prob_abs_loss 0.035417020320892334 | sampled_mse_loss 3.0
Model saved
Epoch 3210 -> prob_abs_loss 0.0354122519493103 | sampled_mse_loss 1.0
Model saved
Epoch 3220 

Epoch 4100 -> prob_abs_loss 0.030581414699554443 | sampled_mse_loss 6.0
Model saved
Epoch 4110 -> prob_abs_loss 0.030293524265289307 | sampled_mse_loss 4.0
Model saved
Epoch 4120 -> prob_abs_loss 0.030169427394866943 | sampled_mse_loss 5.0
Model saved
Epoch 4130 -> prob_abs_loss 0.030086517333984375 | sampled_mse_loss 5.0
Model saved
Epoch 4140 -> prob_abs_loss 0.03005880117416382 | sampled_mse_loss 9.0
Model saved
Epoch 4150 -> prob_abs_loss 0.030049264430999756 | sampled_mse_loss 1.0
Model saved
Epoch 4160 -> prob_abs_loss 0.030045926570892334 | sampled_mse_loss 5.0
Model saved
Epoch 4170 -> prob_abs_loss 0.0299718976020813 | sampled_mse_loss 2.0
Model saved
Epoch 4180 -> prob_abs_loss 0.02976226806640625 | sampled_mse_loss 2.0
Model saved
Epoch 4190 -> prob_abs_loss 0.029688596725463867 | sampled_mse_loss 1.0
Model saved
Epoch 4200 -> prob_abs_loss 0.029663681983947754 | sampled_mse_loss 3.0
Model saved
Epoch 4210 -> prob_abs_loss 0.029655098915100098 | sampled_mse_loss 3.0
Model sa

Epoch 5130 -> prob_abs_loss 0.026048362255096436 | sampled_mse_loss 4.0
Model saved
Epoch 5140 -> prob_abs_loss 0.02590346336364746 | sampled_mse_loss 4.0
Model saved
Epoch 5150 -> prob_abs_loss 0.025744318962097168 | sampled_mse_loss 3.0
Model saved
Epoch 5160 -> prob_abs_loss 0.02563035488128662 | sampled_mse_loss 3.0
Model saved
Epoch 5170 -> prob_abs_loss 0.025494873523712158 | sampled_mse_loss 4.0
Model saved
Epoch 5180 -> prob_abs_loss 0.02545088529586792 | sampled_mse_loss 3.0
Model saved
Epoch 5190 -> prob_abs_loss 0.02543550729751587 | sampled_mse_loss 11.0
Model saved
Epoch 5200 -> prob_abs_loss 0.025430262088775635 | sampled_mse_loss 2.0
Model saved
Epoch 5210 -> prob_abs_loss 0.025395691394805908 | sampled_mse_loss 4.0
Model saved
Epoch 5220 -> prob_abs_loss 0.025342106819152832 | sampled_mse_loss 4.0
Model saved
Epoch 5230 -> prob_abs_loss 0.02532482147216797 | sampled_mse_loss 3.0
Model saved
Epoch 5240 -> prob_abs_loss 0.025318622589111328 | sampled_mse_loss 3.0
Model sa

Model saved
Epoch 6140 -> prob_abs_loss 0.02302306890487671 | sampled_mse_loss 5.0
Model saved
Epoch 6150 -> prob_abs_loss 0.02294546365737915 | sampled_mse_loss 3.0
Model saved
Epoch 6160 -> prob_abs_loss 0.022840023040771484 | sampled_mse_loss 5.0
Model saved
Epoch 6170 -> prob_abs_loss 0.02273041009902954 | sampled_mse_loss 4.0
Model saved
Epoch 6180 -> prob_abs_loss 0.02268749475479126 | sampled_mse_loss 2.0
Model saved
Epoch 6190 -> prob_abs_loss 0.022624850273132324 | sampled_mse_loss 7.0
Model saved
Epoch 6200 -> prob_abs_loss 0.022544562816619873 | sampled_mse_loss 3.0
Model saved
Epoch 6210 -> prob_abs_loss 0.02251899242401123 | sampled_mse_loss 3.0
Model saved
Epoch 6220 -> prob_abs_loss 0.022510409355163574 | sampled_mse_loss 6.0
Model saved
Epoch 6230 -> prob_abs_loss 0.022507429122924805 | sampled_mse_loss 6.0
Model saved
Epoch 6240 -> prob_abs_loss 0.022506356239318848 | sampled_mse_loss 4.0
Model saved
Epoch 6250 -> prob_abs_loss 0.022505998611450195 | sampled_mse_loss 7

Epoch 7150 -> prob_abs_loss 0.02039855718612671 | sampled_mse_loss 5.0
Epoch 7160 -> prob_abs_loss 0.02039855718612671 | sampled_mse_loss 8.0
Epoch 7170 -> prob_abs_loss 0.02039855718612671 | sampled_mse_loss 6.0
Epoch 7180 -> prob_abs_loss 0.02039855718612671 | sampled_mse_loss 2.0
Epoch 7190 -> prob_abs_loss 0.020224690437316895 | sampled_mse_loss 3.0
Model saved
Epoch 7200 -> prob_abs_loss 0.020099163055419922 | sampled_mse_loss 2.0
Model saved
Epoch 7210 -> prob_abs_loss 0.020058929920196533 | sampled_mse_loss 0.0
Model saved
Epoch 7220 -> prob_abs_loss 0.020045220851898193 | sampled_mse_loss 2.0
Model saved
Epoch 7230 -> prob_abs_loss 0.020040512084960938 | sampled_mse_loss 4.0
Model saved
Epoch 7240 -> prob_abs_loss 0.020038843154907227 | sampled_mse_loss 2.0
Model saved
Epoch 7250 -> prob_abs_loss 0.020038247108459473 | sampled_mse_loss 3.0
Model saved
Epoch 7260 -> prob_abs_loss 0.020007550716400146 | sampled_mse_loss 5.0
Model saved
Epoch 7270 -> prob_abs_loss 0.01998120546340

Epoch 8200 -> prob_abs_loss 0.01812732219696045 | sampled_mse_loss 1.0
Model saved
Epoch 8210 -> prob_abs_loss 0.018126845359802246 | sampled_mse_loss 2.0
Model saved
Epoch 8220 -> prob_abs_loss 0.018126726150512695 | sampled_mse_loss 5.0
Model saved
Epoch 8230 -> prob_abs_loss 0.01808035373687744 | sampled_mse_loss 1.0
Model saved
Epoch 8240 -> prob_abs_loss 0.018065929412841797 | sampled_mse_loss 6.0
Model saved
Epoch 8250 -> prob_abs_loss 0.018061041831970215 | sampled_mse_loss 7.0
Model saved
Epoch 8260 -> prob_abs_loss 0.018059372901916504 | sampled_mse_loss 3.0
Model saved
Epoch 8270 -> prob_abs_loss 0.01802128553390503 | sampled_mse_loss 6.0
Model saved
Epoch 8280 -> prob_abs_loss 0.017988979816436768 | sampled_mse_loss 4.0
Model saved
Epoch 8290 -> prob_abs_loss 0.017978131771087646 | sampled_mse_loss 1.0
Model saved
Epoch 8300 -> prob_abs_loss 0.017931878566741943 | sampled_mse_loss 4.0
Model saved
Epoch 8310 -> prob_abs_loss 0.017857730388641357 | sampled_mse_loss 5.0
Model s

Epoch 9240 -> prob_abs_loss 0.016716063022613525 | sampled_mse_loss 1.0
Epoch 9250 -> prob_abs_loss 0.016716063022613525 | sampled_mse_loss 5.0
Epoch 9260 -> prob_abs_loss 0.016716063022613525 | sampled_mse_loss 2.0
Epoch 9270 -> prob_abs_loss 0.016716063022613525 | sampled_mse_loss 5.0
Epoch 9280 -> prob_abs_loss 0.016716063022613525 | sampled_mse_loss 3.0
Epoch 9290 -> prob_abs_loss 0.016716063022613525 | sampled_mse_loss 4.0
Epoch 9300 -> prob_abs_loss 0.016709744930267334 | sampled_mse_loss 6.0
Model saved
Epoch 9310 -> prob_abs_loss 0.01667565107345581 | sampled_mse_loss 8.0
Model saved
Epoch 9320 -> prob_abs_loss 0.01666468381881714 | sampled_mse_loss 12.0
Model saved
Epoch 9330 -> prob_abs_loss 0.016660988330841064 | sampled_mse_loss 5.0
Model saved
Epoch 9340 -> prob_abs_loss 0.016659677028656006 | sampled_mse_loss 2.0
Model saved
Epoch 9350 -> prob_abs_loss 0.016659319400787354 | sampled_mse_loss 4.0
Model saved
Epoch 9360 -> prob_abs_loss 0.016659080982208252 | sampled_mse_lo

Epoch 10300 -> prob_abs_loss 0.015589416027069092 | sampled_mse_loss 5.0
Model saved
Epoch 10310 -> prob_abs_loss 0.015567243099212646 | sampled_mse_loss 7.0
Model saved
Epoch 10320 -> prob_abs_loss 0.015559852123260498 | sampled_mse_loss 2.0
Model saved
Epoch 10330 -> prob_abs_loss 0.01555722951889038 | sampled_mse_loss 5.0
Model saved
Epoch 10340 -> prob_abs_loss 0.015556395053863525 | sampled_mse_loss 2.0
Model saved
Epoch 10350 -> prob_abs_loss 0.015556156635284424 | sampled_mse_loss 0.0
Model saved
Epoch 10360 -> prob_abs_loss 0.015556037425994873 | sampled_mse_loss 6.0
Model saved
Epoch 10370 -> prob_abs_loss 0.015555918216705322 | sampled_mse_loss 5.0
Model saved
Epoch 10380 -> prob_abs_loss 0.015555918216705322 | sampled_mse_loss 4.0
Epoch 10390 -> prob_abs_loss 0.015555918216705322 | sampled_mse_loss 10.0
Epoch 10400 -> prob_abs_loss 0.015506148338317871 | sampled_mse_loss 4.0
Model saved
Epoch 10410 -> prob_abs_loss 0.015450835227966309 | sampled_mse_loss 1.0
Model saved
Epoc

Epoch 11380 -> prob_abs_loss 0.01421266794204712 | sampled_mse_loss 6.0
Epoch 11390 -> prob_abs_loss 0.014167964458465576 | sampled_mse_loss 3.0
Model saved
Epoch 11400 -> prob_abs_loss 0.014129877090454102 | sampled_mse_loss 5.0
Model saved
Epoch 11410 -> prob_abs_loss 0.01411736011505127 | sampled_mse_loss 3.0
Model saved
Epoch 11420 -> prob_abs_loss 0.013979494571685791 | sampled_mse_loss 3.0
Model saved
Epoch 11430 -> prob_abs_loss 0.01393735408782959 | sampled_mse_loss 5.0
Model saved
Epoch 11440 -> prob_abs_loss 0.013923287391662598 | sampled_mse_loss 5.0
Model saved
Epoch 11450 -> prob_abs_loss 0.013918399810791016 | sampled_mse_loss 1.0
Model saved
Epoch 11460 -> prob_abs_loss 0.013916730880737305 | sampled_mse_loss 4.0
Model saved
Epoch 11470 -> prob_abs_loss 0.01391613483428955 | sampled_mse_loss 2.0
Model saved
Epoch 11480 -> prob_abs_loss 0.013916015625 | sampled_mse_loss 1.0
Model saved
Epoch 11490 -> prob_abs_loss 0.01391589641571045 | sampled_mse_loss 5.0
Model saved
Epo

Epoch 12450 -> prob_abs_loss 0.01341867446899414 | sampled_mse_loss 7.0
Model saved
Epoch 12460 -> prob_abs_loss 0.013410210609436035 | sampled_mse_loss 5.0
Model saved
Epoch 12470 -> prob_abs_loss 0.013407230377197266 | sampled_mse_loss 4.0
Model saved
Epoch 12480 -> prob_abs_loss 0.013406157493591309 | sampled_mse_loss 3.0
Model saved
Epoch 12490 -> prob_abs_loss 0.013405680656433105 | sampled_mse_loss 4.0
Model saved
Epoch 12500 -> prob_abs_loss 0.013405680656433105 | sampled_mse_loss 4.0
Epoch 12510 -> prob_abs_loss 0.013405561447143555 | sampled_mse_loss 2.0
Model saved
Epoch 12520 -> prob_abs_loss 0.013405561447143555 | sampled_mse_loss 2.0
Epoch 12530 -> prob_abs_loss 0.013405561447143555 | sampled_mse_loss 2.0
Epoch 12540 -> prob_abs_loss 0.013405561447143555 | sampled_mse_loss 3.0
Epoch 12550 -> prob_abs_loss 0.013405561447143555 | sampled_mse_loss 7.0
Epoch 12560 -> prob_abs_loss 0.013356506824493408 | sampled_mse_loss 5.0
Model saved
Epoch 12570 -> prob_abs_loss 0.0133306384

Epoch 13510 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 2.0
Epoch 13520 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 6.0
Epoch 13530 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 4.0
Epoch 13540 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 6.0
Epoch 13550 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 0.0
Epoch 13560 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 6.0
Epoch 13570 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 5.0
Epoch 13580 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 2.0
Epoch 13590 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 5.0
Epoch 13600 -> prob_abs_loss 0.012724995613098145 | sampled_mse_loss 8.0
Epoch 13610 -> prob_abs_loss 0.012709736824035645 | sampled_mse_loss 2.0
Model saved
Epoch 13620 -> prob_abs_loss 0.012692689895629883 | sampled_mse_loss 2.0
Model saved
Epoch 13630 -> prob_abs_loss 0.012590289115905762 | sampled_mse_loss 7.0
Model saved
Epoch 13640 -> 

Epoch 14640 -> prob_abs_loss 0.011923611164093018 | sampled_mse_loss 3.0
Model saved
Epoch 14650 -> prob_abs_loss 0.011922657489776611 | sampled_mse_loss 2.0
Model saved
Epoch 14660 -> prob_abs_loss 0.011922299861907959 | sampled_mse_loss 4.0
Model saved
Epoch 14670 -> prob_abs_loss 0.011922180652618408 | sampled_mse_loss 6.0
Model saved
Epoch 14680 -> prob_abs_loss 0.011922180652618408 | sampled_mse_loss 5.0
Epoch 14690 -> prob_abs_loss 0.011922180652618408 | sampled_mse_loss 4.0
Epoch 14700 -> prob_abs_loss 0.011903464794158936 | sampled_mse_loss 2.0
Model saved
Epoch 14710 -> prob_abs_loss 0.01184624433517456 | sampled_mse_loss 2.0
Model saved
Epoch 14720 -> prob_abs_loss 0.011813580989837646 | sampled_mse_loss 5.0
Model saved
Epoch 14730 -> prob_abs_loss 0.011802911758422852 | sampled_mse_loss 6.0
Model saved
Epoch 14740 -> prob_abs_loss 0.011799216270446777 | sampled_mse_loss 6.0
Model saved
Epoch 14750 -> prob_abs_loss 0.01179802417755127 | sampled_mse_loss 8.0
Model saved
Epoch 

Model saved
Epoch 15660 -> prob_abs_loss 0.010864853858947754 | sampled_mse_loss 3.0
Model saved
Epoch 15670 -> prob_abs_loss 0.010857820510864258 | sampled_mse_loss 7.0
Model saved
Epoch 15680 -> prob_abs_loss 0.010855317115783691 | sampled_mse_loss 4.0
Model saved
Epoch 15690 -> prob_abs_loss 0.010854482650756836 | sampled_mse_loss 5.0
Model saved
Epoch 15700 -> prob_abs_loss 0.010854125022888184 | sampled_mse_loss 3.0
Model saved
Epoch 15710 -> prob_abs_loss 0.010854125022888184 | sampled_mse_loss 2.0
Epoch 15720 -> prob_abs_loss 0.010854005813598633 | sampled_mse_loss 3.0
Model saved
Epoch 15730 -> prob_abs_loss 0.010854005813598633 | sampled_mse_loss 3.0
Epoch 15740 -> prob_abs_loss 0.010854005813598633 | sampled_mse_loss 3.0
Epoch 15750 -> prob_abs_loss 0.010752439498901367 | sampled_mse_loss 4.0
Model saved
Epoch 15760 -> prob_abs_loss 0.01070857048034668 | sampled_mse_loss 3.0
Model saved
Epoch 15770 -> prob_abs_loss 0.01069408655166626 | sampled_mse_loss 5.0
Model saved
Epoch 

Epoch 16730 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 3.0
Epoch 16740 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 2.0
Epoch 16750 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 2.0
Epoch 16760 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 5.0
Epoch 16770 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 7.0
Epoch 16780 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 3.0
Epoch 16790 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 7.0
Epoch 16800 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 4.0
Epoch 16810 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 3.0
Epoch 16820 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 4.0
Epoch 16830 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 5.0
Epoch 16840 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 5.0
Epoch 16850 -> prob_abs_loss 0.010134696960449219 | sampled_mse_loss 2.0
Epoch 16860 -> prob_abs_loss 0.010134696960449219 |

Epoch 17800 -> prob_abs_loss 0.009307622909545898 | sampled_mse_loss 7.0
Model saved
Epoch 17810 -> prob_abs_loss 0.009259700775146484 | sampled_mse_loss 5.0
Model saved
Epoch 17820 -> prob_abs_loss 0.009227633476257324 | sampled_mse_loss 3.0
Model saved
Epoch 17830 -> prob_abs_loss 0.009217023849487305 | sampled_mse_loss 3.0
Model saved
Epoch 17840 -> prob_abs_loss 0.00921332836151123 | sampled_mse_loss 3.0
Model saved
Epoch 17850 -> prob_abs_loss 0.009212136268615723 | sampled_mse_loss 3.0
Model saved
Epoch 17860 -> prob_abs_loss 0.00921165943145752 | sampled_mse_loss 4.0
Model saved
Epoch 17870 -> prob_abs_loss 0.009211540222167969 | sampled_mse_loss 2.0
Model saved
Epoch 17880 -> prob_abs_loss 0.009211540222167969 | sampled_mse_loss 4.0
Epoch 17890 -> prob_abs_loss 0.009211421012878418 | sampled_mse_loss 1.0
Model saved
Epoch 17900 -> prob_abs_loss 0.009199142456054688 | sampled_mse_loss 3.0
Model saved
Epoch 17910 -> prob_abs_loss 0.00918567180633545 | sampled_mse_loss 4.0
Model s

Epoch 18860 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 4.0
Epoch 18870 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 6.0
Epoch 18880 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 3.0
Epoch 18890 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 6.0
Epoch 18900 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 4.0
Epoch 18910 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 12.0
Epoch 18920 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 5.0
Epoch 18930 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 6.0
Epoch 18940 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 4.0
Epoch 18950 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 2.0
Epoch 18960 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 4.0
Epoch 18970 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 2.0
Epoch 18980 -> prob_abs_loss 0.008806228637695312 | sampled_mse_loss 8.0
Epoch 18990 -> prob_abs_loss 0.008783221244812012 

Epoch 19970 -> prob_abs_loss 0.008296728134155273 | sampled_mse_loss 4.0
Epoch 19980 -> prob_abs_loss 0.008275210857391357 | sampled_mse_loss 4.0
Model saved
Epoch 19990 -> prob_abs_loss 0.00825732946395874 | sampled_mse_loss 4.0
Model saved
Epoch 20000 -> prob_abs_loss 0.008251488208770752 | sampled_mse_loss 3.0
Model saved
Epoch 20010 -> prob_abs_loss 0.00824958086013794 | sampled_mse_loss 8.0
Model saved
Epoch 20020 -> prob_abs_loss 0.008248865604400635 | sampled_mse_loss 3.0
Model saved
Epoch 20030 -> prob_abs_loss 0.008248627185821533 | sampled_mse_loss 6.0
Model saved
Epoch 20040 -> prob_abs_loss 0.008248507976531982 | sampled_mse_loss 2.0
Model saved
Epoch 20050 -> prob_abs_loss 0.008248507976531982 | sampled_mse_loss 2.0
Epoch 20060 -> prob_abs_loss 0.008230984210968018 | sampled_mse_loss 11.0
Model saved
Epoch 20070 -> prob_abs_loss 0.008211672306060791 | sampled_mse_loss 5.0
Model saved
Epoch 20080 -> prob_abs_loss 0.008178174495697021 | sampled_mse_loss 3.0
Model saved
Epoch

Epoch 21010 -> prob_abs_loss 0.007817625999450684 | sampled_mse_loss 7.0
Epoch 21020 -> prob_abs_loss 0.007782280445098877 | sampled_mse_loss 6.0
Model saved
Epoch 21030 -> prob_abs_loss 0.007769525051116943 | sampled_mse_loss 1.0
Model saved
Epoch 21040 -> prob_abs_loss 0.007765233516693115 | sampled_mse_loss 3.0
Model saved
Epoch 21050 -> prob_abs_loss 0.007763803005218506 | sampled_mse_loss 7.0
Model saved
Epoch 21060 -> prob_abs_loss 0.007763326168060303 | sampled_mse_loss 2.0
Model saved
Epoch 21070 -> prob_abs_loss 0.007763087749481201 | sampled_mse_loss 5.0
Model saved
Epoch 21080 -> prob_abs_loss 0.007763087749481201 | sampled_mse_loss 3.0
Epoch 21090 -> prob_abs_loss 0.007763087749481201 | sampled_mse_loss 3.0
Epoch 21100 -> prob_abs_loss 0.007763087749481201 | sampled_mse_loss 1.0
Epoch 21110 -> prob_abs_loss 0.007763087749481201 | sampled_mse_loss 2.0
Epoch 21120 -> prob_abs_loss 0.007763087749481201 | sampled_mse_loss 5.0
Epoch 21130 -> prob_abs_loss 0.007763087749481201 | 

Epoch 22130 -> prob_abs_loss 0.0074043869972229 | sampled_mse_loss 2.0
Model saved
Epoch 22140 -> prob_abs_loss 0.007366001605987549 | sampled_mse_loss 5.0
Model saved
Epoch 22150 -> prob_abs_loss 0.007353544235229492 | sampled_mse_loss 2.0
Model saved
Epoch 22160 -> prob_abs_loss 0.0073490142822265625 | sampled_mse_loss 5.0
Model saved
Epoch 22170 -> prob_abs_loss 0.007347702980041504 | sampled_mse_loss 3.0
Model saved
Epoch 22180 -> prob_abs_loss 0.00734710693359375 | sampled_mse_loss 5.0
Model saved
Epoch 22190 -> prob_abs_loss 0.0073468685150146484 | sampled_mse_loss 5.0
Model saved
Epoch 22200 -> prob_abs_loss 0.0073403120040893555 | sampled_mse_loss 1.0
Model saved
Epoch 22210 -> prob_abs_loss 0.00730520486831665 | sampled_mse_loss 3.0
Model saved
Epoch 22220 -> prob_abs_loss 0.007294356822967529 | sampled_mse_loss 4.0
Model saved
Epoch 22230 -> prob_abs_loss 0.007290661334991455 | sampled_mse_loss 3.0
Model saved
Epoch 22240 -> prob_abs_loss 0.007289469242095947 | sampled_mse_lo

Epoch 23180 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 6.0
Epoch 23190 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 5.0
Epoch 23200 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 3.0
Epoch 23210 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 2.0
Epoch 23220 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 10.0
Epoch 23230 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 4.0
Epoch 23240 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 0.0
Epoch 23250 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 3.0
Epoch 23260 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 7.0
Epoch 23270 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 2.0
Epoch 23280 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 1.0
Epoch 23290 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 4.0
Epoch 23300 -> prob_abs_loss 0.006969749927520752 | sampled_mse_loss 3.0
Epoch 23310 -> prob_abs_loss 0.006969749927520752 

Epoch 24280 -> prob_abs_loss 0.006631314754486084 | sampled_mse_loss 1.0
Model saved
Epoch 24290 -> prob_abs_loss 0.006631314754486084 | sampled_mse_loss 7.0
Epoch 24300 -> prob_abs_loss 0.006631314754486084 | sampled_mse_loss 4.0
Epoch 24310 -> prob_abs_loss 0.006631314754486084 | sampled_mse_loss 2.0
Epoch 24320 -> prob_abs_loss 0.006618916988372803 | sampled_mse_loss 2.0
Model saved
Epoch 24330 -> prob_abs_loss 0.006599485874176025 | sampled_mse_loss 3.0
Model saved
Epoch 24340 -> prob_abs_loss 0.006593286991119385 | sampled_mse_loss 3.0
Model saved
Epoch 24350 -> prob_abs_loss 0.0065912604331970215 | sampled_mse_loss 5.0
Model saved
Epoch 24360 -> prob_abs_loss 0.006585657596588135 | sampled_mse_loss 6.0
Model saved
Epoch 24370 -> prob_abs_loss 0.00657278299331665 | sampled_mse_loss 5.0
Model saved
Epoch 24380 -> prob_abs_loss 0.006568610668182373 | sampled_mse_loss 4.0
Model saved
Epoch 24390 -> prob_abs_loss 0.0065672993659973145 | sampled_mse_loss 5.0
Model saved
Epoch 24400 -> 

Epoch 25400 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 5.0
Epoch 25410 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 8.0
Epoch 25420 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 1.0
Epoch 25430 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 4.0
Epoch 25440 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 3.0
Epoch 25450 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 3.0
Epoch 25460 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 12.0
Epoch 25470 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 9.0
Epoch 25480 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 3.0
Epoch 25490 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 4.0
Epoch 25500 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 2.0
Epoch 25510 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 6.0
Epoch 25520 -> prob_abs_loss 0.006482958793640137 | sampled_mse_loss 2.0
Epoch 25530 -> prob_abs_loss 0.006482958793640137 

Epoch 26540 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 5.0
Epoch 26550 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 3.0
Epoch 26560 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 4.0
Epoch 26570 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 2.0
Epoch 26580 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 1.0
Epoch 26590 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 3.0
Epoch 26600 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 4.0
Epoch 26610 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 5.0
Epoch 26620 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 4.0
Epoch 26630 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 3.0
Epoch 26640 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 3.0
Epoch 26650 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 5.0
Epoch 26660 -> prob_abs_loss 0.006246685981750488 | sampled_mse_loss 5.0
Epoch 26670 -> prob_abs_loss 0.006246685981750488 |

Epoch 27630 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 3.0
Epoch 27640 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 5.0
Epoch 27650 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 4.0
Epoch 27660 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 4.0
Epoch 27670 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 3.0
Epoch 27680 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 4.0
Epoch 27690 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 5.0
Epoch 27700 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 3.0
Epoch 27710 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 6.0
Epoch 27720 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 4.0
Epoch 27730 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 5.0
Epoch 27740 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 5.0
Epoch 27750 -> prob_abs_loss 0.005967199802398682 | sampled_mse_loss 4.0
Epoch 27760 -> prob_abs_loss 0.005967199802398682 |

Epoch 28770 -> prob_abs_loss 0.005836844444274902 | sampled_mse_loss 9.0
Epoch 28780 -> prob_abs_loss 0.005836844444274902 | sampled_mse_loss 5.0
Epoch 28790 -> prob_abs_loss 0.005836844444274902 | sampled_mse_loss 7.0
Epoch 28800 -> prob_abs_loss 0.005836844444274902 | sampled_mse_loss 3.0
Epoch 28810 -> prob_abs_loss 0.005830168724060059 | sampled_mse_loss 7.0
Model saved
Epoch 28820 -> prob_abs_loss 0.00582277774810791 | sampled_mse_loss 5.0
Model saved
Epoch 28830 -> prob_abs_loss 0.0058203935623168945 | sampled_mse_loss 6.0
Model saved
Epoch 28840 -> prob_abs_loss 0.005819559097290039 | sampled_mse_loss 6.0
Model saved
Epoch 28850 -> prob_abs_loss 0.0058193206787109375 | sampled_mse_loss 5.0
Model saved
Epoch 28860 -> prob_abs_loss 0.005819201469421387 | sampled_mse_loss 4.0
Model saved
Epoch 28870 -> prob_abs_loss 0.005819201469421387 | sampled_mse_loss 3.0
Epoch 28880 -> prob_abs_loss 0.005819201469421387 | sampled_mse_loss 2.0
Epoch 28890 -> prob_abs_loss 0.005790412425994873 |

Epoch 29860 -> prob_abs_loss 0.005583345890045166 | sampled_mse_loss 3.0
Epoch 29870 -> prob_abs_loss 0.005583345890045166 | sampled_mse_loss 5.0
Epoch 29880 -> prob_abs_loss 0.005583345890045166 | sampled_mse_loss 9.0
Epoch 29890 -> prob_abs_loss 0.005583345890045166 | sampled_mse_loss 1.0
Epoch 29900 -> prob_abs_loss 0.005583345890045166 | sampled_mse_loss 4.0
Epoch 29910 -> prob_abs_loss 0.005564868450164795 | sampled_mse_loss 5.0
Model saved
Epoch 29920 -> prob_abs_loss 0.005549490451812744 | sampled_mse_loss 8.0
Model saved
Epoch 29930 -> prob_abs_loss 0.005544602870941162 | sampled_mse_loss 6.0
Model saved
Epoch 29940 -> prob_abs_loss 0.005542933940887451 | sampled_mse_loss 3.0
Model saved
Epoch 29950 -> prob_abs_loss 0.005542337894439697 | sampled_mse_loss 5.0
Model saved
Epoch 29960 -> prob_abs_loss 0.005542099475860596 | sampled_mse_loss 2.0
Model saved
Epoch 29970 -> prob_abs_loss 0.005542099475860596 | sampled_mse_loss 6.0
Epoch 29980 -> prob_abs_loss 0.005541980266571045 | 