In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

## Hyperparameters

In [7]:
n_RQS_knots = 5     # Number of knots in RQS transform
n_made_layers = 1   # Number of hidden layers in every made network
n_made_units = 100  # Number of units in every layer of the made network
n_flow_layers = 6   # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001     # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6)  # Number of training events
n_test = int(1e5)   # Number of testing events
n_sample = int(1e6) # Number of samples for ess evaluation

## Load the training data

In [None]:
samples = np.genfromtxt("data/unweighted_samples.csv", delimiter=',')
if (n_train + n_test > samples.shape[0]):
    raise Exception("Not enough training data")

## Split to a train and test set

In [None]:
train_samples = torch.tensor(samples[:n_train], dtype=torch.float32, device=device)
test_samples = torch.tensor(samples[n_train:n_train+n_test], dtype=torch.float32, device=device)

del samples
gc.collect()

59

## Set up the flow

In [None]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_validation_loss = np.inf
best_ess = 0

for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    scheduler.step()


    # ---------- Compute validation loss -----------
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)).mean())/(batch+1)

    print("Validation loss = ", validation_loss.item())
    writer.add_scalar("Loss_test", validation_loss.item(), epoch)

    if validation_loss < best_validation_loss:
        torch.save(flow, "flow_model_unweighted_best_validation-cuda0.pt")
        best_validation_loss = validation_loss

    
    # ---------- Compute effective sample size ----------
    # generate samples and evaluate llhs
    samples = None
    llhs = None
    with torch.no_grad():
        for i in range(10):
            s = flow.sample(int(n_sample/10))
            l = flow.log_prob(s)
            if samples is None:
                samples, llhs = s.cpu().numpy(), l.cpu().numpy()
            else:
                samples = np.vstack((samples, s.cpu().numpy()))
                llhs = np.vstack((llhs, l.cpu().numpy()))

    # Store files
    np.savetxt("/tmp/samples_file-cuda0.csv", samples, delimiter=',')
    np.savetxt("/tmp/llhs_file-cuda0.csv", np.exp(llhs), delimiter=',')

    # Run the evaluator
    cmd = os.path.abspath(os.getcwd())+'/ME_VEGAS/compute_metrics_from_likelihoods /tmp/samples_file.csv /tmp/llhs_file.csv'
    b = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout
    lines = b.decode('ascii').split("\n")

    ess = float(lines[2].split(' ')[-1])
    
    print("Effective sample size = ", ess)
    writer.add_scalar("Effective_sample_size", ess, epoch)

    if ess > best_ess:
        torch.save(flow, "flow_model_unweighted_best_ess-cuda0.pt")
        best_ess = ess
        
torch.save(flow, "flow_model_unweighted_final-cuda0.pt")

epoch =  0 batch =  0 / 977 loss =  -1.0416027307510376
epoch =  0 batch =  25 / 977 loss =  -6.431734199707325
epoch =  0 batch =  50 / 977 loss =  -10.011829846045552
epoch =  0 batch =  75 / 977 loss =  -12.376327942860758
epoch =  0 batch =  100 / 977 loss =  -14.109898804437998
epoch =  0 batch =  125 / 977 loss =  -15.341521960402293
epoch =  0 batch =  150 / 977 loss =  -16.27941198617417
epoch =  0 batch =  175 / 977 loss =  -17.0213098207658
epoch =  0 batch =  200 / 977 loss =  -17.586062104547796
epoch =  0 batch =  225 / 977 loss =  -18.07432940206697
epoch =  0 batch =  250 / 977 loss =  -18.505274564146525
epoch =  0 batch =  275 / 977 loss =  -18.873152864584025
epoch =  0 batch =  300 / 977 loss =  -19.190932302775963
epoch =  0 batch =  325 / 977 loss =  -19.458550260476535
epoch =  0 batch =  350 / 977 loss =  -19.690921658803934
epoch =  0 batch =  375 / 977 loss =  -19.899228441271386
epoch =  0 batch =  400 / 977 loss =  -20.080563714379398
epoch =  0 batch =  425 

epoch =  3 batch =  475 / 977 loss =  -23.338598840376918
epoch =  3 batch =  500 / 977 loss =  -23.338677442478332
epoch =  3 batch =  525 / 977 loss =  -23.338741063165127
epoch =  3 batch =  550 / 977 loss =  -23.338269508901835
epoch =  3 batch =  575 / 977 loss =  -23.33954894873831
epoch =  3 batch =  600 / 977 loss =  -23.33828739477276
epoch =  3 batch =  625 / 977 loss =  -23.338403025374248
epoch =  3 batch =  650 / 977 loss =  -23.340315603440818
epoch =  3 batch =  675 / 977 loss =  -23.34132449302449
epoch =  3 batch =  700 / 977 loss =  -23.34230420796918
epoch =  3 batch =  725 / 977 loss =  -23.34447663152187
epoch =  3 batch =  750 / 977 loss =  -23.344384882961247
epoch =  3 batch =  775 / 977 loss =  -23.34434564334832
epoch =  3 batch =  800 / 977 loss =  -23.344279321391948
epoch =  3 batch =  825 / 977 loss =  -23.344032871809784
epoch =  3 batch =  850 / 977 loss =  -23.346117967724396
epoch =  3 batch =  875 / 977 loss =  -23.34768231492066
epoch =  3 batch =  9

epoch =  6 batch =  950 / 977 loss =  -23.427965324634506
epoch =  6 batch =  975 / 977 loss =  -23.42775317684549
Validation loss =  -23.452537536621094
Effective sample size =  0.103185
epoch =  7 batch =  0 / 977 loss =  -23.55179214477539
epoch =  7 batch =  25 / 977 loss =  -23.40766290517954
epoch =  7 batch =  50 / 977 loss =  -23.409695793600644
epoch =  7 batch =  75 / 977 loss =  -23.429346486141807
epoch =  7 batch =  100 / 977 loss =  -23.433382506417757
epoch =  7 batch =  125 / 977 loss =  -23.44013728792705
epoch =  7 batch =  150 / 977 loss =  -23.437871465619818
epoch =  7 batch =  175 / 977 loss =  -23.445224664428007
epoch =  7 batch =  200 / 977 loss =  -23.446305260729424
epoch =  7 batch =  225 / 977 loss =  -23.441382956715803
epoch =  7 batch =  250 / 977 loss =  -23.445586200729302
epoch =  7 batch =  275 / 977 loss =  -23.445194603740298
epoch =  7 batch =  300 / 977 loss =  -23.443523704807614
epoch =  7 batch =  325 / 977 loss =  -23.444330765425786
epoch = 

epoch =  10 batch =  400 / 977 loss =  -23.4910737891447
epoch =  10 batch =  425 / 977 loss =  -23.491747041263498
epoch =  10 batch =  450 / 977 loss =  -23.49205063610543
epoch =  10 batch =  475 / 977 loss =  -23.489685459297252
epoch =  10 batch =  500 / 977 loss =  -23.486665592460106
epoch =  10 batch =  525 / 977 loss =  -23.48715352105551
epoch =  10 batch =  550 / 977 loss =  -23.486584585504833
epoch =  10 batch =  575 / 977 loss =  -23.48527737127411
epoch =  10 batch =  600 / 977 loss =  -23.486282796113947
epoch =  10 batch =  625 / 977 loss =  -23.487138294183413
epoch =  10 batch =  650 / 977 loss =  -23.486703430269547
epoch =  10 batch =  675 / 977 loss =  -23.486612901179747
epoch =  10 batch =  700 / 977 loss =  -23.48752373329413
epoch =  10 batch =  725 / 977 loss =  -23.48759303867981
epoch =  10 batch =  750 / 977 loss =  -23.48917290119928
epoch =  10 batch =  775 / 977 loss =  -23.488890470917692
epoch =  10 batch =  800 / 977 loss =  -23.489044894291073
epoch

epoch =  13 batch =  825 / 977 loss =  -23.508750114256188
epoch =  13 batch =  850 / 977 loss =  -23.508220979665897
epoch =  13 batch =  875 / 977 loss =  -23.508015094826757
epoch =  13 batch =  900 / 977 loss =  -23.50950988998159
epoch =  13 batch =  925 / 977 loss =  -23.507980719498374
epoch =  13 batch =  950 / 977 loss =  -23.50785200678588
epoch =  13 batch =  975 / 977 loss =  -23.50831168792287
Validation loss =  -23.422786712646484
Effective sample size =  1.37446e-06
epoch =  14 batch =  0 / 977 loss =  -23.513599395751953
epoch =  14 batch =  25 / 977 loss =  -23.499762755173904
epoch =  14 batch =  50 / 977 loss =  -23.52512411977731
epoch =  14 batch =  75 / 977 loss =  -23.514254369233786
epoch =  14 batch =  100 / 977 loss =  -23.525361372692753
epoch =  14 batch =  125 / 977 loss =  -23.52399268982903
epoch =  14 batch =  150 / 977 loss =  -23.516786928997924
epoch =  14 batch =  175 / 977 loss =  -23.51387610218742
epoch =  14 batch =  200 / 977 loss =  -23.5144515

epoch =  17 batch =  200 / 977 loss =  -23.52750646297017
epoch =  17 batch =  225 / 977 loss =  -23.532703138030723
epoch =  17 batch =  250 / 977 loss =  -23.53477137592208
epoch =  17 batch =  275 / 977 loss =  -23.540935606196292
epoch =  17 batch =  300 / 977 loss =  -23.541941576225792
epoch =  17 batch =  325 / 977 loss =  -23.539885269352247
epoch =  17 batch =  350 / 977 loss =  -23.538352705474573
epoch =  17 batch =  375 / 977 loss =  -23.53591046434766
epoch =  17 batch =  400 / 977 loss =  -23.534909621735743
epoch =  17 batch =  425 / 977 loss =  -23.538337962728136
epoch =  17 batch =  450 / 977 loss =  -23.538333588323134
epoch =  17 batch =  475 / 977 loss =  -23.536813335258405
epoch =  17 batch =  500 / 977 loss =  -23.536160272990383
epoch =  17 batch =  525 / 977 loss =  -23.53581134085418
epoch =  17 batch =  550 / 977 loss =  -23.533491418495785
epoch =  17 batch =  575 / 977 loss =  -23.533760706583635
epoch =  17 batch =  600 / 977 loss =  -23.532414022182245
e

epoch =  20 batch =  625 / 977 loss =  -23.556768490483595
epoch =  20 batch =  650 / 977 loss =  -23.556561629709933
epoch =  20 batch =  675 / 977 loss =  -23.55468368812427
epoch =  20 batch =  700 / 977 loss =  -23.5540804407226
epoch =  20 batch =  725 / 977 loss =  -23.553264938438577
epoch =  20 batch =  750 / 977 loss =  -23.552355968206143
epoch =  20 batch =  775 / 977 loss =  -23.552769707650267
epoch =  20 batch =  800 / 977 loss =  -23.550890335578313
epoch =  20 batch =  825 / 977 loss =  -23.550393651819128
epoch =  20 batch =  850 / 977 loss =  -23.54977525919502
epoch =  20 batch =  875 / 977 loss =  -23.54923513490861
epoch =  20 batch =  900 / 977 loss =  -23.547531828631584
epoch =  20 batch =  925 / 977 loss =  -23.54816640890959
epoch =  20 batch =  950 / 977 loss =  -23.546976861642875
epoch =  20 batch =  975 / 977 loss =  -23.54713245884318
Validation loss =  -23.40059471130371
Effective sample size =  0.112078
epoch =  21 batch =  0 / 977 loss =  -23.412302017

epoch =  24 batch =  25 / 977 loss =  -23.53301128974328
epoch =  24 batch =  50 / 977 loss =  -23.52524035584693
epoch =  24 batch =  75 / 977 loss =  -23.539735417616992
epoch =  24 batch =  100 / 977 loss =  -23.54765803271001
epoch =  24 batch =  125 / 977 loss =  -23.55465879894438
epoch =  24 batch =  150 / 977 loss =  -23.558596427866952
epoch =  24 batch =  175 / 977 loss =  -23.556398706002668
epoch =  24 batch =  200 / 977 loss =  -23.55067104130835
epoch =  24 batch =  225 / 977 loss =  -23.55034396078734
epoch =  24 batch =  250 / 977 loss =  -23.55111083376455
epoch =  24 batch =  275 / 977 loss =  -23.548700381016385
epoch =  24 batch =  300 / 977 loss =  -23.55066952119238
epoch =  24 batch =  325 / 977 loss =  -23.54956433816922
epoch =  24 batch =  350 / 977 loss =  -23.55218522704904
epoch =  24 batch =  375 / 977 loss =  -23.55135179580526
epoch =  24 batch =  400 / 977 loss =  -23.55313829531396
epoch =  24 batch =  425 / 977 loss =  -23.55273163934269
epoch =  24 b

epoch =  27 batch =  450 / 977 loss =  -23.560136147983314
epoch =  27 batch =  475 / 977 loss =  -23.55892982002065
epoch =  27 batch =  500 / 977 loss =  -23.55630236947369
epoch =  27 batch =  525 / 977 loss =  -23.555428900192894
epoch =  27 batch =  550 / 977 loss =  -23.55392010198963
epoch =  27 batch =  575 / 977 loss =  -23.555733253558476
epoch =  27 batch =  600 / 977 loss =  -23.555313964056694
epoch =  27 batch =  625 / 977 loss =  -23.55447150647832
epoch =  27 batch =  650 / 977 loss =  -23.555469070528357
epoch =  27 batch =  675 / 977 loss =  -23.55548874160951
epoch =  27 batch =  700 / 977 loss =  -23.55622787584425
epoch =  27 batch =  725 / 977 loss =  -23.556256840708528
epoch =  27 batch =  750 / 977 loss =  -23.5574533967934
epoch =  27 batch =  775 / 977 loss =  -23.55997123423314
epoch =  27 batch =  800 / 977 loss =  -23.559706334317696
epoch =  27 batch =  825 / 977 loss =  -23.558608840221957
epoch =  27 batch =  850 / 977 loss =  -23.557742707177443
epoch 

epoch =  30 batch =  875 / 977 loss =  -23.573421835355028
epoch =  30 batch =  900 / 977 loss =  -23.572737644038913
epoch =  30 batch =  925 / 977 loss =  -23.572617934538023
epoch =  30 batch =  950 / 977 loss =  -23.57317160958372
epoch =  30 batch =  975 / 977 loss =  -23.571788248468625
Validation loss =  -23.573701858520508
Effective sample size =  0.0364412
epoch =  31 batch =  0 / 977 loss =  -23.459257125854492
epoch =  31 batch =  25 / 977 loss =  -23.567093188946064
epoch =  31 batch =  50 / 977 loss =  -23.55039188908595
epoch =  31 batch =  75 / 977 loss =  -23.566100773058434
epoch =  31 batch =  100 / 977 loss =  -23.580639565345084
epoch =  31 batch =  125 / 977 loss =  -23.58408384474497
epoch =  31 batch =  150 / 977 loss =  -23.583476818160506
epoch =  31 batch =  175 / 977 loss =  -23.578006809408013
epoch =  31 batch =  200 / 977 loss =  -23.58050785728948
epoch =  31 batch =  225 / 977 loss =  -23.58463590335002
epoch =  31 batch =  250 / 977 loss =  -23.58258645

epoch =  34 batch =  275 / 977 loss =  -23.57713254983875
epoch =  34 batch =  300 / 977 loss =  -23.576962258728646
epoch =  34 batch =  325 / 977 loss =  -23.57850080033754
epoch =  34 batch =  350 / 977 loss =  -23.58240349815783
epoch =  34 batch =  375 / 977 loss =  -23.581629555276113
epoch =  34 batch =  400 / 977 loss =  -23.581402714413016
epoch =  34 batch =  425 / 977 loss =  -23.583264391187225
epoch =  34 batch =  450 / 977 loss =  -23.58419567201197
epoch =  34 batch =  475 / 977 loss =  -23.584388055721263
epoch =  34 batch =  500 / 977 loss =  -23.584146834657112
epoch =  34 batch =  525 / 977 loss =  -23.584974695067903
epoch =  34 batch =  550 / 977 loss =  -23.586680857109727
epoch =  34 batch =  575 / 977 loss =  -23.587727887762927
epoch =  34 batch =  600 / 977 loss =  -23.584957154538998
epoch =  34 batch =  625 / 977 loss =  -23.584873945949195
epoch =  34 batch =  650 / 977 loss =  -23.5847914998982
epoch =  34 batch =  675 / 977 loss =  -23.584214811494377
epo

epoch =  37 batch =  700 / 977 loss =  -23.81952721926352
epoch =  37 batch =  725 / 977 loss =  -23.82196896660754
epoch =  37 batch =  750 / 977 loss =  -23.823941966030148
epoch =  37 batch =  775 / 977 loss =  -23.82489166800508
epoch =  37 batch =  800 / 977 loss =  -23.827779555588616
epoch =  37 batch =  825 / 977 loss =  -23.82715027557446
epoch =  37 batch =  850 / 977 loss =  -23.829068021404467
epoch =  37 batch =  875 / 977 loss =  -23.830295948133067
epoch =  37 batch =  900 / 977 loss =  -23.831321146856
epoch =  37 batch =  925 / 977 loss =  -23.833089332086214
epoch =  37 batch =  950 / 977 loss =  -23.83246376687419
epoch =  37 batch =  975 / 977 loss =  -23.832697950425686
Validation loss =  -23.854013442993164
Effective sample size =  0.110289
epoch =  38 batch =  0 / 977 loss =  -23.809261322021484
epoch =  38 batch =  25 / 977 loss =  -23.884732026320236
epoch =  38 batch =  50 / 977 loss =  -23.886147442985983
epoch =  38 batch =  75 / 977 loss =  -23.880729800776

epoch =  41 batch =  100 / 977 loss =  -23.929797710758624
epoch =  41 batch =  125 / 977 loss =  -23.924843455117845
epoch =  41 batch =  150 / 977 loss =  -23.932473694251865
epoch =  41 batch =  175 / 977 loss =  -23.935811205343768
epoch =  41 batch =  200 / 977 loss =  -23.94115922462881
epoch =  41 batch =  225 / 977 loss =  -23.94198712205465
epoch =  41 batch =  250 / 977 loss =  -23.942350737127175
epoch =  41 batch =  275 / 977 loss =  -23.93804772003837
epoch =  41 batch =  300 / 977 loss =  -23.936332772340485
epoch =  41 batch =  325 / 977 loss =  -23.931333746646793
epoch =  41 batch =  350 / 977 loss =  -23.9353785066523
epoch =  41 batch =  375 / 977 loss =  -23.93542699103659
epoch =  41 batch =  400 / 977 loss =  -23.935735673975756
epoch =  41 batch =  425 / 977 loss =  -23.93452524355319
epoch =  41 batch =  450 / 977 loss =  -23.933150230120123
epoch =  41 batch =  475 / 977 loss =  -23.9311746028291
epoch =  41 batch =  500 / 977 loss =  -23.93308301076679
epoch =

epoch =  44 batch =  525 / 977 loss =  -23.928814082997835
epoch =  44 batch =  550 / 977 loss =  -23.928806993792584
epoch =  44 batch =  575 / 977 loss =  -23.926767888996338
epoch =  44 batch =  600 / 977 loss =  -23.927263002824063
epoch =  44 batch =  625 / 977 loss =  -23.929196970150493
epoch =  44 batch =  650 / 977 loss =  -23.928171131468037
epoch =  44 batch =  675 / 977 loss =  -23.928113731406842
epoch =  44 batch =  700 / 977 loss =  -23.92715357781816
epoch =  44 batch =  725 / 977 loss =  -23.928568845281248
epoch =  44 batch =  750 / 977 loss =  -23.928416923898823
epoch =  44 batch =  775 / 977 loss =  -23.928972062376346
epoch =  44 batch =  800 / 977 loss =  -23.928943557834497
epoch =  44 batch =  825 / 977 loss =  -23.92888457780888
epoch =  44 batch =  850 / 977 loss =  -23.92821556700942
epoch =  44 batch =  875 / 977 loss =  -23.927767089512788
epoch =  44 batch =  900 / 977 loss =  -23.927564309783826
epoch =  44 batch =  925 / 977 loss =  -23.92760168244462
e

epoch =  47 batch =  950 / 977 loss =  -23.93803185544181
epoch =  47 batch =  975 / 977 loss =  -23.938155012052583
Validation loss =  -23.910417556762695
Effective sample size =  0.0752158
epoch =  48 batch =  0 / 977 loss =  -23.755380630493164
epoch =  48 batch =  25 / 977 loss =  -23.941856384277344
epoch =  48 batch =  50 / 977 loss =  -23.966986338297527
epoch =  48 batch =  75 / 977 loss =  -23.974610228287546
epoch =  48 batch =  100 / 977 loss =  -23.971874350368388
epoch =  48 batch =  125 / 977 loss =  -23.964886831858802
epoch =  48 batch =  150 / 977 loss =  -23.95416391132683
epoch =  48 batch =  175 / 977 loss =  -23.9555968913165
epoch =  48 batch =  200 / 977 loss =  -23.95164466022852
epoch =  48 batch =  225 / 977 loss =  -23.953644668106485
epoch =  48 batch =  250 / 977 loss =  -23.950498284571673
epoch =  48 batch =  275 / 977 loss =  -23.950628957886625
epoch =  48 batch =  300 / 977 loss =  -23.952314053659034
epoch =  48 batch =  325 / 977 loss =  -23.95358598

epoch =  51 batch =  350 / 977 loss =  -23.943307338616773
epoch =  51 batch =  375 / 977 loss =  -23.942332222106593
epoch =  51 batch =  400 / 977 loss =  -23.944065431704235
epoch =  51 batch =  425 / 977 loss =  -23.943821593629348
epoch =  51 batch =  450 / 977 loss =  -23.943696417459616
epoch =  51 batch =  475 / 977 loss =  -23.9424337058508
epoch =  51 batch =  500 / 977 loss =  -23.94232012887676
epoch =  51 batch =  525 / 977 loss =  -23.944658199643897
epoch =  51 batch =  550 / 977 loss =  -23.94308861803446
epoch =  51 batch =  575 / 977 loss =  -23.943521516190632
epoch =  51 batch =  600 / 977 loss =  -23.94416690071093
epoch =  51 batch =  625 / 977 loss =  -23.94581759699617
epoch =  51 batch =  650 / 977 loss =  -23.944650389265533
epoch =  51 batch =  675 / 977 loss =  -23.942581997820596
epoch =  51 batch =  700 / 977 loss =  -23.942220097431598
epoch =  51 batch =  725 / 977 loss =  -23.942869273099017
epoch =  51 batch =  750 / 977 loss =  -23.9411821657427
epoch

epoch =  54 batch =  775 / 977 loss =  -23.945777177810672
epoch =  54 batch =  800 / 977 loss =  -23.945345875029265
epoch =  54 batch =  825 / 977 loss =  -23.94536541853344
epoch =  54 batch =  850 / 977 loss =  -23.944827213130466
epoch =  54 batch =  875 / 977 loss =  -23.946019640796266
epoch =  54 batch =  900 / 977 loss =  -23.945038798646582
epoch =  54 batch =  925 / 977 loss =  -23.945089408182433
epoch =  54 batch =  950 / 977 loss =  -23.94467230274349
epoch =  54 batch =  975 / 977 loss =  -23.945789001027094
Validation loss =  -23.918855667114258
Effective sample size =  0.00438931
epoch =  55 batch =  0 / 977 loss =  -23.991535186767578
epoch =  55 batch =  25 / 977 loss =  -23.995381061847393
epoch =  55 batch =  50 / 977 loss =  -23.971227084889133
epoch =  55 batch =  75 / 977 loss =  -23.95617934277183
epoch =  55 batch =  100 / 977 loss =  -23.95669812495165
epoch =  55 batch =  125 / 977 loss =  -23.949582538907485
epoch =  55 batch =  150 / 977 loss =  -23.950700

epoch =  58 batch =  175 / 977 loss =  -23.938422647389498
epoch =  58 batch =  200 / 977 loss =  -23.936646181552565
epoch =  58 batch =  225 / 977 loss =  -23.94250939799621
epoch =  58 batch =  250 / 977 loss =  -23.944734694948234
epoch =  58 batch =  275 / 977 loss =  -23.940178332121477
epoch =  58 batch =  300 / 977 loss =  -23.93928705893483
epoch =  58 batch =  325 / 977 loss =  -23.938121409503967
epoch =  58 batch =  350 / 977 loss =  -23.939945655670602
epoch =  58 batch =  375 / 977 loss =  -23.941890402043118
epoch =  58 batch =  400 / 977 loss =  -23.94174712316651
epoch =  58 batch =  425 / 977 loss =  -23.945326814069436
epoch =  58 batch =  450 / 977 loss =  -23.947243426168573
epoch =  58 batch =  475 / 977 loss =  -23.950709751674108
epoch =  58 batch =  500 / 977 loss =  -23.952216725149555
epoch =  58 batch =  525 / 977 loss =  -23.953502201762035
epoch =  58 batch =  550 / 977 loss =  -23.955438032340655
epoch =  58 batch =  575 / 977 loss =  -23.95476171705458
e

epoch =  61 batch =  600 / 977 loss =  -23.960983228762814
epoch =  61 batch =  625 / 977 loss =  -23.959587779669715
epoch =  61 batch =  650 / 977 loss =  -23.95907140986711
epoch =  61 batch =  675 / 977 loss =  -23.958907683220144
epoch =  61 batch =  700 / 977 loss =  -23.95963392652221
epoch =  61 batch =  725 / 977 loss =  -23.959577429064712
epoch =  61 batch =  750 / 977 loss =  -23.959110567318938
epoch =  61 batch =  775 / 977 loss =  -23.959640195689254
epoch =  61 batch =  800 / 977 loss =  -23.95912418294043
epoch =  61 batch =  825 / 977 loss =  -23.959006935285892
epoch =  61 batch =  850 / 977 loss =  -23.959495082725233
epoch =  61 batch =  875 / 977 loss =  -23.960668311271498
epoch =  61 batch =  900 / 977 loss =  -23.96020048439966
epoch =  61 batch =  925 / 977 loss =  -23.959992470545597
epoch =  61 batch =  950 / 977 loss =  -23.95901678388177
epoch =  61 batch =  975 / 977 loss =  -23.959148170518098
Validation loss =  -23.941669464111328
Effective sample size 

Effective sample size =  0.114305
epoch =  65 batch =  0 / 977 loss =  -24.020709991455078
epoch =  65 batch =  25 / 977 loss =  -23.98424926170936
epoch =  65 batch =  50 / 977 loss =  -23.96741122825473
epoch =  65 batch =  75 / 977 loss =  -23.980699965828343
epoch =  65 batch =  100 / 977 loss =  -23.97799914898259
epoch =  65 batch =  125 / 977 loss =  -23.978321817186146
epoch =  65 batch =  150 / 977 loss =  -23.97787553900915
epoch =  65 batch =  175 / 977 loss =  -23.97401947324926
epoch =  65 batch =  200 / 977 loss =  -23.973451889569485
epoch =  65 batch =  225 / 977 loss =  -23.974624338403213
epoch =  65 batch =  250 / 977 loss =  -23.974052915535125
epoch =  65 batch =  275 / 977 loss =  -23.973753832388617
epoch =  65 batch =  300 / 977 loss =  -23.975605498912728
epoch =  65 batch =  325 / 977 loss =  -23.976816317786486
epoch =  65 batch =  350 / 977 loss =  -23.977262279586597
epoch =  65 batch =  375 / 977 loss =  -23.976914060876744
epoch =  65 batch =  400 / 977 l

epoch =  68 batch =  400 / 977 loss =  -23.960574575790435
epoch =  68 batch =  425 / 977 loss =  -23.963246618638006
epoch =  68 batch =  450 / 977 loss =  -23.96520708088335
epoch =  68 batch =  475 / 977 loss =  -23.963399278015643
epoch =  68 batch =  500 / 977 loss =  -23.962749458358665
epoch =  68 batch =  525 / 977 loss =  -23.961196152429615
epoch =  68 batch =  550 / 977 loss =  -23.95914105797852
epoch =  68 batch =  575 / 977 loss =  -23.959357629219692
epoch =  68 batch =  600 / 977 loss =  -23.958788730538824
epoch =  68 batch =  625 / 977 loss =  -23.96041530389755
epoch =  68 batch =  650 / 977 loss =  -23.96117448001046
epoch =  68 batch =  675 / 977 loss =  -23.96053780324359
epoch =  68 batch =  700 / 977 loss =  -23.96250809684458
epoch =  68 batch =  725 / 977 loss =  -23.96286496774551
epoch =  68 batch =  750 / 977 loss =  -23.962877339592932
epoch =  68 batch =  775 / 977 loss =  -23.962209866218945
epoch =  68 batch =  800 / 977 loss =  -23.96313534664005
epoch

epoch =  71 batch =  825 / 977 loss =  -23.9699816715342
epoch =  71 batch =  850 / 977 loss =  -23.97030603787593
epoch =  71 batch =  875 / 977 loss =  -23.970455537647954
epoch =  71 batch =  900 / 977 loss =  -23.96973223162279
epoch =  71 batch =  925 / 977 loss =  -23.968400870749544
epoch =  71 batch =  950 / 977 loss =  -23.96748944512175
epoch =  71 batch =  975 / 977 loss =  -23.967136654697473
Validation loss =  -23.97859764099121
Effective sample size =  0.0264105
epoch =  72 batch =  0 / 977 loss =  -23.808425903320312
epoch =  72 batch =  25 / 977 loss =  -23.984956594613884
epoch =  72 batch =  50 / 977 loss =  -23.990587982476928
epoch =  72 batch =  75 / 977 loss =  -23.980964710837917
epoch =  72 batch =  100 / 977 loss =  -23.974390709754264
epoch =  72 batch =  125 / 977 loss =  -23.965516423422194
epoch =  72 batch =  150 / 977 loss =  -23.960718533850667
epoch =  72 batch =  175 / 977 loss =  -23.960393125360664
epoch =  72 batch =  200 / 977 loss =  -23.960121667

epoch =  75 batch =  225 / 977 loss =  -23.97393843557982
epoch =  75 batch =  250 / 977 loss =  -23.975912132111205
epoch =  75 batch =  275 / 977 loss =  -23.972732232964557
epoch =  75 batch =  300 / 977 loss =  -23.97520905554889
epoch =  75 batch =  325 / 977 loss =  -23.972318222186324
epoch =  75 batch =  350 / 977 loss =  -23.971491099083195
epoch =  75 batch =  375 / 977 loss =  -23.968627503577707
epoch =  75 batch =  400 / 977 loss =  -23.969490094077862
epoch =  75 batch =  425 / 977 loss =  -23.971526226527264
epoch =  75 batch =  450 / 977 loss =  -23.971173292781728
epoch =  75 batch =  475 / 977 loss =  -23.969682845748782
epoch =  75 batch =  500 / 977 loss =  -23.969526218558983
epoch =  75 batch =  525 / 977 loss =  -23.970151440725587
epoch =  75 batch =  550 / 977 loss =  -23.971146270281185
epoch =  75 batch =  575 / 977 loss =  -23.970503336853458
epoch =  75 batch =  600 / 977 loss =  -23.970422078289737
epoch =  75 batch =  625 / 977 loss =  -23.97240712162787


epoch =  78 batch =  650 / 977 loss =  -23.974297396048943
epoch =  78 batch =  675 / 977 loss =  -23.97364268500421
epoch =  78 batch =  700 / 977 loss =  -23.974136981066184
epoch =  78 batch =  725 / 977 loss =  -23.97159017872873
epoch =  78 batch =  750 / 977 loss =  -23.971907842968786
epoch =  78 batch =  775 / 977 loss =  -23.971805867460557
epoch =  78 batch =  800 / 977 loss =  -23.9706955622793
epoch =  78 batch =  825 / 977 loss =  -23.970233300696247
epoch =  78 batch =  850 / 977 loss =  -23.970045629875614
epoch =  78 batch =  875 / 977 loss =  -23.971678518269137
epoch =  78 batch =  900 / 977 loss =  -23.972334198099635
epoch =  78 batch =  925 / 977 loss =  -23.972790036294143
epoch =  78 batch =  950 / 977 loss =  -23.9719291711079
epoch =  78 batch =  975 / 977 loss =  -23.971178430025674
Validation loss =  -23.938068389892578
Effective sample size =  0.116162
epoch =  79 batch =  0 / 977 loss =  -23.976123809814453
epoch =  79 batch =  25 / 977 loss =  -23.96963897

epoch =  82 batch =  50 / 977 loss =  -23.981649286606732
epoch =  82 batch =  75 / 977 loss =  -23.975516018114586
epoch =  82 batch =  100 / 977 loss =  -23.98830880268965
epoch =  82 batch =  125 / 977 loss =  -23.99387205214727
epoch =  82 batch =  150 / 977 loss =  -24.000494698025527
epoch =  82 batch =  175 / 977 loss =  -24.003285158764232
epoch =  82 batch =  200 / 977 loss =  -24.00060123234839
epoch =  82 batch =  225 / 977 loss =  -24.00308830126197
epoch =  82 batch =  250 / 977 loss =  -23.999089343614312
epoch =  82 batch =  275 / 977 loss =  -23.9976961578148
epoch =  82 batch =  300 / 977 loss =  -23.99485662846866
epoch =  82 batch =  325 / 977 loss =  -23.99487319314406
epoch =  82 batch =  350 / 977 loss =  -23.991567095460375
epoch =  82 batch =  375 / 977 loss =  -23.987817875882413
epoch =  82 batch =  400 / 977 loss =  -23.985100748533025
epoch =  82 batch =  425 / 977 loss =  -23.981547078056515
epoch =  82 batch =  450 / 977 loss =  -23.981620839324073
epoch =

epoch =  85 batch =  475 / 977 loss =  -23.97322819413258
epoch =  85 batch =  500 / 977 loss =  -23.972269385636693
epoch =  85 batch =  525 / 977 loss =  -23.97416325848366
epoch =  85 batch =  550 / 977 loss =  -23.97601241857734
epoch =  85 batch =  575 / 977 loss =  -23.977059470282672
epoch =  85 batch =  600 / 977 loss =  -23.977298723877766
epoch =  85 batch =  625 / 977 loss =  -23.976234369003734
epoch =  85 batch =  650 / 977 loss =  -23.975355611235678
epoch =  85 batch =  675 / 977 loss =  -23.975105319502784
epoch =  85 batch =  700 / 977 loss =  -23.974178518276236
epoch =  85 batch =  725 / 977 loss =  -23.974595608461655
epoch =  85 batch =  750 / 977 loss =  -23.975088812857265
epoch =  85 batch =  775 / 977 loss =  -23.975307626822556
epoch =  85 batch =  800 / 977 loss =  -23.97518530141994
epoch =  85 batch =  825 / 977 loss =  -23.974131868191545
epoch =  85 batch =  850 / 977 loss =  -23.974287380482977
epoch =  85 batch =  875 / 977 loss =  -23.97366647633243
ep