In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:1")
#device = torch.device("cpu")

## Reweighting strategy - must be min, mean or max

In [7]:
reference_method = 'max'

## Hyperparameters

In [8]:
n_RQS_knots = 10   # Number of knots in RQS transform
n_made_layers = 1  # Number of hidden layers in every made network
n_made_units = 100 # Number of units in every layer of the made network
n_flow_layers = 6  # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001     # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6)  # Number of training events
n_test = int(1e5)   # Number of testing events
n_sample = int(1e6) # Number of samples for ess evaluation

## Load the train data and reweight

In [9]:
# Load
train_samples = np.genfromtxt("data/weighted_samples.csv", delimiter=',')[:n_train]
train_weights = np.genfromtxt("data/weighted_weights.csv", delimiter=',')[:n_train]

# Get reference weight
reference_methods = {'min': np.amin, 'mean': np.mean, 'max': np.amax}
ref_weight = reference_methods[reference_method](train_weights)

# Reweighting and rejection sampling
train_weights = train_weights / ref_weight
p_rejection_sampling = np.random.rand(len(train_weights))
select = p_rejection_sampling < train_weights

train_samples = train_samples[select]
train_weights = train_weights[select]
train_weights[train_weights < 1.0] = 1.0

# Normalize weights
train_weights /= train_weights.mean()

# Convert to torch tensors
train_samples = torch.tensor(train_samples, dtype=torch.float32, device=device)
train_weights = torch.tensor(train_weights, dtype=torch.float32, device=device)

## Load the test data

In [10]:
test_samples = torch.tensor(np.genfromtxt("data/unweighted_samples.csv", delimiter=',')[:n_test], dtype=torch.float32, device=device)

## Set up the flow

In [11]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [12]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_validation_loss = np.inf
best_ess = 0

for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        weights_batch = train_weights[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)*weights_batch).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    scheduler.step()
    

    # ---------- Compute validation loss -----------
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)).mean())/(batch+1)

    print("Validation loss = ", validation_loss.item())
    writer.add_scalar("Loss_test", validation_loss.item(), epoch)

    if validation_loss < best_validation_loss:
        torch.save(flow, "flow_model_weighted_{}_best_validation.pt".format(reference_method))
        best_validation_loss = validation_loss

    
    # ---------- Compute effective sample size ----------
    # generate samples and evaluate llhs
    with torch.no_grad():
        samples = flow.sample(n_sample)
        llhs = flow.log_prob(samples)

    # Store files
    np.savetxt("/tmp/samples_weighted_file.csv", samples.cpu().numpy(), delimiter=',')
    np.savetxt("/tmp/llhs_weighted_file.csv", np.exp(llhs.cpu().numpy()), delimiter=',')

    # Run the evaluator
    cmd = os.path.abspath(os.getcwd())+'/ME_VEGAS/compute_metrics_from_likelihoods /tmp/samples_weighted_file.csv /tmp/llhs_weighted_file.csv'
    b = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout
    lines = b.decode('ascii').split("\n")

    ess = float(lines[2].split(' ')[-1])
    
    print("Effective sample size = ", ess)
    writer.add_scalar("Effective_sample_size", ess, epoch)

    if ess > best_ess:
        torch.save(flow, "flow_model_weighted_{}_best_ess.pt".format(reference_method))
        best_ess = ess
        
torch.save(flow, "flow_model_weighted_{}_final.pt".format(reference_method))

epoch =  0 batch =  0 / 63 loss =  0.5265049934387207
epoch =  0 batch =  25 / 63 loss =  -4.674710771212211
epoch =  0 batch =  50 / 63 loss =  -7.249040857249615
Validation loss =  -15.671630859375
Effective sample size =  0.000161148
epoch =  1 batch =  0 / 63 loss =  -15.70964241027832
epoch =  1 batch =  25 / 63 loss =  -18.273696239177994
epoch =  1 batch =  50 / 63 loss =  -19.774563546274212
Validation loss =  -22.197216033935547
Effective sample size =  0.038571
epoch =  2 batch =  0 / 63 loss =  -22.309574127197266
epoch =  2 batch =  25 / 63 loss =  -22.544276604285606
epoch =  2 batch =  50 / 63 loss =  -22.68653211406633
Validation loss =  -22.896236419677734
Effective sample size =  0.127785
epoch =  3 batch =  0 / 63 loss =  -23.150585174560547
epoch =  3 batch =  25 / 63 loss =  -23.017963849581204
epoch =  3 batch =  50 / 63 loss =  -22.95378928091012
Validation loss =  -23.099830627441406
Effective sample size =  0.199399
epoch =  4 batch =  0 / 63 loss =  -23.2873573

epoch =  34 batch =  25 / 63 loss =  -23.854330869821403
epoch =  34 batch =  50 / 63 loss =  -23.827235764148192
Validation loss =  -23.669841766357422
Effective sample size =  0.491067
epoch =  35 batch =  0 / 63 loss =  -23.716535568237305
epoch =  35 batch =  25 / 63 loss =  -23.819158920874962
epoch =  35 batch =  50 / 63 loss =  -23.807133356730144
Validation loss =  -23.711997985839844
Effective sample size =  0.528474
epoch =  36 batch =  0 / 63 loss =  -23.878673553466797
epoch =  36 batch =  25 / 63 loss =  -23.8279794546274
epoch =  36 batch =  50 / 63 loss =  -23.786989810420017
Validation loss =  -23.606908798217773
Effective sample size =  0.414813
epoch =  37 batch =  0 / 63 loss =  -23.594127655029297
epoch =  37 batch =  25 / 63 loss =  -23.8110199708205
epoch =  37 batch =  50 / 63 loss =  -23.829611983953736
Validation loss =  -23.68359375
Effective sample size =  0.494058
epoch =  38 batch =  0 / 63 loss =  -23.930530548095703
epoch =  38 batch =  25 / 63 loss =  -2

epoch =  68 batch =  25 / 63 loss =  -23.893814233633186
epoch =  68 batch =  50 / 63 loss =  -23.890488680671243
Validation loss =  -23.7078857421875
Effective sample size =  0.511446
epoch =  69 batch =  0 / 63 loss =  -23.851409912109375
epoch =  69 batch =  25 / 63 loss =  -23.899106612572304
epoch =  69 batch =  50 / 63 loss =  -23.889161203421793
Validation loss =  -23.632160186767578
Effective sample size =  0.43597
epoch =  70 batch =  0 / 63 loss =  -23.83512306213379
epoch =  70 batch =  25 / 63 loss =  -23.852391169621395
epoch =  70 batch =  50 / 63 loss =  -23.8961139753753
Validation loss =  -23.621028900146484
Effective sample size =  0.437817
epoch =  71 batch =  0 / 63 loss =  -23.885143280029297
epoch =  71 batch =  25 / 63 loss =  -23.791781205397385
epoch =  71 batch =  50 / 63 loss =  -23.8532325146245
Validation loss =  -23.669994354248047
Effective sample size =  0.468255
epoch =  72 batch =  0 / 63 loss =  -23.839658737182617
epoch =  72 batch =  25 / 63 loss = 

epoch =  102 batch =  25 / 63 loss =  -23.909619698157677
epoch =  102 batch =  50 / 63 loss =  -23.92347062802782
Validation loss =  -23.695384979248047
Effective sample size =  0.484409
epoch =  103 batch =  0 / 63 loss =  -23.71314239501953
epoch =  103 batch =  25 / 63 loss =  -23.94952818063589
epoch =  103 batch =  50 / 63 loss =  -23.973639544318704
Validation loss =  -23.69082260131836
Effective sample size =  0.468394
epoch =  104 batch =  0 / 63 loss =  -23.832443237304688
epoch =  104 batch =  25 / 63 loss =  -23.891023855942947
epoch =  104 batch =  50 / 63 loss =  -23.907737170948703
Validation loss =  -23.704669952392578
Effective sample size =  0.483617
epoch =  105 batch =  0 / 63 loss =  -23.930498123168945
epoch =  105 batch =  25 / 63 loss =  -23.933949910677395
epoch =  105 batch =  50 / 63 loss =  -23.952912461523916
Validation loss =  -23.647663116455078
Effective sample size =  0.439676
epoch =  106 batch =  0 / 63 loss =  -23.827497482299805
epoch =  106 batch =

Effective sample size =  0.453826
epoch =  136 batch =  0 / 63 loss =  -24.155799865722656
epoch =  136 batch =  25 / 63 loss =  -24.010983833899864
epoch =  136 batch =  50 / 63 loss =  -23.983442418715534
Validation loss =  -23.56993293762207
Effective sample size =  0.394048
epoch =  137 batch =  0 / 63 loss =  -23.92264175415039
epoch =  137 batch =  25 / 63 loss =  -23.9925898038424
epoch =  137 batch =  50 / 63 loss =  -23.9881399566052
Validation loss =  -23.57893943786621
Effective sample size =  0.36668
epoch =  138 batch =  0 / 63 loss =  -23.878185272216797
epoch =  138 batch =  25 / 63 loss =  -23.949245893038235
epoch =  138 batch =  50 / 63 loss =  -23.978073868097045
Validation loss =  -23.671926498413086
Effective sample size =  0.461837
epoch =  139 batch =  0 / 63 loss =  -24.088685989379883
epoch =  139 batch =  25 / 63 loss =  -23.98609447479248
epoch =  139 batch =  50 / 63 loss =  -23.974415349025353
Validation loss =  -23.703815460205078
Effective sample size =  

epoch =  169 batch =  50 / 63 loss =  -24.01985845378801
Validation loss =  -23.710460662841797
Effective sample size =  0.482686
epoch =  170 batch =  0 / 63 loss =  -23.838436126708984
epoch =  170 batch =  25 / 63 loss =  -24.01750241793119
epoch =  170 batch =  50 / 63 loss =  -24.010229185515758
Validation loss =  -23.6807861328125
Effective sample size =  0.439981
epoch =  171 batch =  0 / 63 loss =  -23.991012573242188
epoch =  171 batch =  25 / 63 loss =  -24.010998505812424
epoch =  171 batch =  50 / 63 loss =  -24.005705814735563
Validation loss =  -23.535512924194336
Effective sample size =  0.365478
epoch =  172 batch =  0 / 63 loss =  -23.82413101196289
epoch =  172 batch =  25 / 63 loss =  -23.944464390094463
epoch =  172 batch =  50 / 63 loss =  -23.986298168406766
Validation loss =  -23.69319725036621
Effective sample size =  0.465088
epoch =  173 batch =  0 / 63 loss =  -24.142108917236328
epoch =  173 batch =  25 / 63 loss =  -24.02722505422739
epoch =  173 batch =  5

epoch =  203 batch =  25 / 63 loss =  -24.054405725919285
epoch =  203 batch =  50 / 63 loss =  -24.0297150331385
Validation loss =  -23.67188262939453
Effective sample size =  0.448519
epoch =  204 batch =  0 / 63 loss =  -24.108108520507812
epoch =  204 batch =  25 / 63 loss =  -24.071220764747032
epoch =  204 batch =  50 / 63 loss =  -24.06394760281432
Validation loss =  -23.68610382080078
Effective sample size =  0.453438
epoch =  205 batch =  0 / 63 loss =  -24.06205940246582
epoch =  205 batch =  25 / 63 loss =  -24.02853797032283
epoch =  205 batch =  50 / 63 loss =  -24.042429082533893
Validation loss =  -23.61603355407715
Effective sample size =  0.416492
epoch =  206 batch =  0 / 63 loss =  -24.059852600097656
epoch =  206 batch =  25 / 63 loss =  -24.055791781498836
epoch =  206 batch =  50 / 63 loss =  -24.064855014576633
Validation loss =  -23.733768463134766
Effective sample size =  0.474862
epoch =  207 batch =  0 / 63 loss =  -24.188644409179688
epoch =  207 batch =  25

Effective sample size =  0.449774
epoch =  237 batch =  0 / 63 loss =  -23.89727020263672
epoch =  237 batch =  25 / 63 loss =  -24.038950553307167
epoch =  237 batch =  50 / 63 loss =  -24.066268846100453
Validation loss =  -23.668548583984375
Effective sample size =  0.419494
epoch =  238 batch =  0 / 63 loss =  -23.809410095214844
epoch =  238 batch =  25 / 63 loss =  -24.048126220703125
epoch =  238 batch =  50 / 63 loss =  -24.044724296121036
Validation loss =  -23.68501091003418
Effective sample size =  0.170192
epoch =  239 batch =  0 / 63 loss =  -24.018024444580078
epoch =  239 batch =  25 / 63 loss =  -24.07668685913086
epoch =  239 batch =  50 / 63 loss =  -24.083032196643305
Validation loss =  -23.71567726135254
Effective sample size =  0.484911
epoch =  240 batch =  0 / 63 loss =  -23.877485275268555
epoch =  240 batch =  25 / 63 loss =  -24.059087093059834
epoch =  240 batch =  50 / 63 loss =  -24.08399346295525
Validation loss =  -23.662118911743164
Effective sample size

epoch =  270 batch =  50 / 63 loss =  -24.111820333144244
Validation loss =  -23.679874420166016
Effective sample size =  0.432702
epoch =  271 batch =  0 / 63 loss =  -24.18185806274414
epoch =  271 batch =  25 / 63 loss =  -24.08792018890381
epoch =  271 batch =  50 / 63 loss =  -24.10169085334329
Validation loss =  -23.628751754760742
Effective sample size =  0.391546
epoch =  272 batch =  0 / 63 loss =  -24.176959991455078
epoch =  272 batch =  25 / 63 loss =  -24.12147338573749
epoch =  272 batch =  50 / 63 loss =  -24.102398068297145
Validation loss =  -23.560823440551758
Effective sample size =  0.37011
epoch =  273 batch =  0 / 63 loss =  -24.16177749633789
epoch =  273 batch =  25 / 63 loss =  -24.059842109680176
epoch =  273 batch =  50 / 63 loss =  -24.059159970750997
Validation loss =  -23.627010345458984
Effective sample size =  0.405079
epoch =  274 batch =  0 / 63 loss =  -24.083213806152344
epoch =  274 batch =  25 / 63 loss =  -24.0931029686561
epoch =  274 batch =  50

epoch =  304 batch =  25 / 63 loss =  -24.103990921607384
epoch =  304 batch =  50 / 63 loss =  -24.10536605236577
Validation loss =  -23.633352279663086
Effective sample size =  0.364716
epoch =  305 batch =  0 / 63 loss =  -24.11505126953125
epoch =  305 batch =  25 / 63 loss =  -24.095263261061447
epoch =  305 batch =  50 / 63 loss =  -24.1188975315468
Validation loss =  -23.693689346313477
Effective sample size =  0.418236
epoch =  306 batch =  0 / 63 loss =  -24.387365341186523
epoch =  306 batch =  25 / 63 loss =  -24.155023941626915
epoch =  306 batch =  50 / 63 loss =  -24.151178546980315
Validation loss =  -23.659387588500977
Effective sample size =  0.423244
epoch =  307 batch =  0 / 63 loss =  -24.109485626220703
epoch =  307 batch =  25 / 63 loss =  -24.126172359173115
epoch =  307 batch =  50 / 63 loss =  -24.129290375055053
Validation loss =  -23.64651870727539
Effective sample size =  0.369551
epoch =  308 batch =  0 / 63 loss =  -24.129558563232422
epoch =  308 batch = 

Effective sample size =  0.446182
epoch =  338 batch =  0 / 63 loss =  -24.066242218017578
epoch =  338 batch =  25 / 63 loss =  -24.134327081533577
epoch =  338 batch =  50 / 63 loss =  -24.15033131019742
Validation loss =  -23.675052642822266
Effective sample size =  0.403751
epoch =  339 batch =  0 / 63 loss =  -23.99295425415039
epoch =  339 batch =  25 / 63 loss =  -24.165105966421272
epoch =  339 batch =  50 / 63 loss =  -24.148697983984853
Validation loss =  -23.60584259033203
Effective sample size =  0.377899
epoch =  340 batch =  0 / 63 loss =  -24.042041778564453
epoch =  340 batch =  25 / 63 loss =  -24.123057512136604
epoch =  340 batch =  50 / 63 loss =  -24.097656474393958
Validation loss =  -23.627920150756836
Effective sample size =  0.376889
epoch =  341 batch =  0 / 63 loss =  -24.143680572509766
epoch =  341 batch =  25 / 63 loss =  -24.1503507907574
epoch =  341 batch =  50 / 63 loss =  -24.151518877814798
Validation loss =  -23.650724411010742
Effective sample size

epoch =  371 batch =  50 / 63 loss =  -24.19512337329341
Validation loss =  -23.649368286132812
Effective sample size =  0.358562
epoch =  372 batch =  0 / 63 loss =  -24.206571578979492
epoch =  372 batch =  25 / 63 loss =  -24.19620550595797
epoch =  372 batch =  50 / 63 loss =  -24.20256341672411
Validation loss =  -23.670183181762695
Effective sample size =  0.419366
epoch =  373 batch =  0 / 63 loss =  -24.138446807861328
epoch =  373 batch =  25 / 63 loss =  -24.202644641582783
epoch =  373 batch =  50 / 63 loss =  -24.205748240152996
Validation loss =  -23.667781829833984
Effective sample size =  0.352033
epoch =  374 batch =  0 / 63 loss =  -24.341949462890625
epoch =  374 batch =  25 / 63 loss =  -24.222536160395695
epoch =  374 batch =  50 / 63 loss =  -24.209943023382447
Validation loss =  -23.670352935791016
Effective sample size =  0.426148
epoch =  375 batch =  0 / 63 loss =  -24.22637176513672
epoch =  375 batch =  25 / 63 loss =  -24.170695011432354
epoch =  375 batch =

epoch =  405 batch =  25 / 63 loss =  -24.240329155555138
epoch =  405 batch =  50 / 63 loss =  -24.208005718156404
Validation loss =  -23.623146057128906
Effective sample size =  0.339985
epoch =  406 batch =  0 / 63 loss =  -24.288164138793945
epoch =  406 batch =  25 / 63 loss =  -24.230159686161922
epoch =  406 batch =  50 / 63 loss =  -24.212947060080136
Validation loss =  -23.650720596313477
Effective sample size =  0.415254
epoch =  407 batch =  0 / 63 loss =  -24.33089256286621
epoch =  407 batch =  25 / 63 loss =  -24.21536269554725
epoch =  407 batch =  50 / 63 loss =  -24.21617780947218
Validation loss =  -23.630617141723633
Effective sample size =  0.398568
epoch =  408 batch =  0 / 63 loss =  -24.122148513793945
epoch =  408 batch =  25 / 63 loss =  -24.207088397099422
epoch =  408 batch =  50 / 63 loss =  -24.217079835779526
Validation loss =  -23.646596908569336
Effective sample size =  0.388403
epoch =  409 batch =  0 / 63 loss =  -24.37552833557129
epoch =  409 batch =

Effective sample size =  0.397823
epoch =  439 batch =  0 / 63 loss =  -24.144718170166016
epoch =  439 batch =  25 / 63 loss =  -24.29296508202186
epoch =  439 batch =  50 / 63 loss =  -24.24600399241728
Validation loss =  -23.654903411865234
Effective sample size =  0.419494
epoch =  440 batch =  0 / 63 loss =  -24.424150466918945
epoch =  440 batch =  25 / 63 loss =  -24.258911572969875
epoch =  440 batch =  50 / 63 loss =  -24.263265647140205
Validation loss =  -23.65079116821289
Effective sample size =  0.373664
epoch =  441 batch =  0 / 63 loss =  -24.221508026123047
epoch =  441 batch =  25 / 63 loss =  -24.25979108076829
epoch =  441 batch =  50 / 63 loss =  -24.254718817916572
Validation loss =  -23.63059425354004
Effective sample size =  0.352122
epoch =  442 batch =  0 / 63 loss =  -24.027549743652344
epoch =  442 batch =  25 / 63 loss =  -24.245762825012207
epoch =  442 batch =  50 / 63 loss =  -24.267671285891065
Validation loss =  -23.65104866027832
Effective sample size 

epoch =  472 batch =  50 / 63 loss =  -24.273486343084596
Validation loss =  -23.645673751831055
Effective sample size =  0.346752
epoch =  473 batch =  0 / 63 loss =  -24.491117477416992
epoch =  473 batch =  25 / 63 loss =  -24.284926634568436
epoch =  473 batch =  50 / 63 loss =  -24.259621638877718
Validation loss =  -23.641695022583008
Effective sample size =  0.348953
epoch =  474 batch =  0 / 63 loss =  -24.272674560546875
epoch =  474 batch =  25 / 63 loss =  -24.267077886141262
epoch =  474 batch =  50 / 63 loss =  -24.262699538586187
Validation loss =  -23.64421272277832
Effective sample size =  0.405358
epoch =  475 batch =  0 / 63 loss =  -24.053937911987305
epoch =  475 batch =  25 / 63 loss =  -24.268491451556862
epoch =  475 batch =  50 / 63 loss =  -24.265555138681446
Validation loss =  -23.647245407104492
Effective sample size =  0.220466
epoch =  476 batch =  0 / 63 loss =  -24.177108764648438
epoch =  476 batch =  25 / 63 loss =  -24.27993532327505
epoch =  476 batch

epoch =  506 batch =  25 / 63 loss =  -24.293442946213943
epoch =  506 batch =  50 / 63 loss =  -24.288813871495865
Validation loss =  -23.64175033569336
Effective sample size =  0.419453
epoch =  507 batch =  0 / 63 loss =  -24.59733009338379
epoch =  507 batch =  25 / 63 loss =  -24.314230698805588
epoch =  507 batch =  50 / 63 loss =  -24.294240839341107
Validation loss =  -23.643583297729492
Effective sample size =  0.400132
epoch =  508 batch =  0 / 63 loss =  -24.253498077392578
epoch =  508 batch =  25 / 63 loss =  -24.293355208176834
epoch =  508 batch =  50 / 63 loss =  -24.28297742207845
Validation loss =  -23.641239166259766
Effective sample size =  0.17887
epoch =  509 batch =  0 / 63 loss =  -24.452707290649414
epoch =  509 batch =  25 / 63 loss =  -24.28084813631498
epoch =  509 batch =  50 / 63 loss =  -24.290191650390625
Validation loss =  -23.63865852355957
Effective sample size =  0.392679
epoch =  510 batch =  0 / 63 loss =  -24.2437686920166
epoch =  510 batch =  25

Effective sample size =  0.348584
epoch =  540 batch =  0 / 63 loss =  -24.368974685668945
epoch =  540 batch =  25 / 63 loss =  -24.28566551208496
epoch =  540 batch =  50 / 63 loss =  -24.284935782937442
Validation loss =  -23.638057708740234
Effective sample size =  0.38629
epoch =  541 batch =  0 / 63 loss =  -24.27193260192871
epoch =  541 batch =  25 / 63 loss =  -24.297491293687088
epoch =  541 batch =  50 / 63 loss =  -24.286824693866805
Validation loss =  -23.63532257080078
Effective sample size =  0.106747
epoch =  542 batch =  0 / 63 loss =  -24.443496704101562
epoch =  542 batch =  25 / 63 loss =  -24.260041750394375
epoch =  542 batch =  50 / 63 loss =  -24.278472414203716
Validation loss =  -23.627534866333008
Effective sample size =  0.367269
epoch =  543 batch =  0 / 63 loss =  -24.453277587890625
epoch =  543 batch =  25 / 63 loss =  -24.302634972792404
epoch =  543 batch =  50 / 63 loss =  -24.28652026606541
Validation loss =  -23.639007568359375
Effective sample size

epoch =  573 batch =  50 / 63 loss =  -24.283111198275698
Validation loss =  -23.633275985717773
Effective sample size =  0.188818
epoch =  574 batch =  0 / 63 loss =  -24.457279205322266
epoch =  574 batch =  25 / 63 loss =  -24.279163360595703
epoch =  574 batch =  50 / 63 loss =  -24.278379440307617
Validation loss =  -23.63435173034668
Effective sample size =  0.263002
epoch =  575 batch =  0 / 63 loss =  -24.135116577148438
epoch =  575 batch =  25 / 63 loss =  -24.28250327477088
epoch =  575 batch =  50 / 63 loss =  -24.294652153463925
Validation loss =  -23.630430221557617
Effective sample size =  0.386338
epoch =  576 batch =  0 / 63 loss =  -24.423891067504883
epoch =  576 batch =  25 / 63 loss =  -24.31002000661997
epoch =  576 batch =  50 / 63 loss =  -24.30226666319604
Validation loss =  -23.63564682006836
Effective sample size =  0.364956
epoch =  577 batch =  0 / 63 loss =  -24.16214370727539
epoch =  577 batch =  25 / 63 loss =  -24.28742680182824
epoch =  577 batch =  5

epoch =  607 batch =  25 / 63 loss =  -24.294973300053524
epoch =  607 batch =  50 / 63 loss =  -24.300674887264478
Validation loss =  -23.633790969848633
Effective sample size =  0.288688
epoch =  608 batch =  0 / 63 loss =  -24.378671646118164
epoch =  608 batch =  25 / 63 loss =  -24.30705136519212
epoch =  608 batch =  50 / 63 loss =  -24.298097722670615
Validation loss =  -23.633838653564453
Effective sample size =  0.371888
epoch =  609 batch =  0 / 63 loss =  -24.33260154724121
epoch =  609 batch =  25 / 63 loss =  -24.30764990586501
epoch =  609 batch =  50 / 63 loss =  -24.295659457936
Validation loss =  -23.627099990844727
Effective sample size =  0.406679
epoch =  610 batch =  0 / 63 loss =  -24.122455596923828
epoch =  610 batch =  25 / 63 loss =  -24.30051950307993
epoch =  610 batch =  50 / 63 loss =  -24.286463344798367
Validation loss =  -23.625022888183594
Effective sample size =  0.347441
epoch =  611 batch =  0 / 63 loss =  -24.130496978759766
epoch =  611 batch =  2

Effective sample size =  0.338936
epoch =  641 batch =  0 / 63 loss =  -24.385730743408203
epoch =  641 batch =  25 / 63 loss =  -24.299401943500225
epoch =  641 batch =  50 / 63 loss =  -24.30034345739028
Validation loss =  -23.626163482666016
Effective sample size =  0.346454
epoch =  642 batch =  0 / 63 loss =  -24.41001319885254
epoch =  642 batch =  25 / 63 loss =  -24.323838380666878
epoch =  642 batch =  50 / 63 loss =  -24.31231846528894
Validation loss =  -23.62674903869629
Effective sample size =  0.378584
epoch =  643 batch =  0 / 63 loss =  -24.277095794677734
epoch =  643 batch =  25 / 63 loss =  -24.302058366628792
epoch =  643 batch =  50 / 63 loss =  -24.309597800759708
Validation loss =  -23.62357521057129
Effective sample size =  0.362046
epoch =  644 batch =  0 / 63 loss =  -24.191925048828125
epoch =  644 batch =  25 / 63 loss =  -24.27369389167199
epoch =  644 batch =  50 / 63 loss =  -24.29667080149931
Validation loss =  -23.630239486694336
Effective sample size =

epoch =  674 batch =  50 / 63 loss =  -24.305427962658452
Validation loss =  -23.631126403808594
Effective sample size =  0.044512
epoch =  675 batch =  0 / 63 loss =  -24.251554489135742
epoch =  675 batch =  25 / 63 loss =  -24.327580525324894
epoch =  675 batch =  50 / 63 loss =  -24.318885391833735
Validation loss =  -23.629728317260742
Effective sample size =  0.341992
epoch =  676 batch =  0 / 63 loss =  -24.262821197509766
epoch =  676 batch =  25 / 63 loss =  -24.291501778822678
epoch =  676 batch =  50 / 63 loss =  -24.315180423212986
Validation loss =  -23.629947662353516
Effective sample size =  0.3714
epoch =  677 batch =  0 / 63 loss =  -24.15227508544922
epoch =  677 batch =  25 / 63 loss =  -24.27300159747784
epoch =  677 batch =  50 / 63 loss =  -24.313009411680927
Validation loss =  -23.62763023376465
Effective sample size =  0.361385
epoch =  678 batch =  0 / 63 loss =  -24.133901596069336
epoch =  678 batch =  25 / 63 loss =  -24.308438081007736
epoch =  678 batch = 

epoch =  708 batch =  25 / 63 loss =  -24.28345826955942
epoch =  708 batch =  50 / 63 loss =  -24.311298594755286
Validation loss =  -23.628353118896484
Effective sample size =  0.390235
epoch =  709 batch =  0 / 63 loss =  -24.337318420410156
epoch =  709 batch =  25 / 63 loss =  -24.32306172297551
epoch =  709 batch =  50 / 63 loss =  -24.308885275148867
Validation loss =  -23.628915786743164
Effective sample size =  0.292245
epoch =  710 batch =  0 / 63 loss =  -24.240440368652344
epoch =  710 batch =  25 / 63 loss =  -24.329721230726975
epoch =  710 batch =  50 / 63 loss =  -24.308111564785822
Validation loss =  -23.628189086914062
Effective sample size =  0.384989
epoch =  711 batch =  0 / 63 loss =  -24.41765594482422
epoch =  711 batch =  25 / 63 loss =  -24.291439056396484
epoch =  711 batch =  50 / 63 loss =  -24.31165137945437
Validation loss =  -23.628910064697266
Effective sample size =  0.348716
epoch =  712 batch =  0 / 63 loss =  -24.43328857421875
epoch =  712 batch = 

Effective sample size =  0.333876
epoch =  742 batch =  0 / 63 loss =  -24.25192642211914
epoch =  742 batch =  25 / 63 loss =  -24.30359180157001
epoch =  742 batch =  50 / 63 loss =  -24.320946450326957
Validation loss =  -23.626497268676758
Effective sample size =  0.370091
epoch =  743 batch =  0 / 63 loss =  -24.26023292541504
epoch =  743 batch =  25 / 63 loss =  -24.31954112419715
epoch =  743 batch =  50 / 63 loss =  -24.3338114046583
Validation loss =  -23.628210067749023
Effective sample size =  0.363435
epoch =  744 batch =  0 / 63 loss =  -24.287227630615234
epoch =  744 batch =  25 / 63 loss =  -24.31111607184777
epoch =  744 batch =  50 / 63 loss =  -24.314895816877776
Validation loss =  -23.62639617919922
Effective sample size =  0.380502
epoch =  745 batch =  0 / 63 loss =  -24.29705047607422
epoch =  745 batch =  25 / 63 loss =  -24.317374742948093
epoch =  745 batch =  50 / 63 loss =  -24.30817147797229
Validation loss =  -23.627363204956055
Effective sample size =  0

epoch =  775 batch =  50 / 63 loss =  -24.323148764815986
Validation loss =  -23.62754249572754
Effective sample size =  0.321149
epoch =  776 batch =  0 / 63 loss =  -24.342388153076172
epoch =  776 batch =  25 / 63 loss =  -24.311338424682617
epoch =  776 batch =  50 / 63 loss =  -24.305698319977406
Validation loss =  -23.62751007080078
Effective sample size =  0.327603
epoch =  777 batch =  0 / 63 loss =  -24.44522476196289
epoch =  777 batch =  25 / 63 loss =  -24.330693391653213
epoch =  777 batch =  50 / 63 loss =  -24.315040850171854
Validation loss =  -23.62718963623047
Effective sample size =  0.362752
epoch =  778 batch =  0 / 63 loss =  -24.00171661376953
epoch =  778 batch =  25 / 63 loss =  -24.331462933466984
epoch =  778 batch =  50 / 63 loss =  -24.308538362091664
Validation loss =  -23.627901077270508
Effective sample size =  0.393212
epoch =  779 batch =  0 / 63 loss =  -24.375083923339844
epoch =  779 batch =  25 / 63 loss =  -24.35012986109807
epoch =  779 batch =  