In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

## Hyperparameters

In [7]:
n_RQS_knots = 10   # Number of knots in RQS transform
n_made_layers = 3  # Number of hidden layers in every made network
n_made_units = 500 # Number of units in every layer of the made network
n_flow_layers = 12 # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001   # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6) # This is missing the required statistical factor to account for negative weights
n_test = int(1e5)

## Load the training data

In [8]:
samples = np.genfromtxt("data/negative_weight_samples.csv", delimiter=',')
weights = np.genfromtxt("data/negative_weight_weights.csv", delimiter=',')

## Find the fraction of negative events and the required statistical factor

In [None]:
f = np.sum(weights < 0)/len(weights)
c = (1-2*f)**-2
print("Fraction of negative events", f)
print("Statistical factor", c)
if (n_train + n_test)*c > samples.shape[0]:
    raise Exception("Not enough training data")

Fraction of negative events 0.2394467380986031
Statistical factor 3.6825358174692755


## Normalise weights

In [None]:
weights = np.sign(weights)

## Split to a train and test set

In [None]:
n_train_with_stats = int(n_train*c)
n_test_with_stats = int(n_test*c)

train_samples = torch.tensor(samples[:n_train_with_stats], dtype=torch.float32, device=device)
train_weights = torch.tensor(weights[:n_train_with_stats], dtype=torch.float32, device=device)
test_samples = torch.tensor(samples[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)
test_weights = torch.tensor(weights[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)

del samples
del weights
gc.collect()

90

## Set up the flow

In [None]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_loss = np.inf
for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        weights_batch = train_weights[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)*weights_batch).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    scheduler.step()
    
    # Compute validation loss
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
        weights_batch = test_weights[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)*weights_batch).mean())/(batch+1)
    
    print("Validation loss = ", validation_loss.item())
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    writer.add_scalar("Loss_test", validation_loss, epoch)
    
    if validation_loss < best_loss:
        torch.save(flow, "best_flow_model.pt")
        best_loss = validation_loss

torch.save(flow, "final_flow_model.pt")

epoch =  0 batch =  0 / 3597 loss =  17.9576358795166
epoch =  0 batch =  25 / 3597 loss =  0.5931739480449599
epoch =  0 batch =  50 / 3597 loss =  -2.5938962186668437
epoch =  0 batch =  75 / 3597 loss =  -4.261602783085485
epoch =  0 batch =  100 / 3597 loss =  -5.302870060105135
epoch =  0 batch =  125 / 3597 loss =  -6.050047131876151
epoch =  0 batch =  150 / 3597 loss =  -6.653813500376726
epoch =  0 batch =  175 / 3597 loss =  -7.145457922819663
epoch =  0 batch =  200 / 3597 loss =  -7.599675016690843
epoch =  0 batch =  225 / 3597 loss =  -7.969954705409772
epoch =  0 batch =  250 / 3597 loss =  -8.289210155725007
epoch =  0 batch =  275 / 3597 loss =  -8.554372438108146
epoch =  0 batch =  300 / 3597 loss =  -8.79812773799778
epoch =  0 batch =  325 / 3597 loss =  -9.013570887590847
epoch =  0 batch =  350 / 3597 loss =  -9.203197745218917
epoch =  0 batch =  375 / 3597 loss =  -9.354217412346545
epoch =  0 batch =  400 / 3597 loss =  -9.499882419544862
epoch =  0 batch =  4

epoch =  0 batch =  3475 / 3597 loss =  -12.81066526694199
epoch =  0 batch =  3500 / 3597 loss =  -12.818650589678596
epoch =  0 batch =  3525 / 3597 loss =  -12.82636924849963
epoch =  0 batch =  3550 / 3597 loss =  -12.832992517962412
epoch =  0 batch =  3575 / 3597 loss =  -12.841424104855525
Validation loss =  -13.9771146774292
epoch =  1 batch =  0 / 3597 loss =  -14.246454238891602
epoch =  1 batch =  25 / 3597 loss =  -14.041108791644756
epoch =  1 batch =  50 / 3597 loss =  -14.156505079830394
epoch =  1 batch =  75 / 3597 loss =  -14.05435713968779
epoch =  1 batch =  100 / 3597 loss =  -14.039949709826177
epoch =  1 batch =  125 / 3597 loss =  -14.093002849155003
epoch =  1 batch =  150 / 3597 loss =  -14.119642983998684
epoch =  1 batch =  175 / 3597 loss =  -14.119377212090926
epoch =  1 batch =  200 / 3597 loss =  -14.094803079443784
epoch =  1 batch =  225 / 3597 loss =  -14.081807191393017
epoch =  1 batch =  250 / 3597 loss =  -14.0563746676502
epoch =  1 batch =  275 

epoch =  1 batch =  3325 / 3597 loss =  -14.337573185930465
epoch =  1 batch =  3350 / 3597 loss =  -14.339742073120554
epoch =  1 batch =  3375 / 3597 loss =  -14.340022903765547
epoch =  1 batch =  3400 / 3597 loss =  -14.341916379561251
epoch =  1 batch =  3425 / 3597 loss =  -14.344772187994018
epoch =  1 batch =  3450 / 3597 loss =  -14.346889278640681
epoch =  1 batch =  3475 / 3597 loss =  -14.347250139480082
epoch =  1 batch =  3500 / 3597 loss =  -14.348909794007394
epoch =  1 batch =  3525 / 3597 loss =  -14.352779027892321
epoch =  1 batch =  3550 / 3597 loss =  -14.355382529355143
epoch =  1 batch =  3575 / 3597 loss =  -14.357113652314649
Validation loss =  -14.616870880126953
epoch =  2 batch =  0 / 3597 loss =  -14.837501525878906
epoch =  2 batch =  25 / 3597 loss =  -14.700287635509785
epoch =  2 batch =  50 / 3597 loss =  -14.615548208648084
epoch =  2 batch =  75 / 3597 loss =  -14.615412787387246
epoch =  2 batch =  100 / 3597 loss =  -14.663581687625092
epoch =  2 

epoch =  2 batch =  3150 / 3597 loss =  -14.79252499937822
epoch =  2 batch =  3175 / 3597 loss =  -14.792928848218558
epoch =  2 batch =  3200 / 3597 loss =  -14.795544390751399
epoch =  2 batch =  3225 / 3597 loss =  -14.796652232203806
epoch =  2 batch =  3250 / 3597 loss =  -14.796315499138224
epoch =  2 batch =  3275 / 3597 loss =  -14.79568148911072
epoch =  2 batch =  3300 / 3597 loss =  -14.795369102029648
epoch =  2 batch =  3325 / 3597 loss =  -14.798789903476736
epoch =  2 batch =  3350 / 3597 loss =  -14.799766401360904
epoch =  2 batch =  3375 / 3597 loss =  -14.799943755588261
epoch =  2 batch =  3400 / 3597 loss =  -14.80244406823795
epoch =  2 batch =  3425 / 3597 loss =  -14.801963307046083
epoch =  2 batch =  3450 / 3597 loss =  -14.804576300358779
epoch =  2 batch =  3475 / 3597 loss =  -14.804631808821982
epoch =  2 batch =  3500 / 3597 loss =  -14.804732692885215
epoch =  2 batch =  3525 / 3597 loss =  -14.80584778842504
epoch =  2 batch =  3550 / 3597 loss =  -14.

epoch =  3 batch =  3000 / 3597 loss =  -15.02278028715058
epoch =  3 batch =  3025 / 3597 loss =  -15.023341060709024
epoch =  3 batch =  3050 / 3597 loss =  -15.023287990998925
epoch =  3 batch =  3075 / 3597 loss =  -15.022860536959763
epoch =  3 batch =  3100 / 3597 loss =  -15.021656931311574
epoch =  3 batch =  3125 / 3597 loss =  -15.02323330081737
epoch =  3 batch =  3150 / 3597 loss =  -15.024923563987178
epoch =  3 batch =  3175 / 3597 loss =  -15.026135816982471
epoch =  3 batch =  3200 / 3597 loss =  -15.028353724469248
epoch =  3 batch =  3225 / 3597 loss =  -15.028823765035158
epoch =  3 batch =  3250 / 3597 loss =  -15.031098364316952
epoch =  3 batch =  3275 / 3597 loss =  -15.029220966483502
epoch =  3 batch =  3300 / 3597 loss =  -15.030485417546016
epoch =  3 batch =  3325 / 3597 loss =  -15.028140841660123
epoch =  3 batch =  3350 / 3597 loss =  -15.029541417471654
epoch =  3 batch =  3375 / 3597 loss =  -15.027470055632117
epoch =  3 batch =  3400 / 3597 loss =  -1

epoch =  4 batch =  2825 / 3597 loss =  -15.165194915020962
epoch =  4 batch =  2850 / 3597 loss =  -15.164841786053085
epoch =  4 batch =  2875 / 3597 loss =  -15.16409544686117
epoch =  4 batch =  2900 / 3597 loss =  -15.163910172471336
epoch =  4 batch =  2925 / 3597 loss =  -15.164854743589867
epoch =  4 batch =  2950 / 3597 loss =  -15.16470019658191
epoch =  4 batch =  2975 / 3597 loss =  -15.16631504156256
epoch =  4 batch =  3000 / 3597 loss =  -15.165048338659043
epoch =  4 batch =  3025 / 3597 loss =  -15.166337831084961
epoch =  4 batch =  3050 / 3597 loss =  -15.167202263103471
epoch =  4 batch =  3075 / 3597 loss =  -15.1675673833284
epoch =  4 batch =  3100 / 3597 loss =  -15.16865845581517
epoch =  4 batch =  3125 / 3597 loss =  -15.16949273406582
epoch =  4 batch =  3150 / 3597 loss =  -15.169659325524911
epoch =  4 batch =  3175 / 3597 loss =  -15.168034217219507
epoch =  4 batch =  3200 / 3597 loss =  -15.167977552047484
epoch =  4 batch =  3225 / 3597 loss =  -15.166

epoch =  5 batch =  2675 / 3597 loss =  -15.272952343494543
epoch =  5 batch =  2700 / 3597 loss =  -15.274478148990362
epoch =  5 batch =  2725 / 3597 loss =  -15.274069173145362
epoch =  5 batch =  2750 / 3597 loss =  -15.2782363368138
epoch =  5 batch =  2775 / 3597 loss =  -15.278668407060913
epoch =  5 batch =  2800 / 3597 loss =  -15.279261132811953
epoch =  5 batch =  2825 / 3597 loss =  -15.279348164860101
epoch =  5 batch =  2850 / 3597 loss =  -15.27671886151065
epoch =  5 batch =  2875 / 3597 loss =  -15.277596800648949
epoch =  5 batch =  2900 / 3597 loss =  -15.276391437324067
epoch =  5 batch =  2925 / 3597 loss =  -15.27610117918159
epoch =  5 batch =  2950 / 3597 loss =  -15.275830533859162
epoch =  5 batch =  2975 / 3597 loss =  -15.277163187021849
epoch =  5 batch =  3000 / 3597 loss =  -15.279022233321719
epoch =  5 batch =  3025 / 3597 loss =  -15.281352892508192
epoch =  5 batch =  3050 / 3597 loss =  -15.282058812094217
epoch =  5 batch =  3075 / 3597 loss =  -15.

epoch =  6 batch =  2525 / 3597 loss =  -15.348485336922883
epoch =  6 batch =  2550 / 3597 loss =  -15.349668822443657
epoch =  6 batch =  2575 / 3597 loss =  -15.346315208429136
epoch =  6 batch =  2600 / 3597 loss =  -15.34763884993527
epoch =  6 batch =  2625 / 3597 loss =  -15.348665150630048
epoch =  6 batch =  2650 / 3597 loss =  -15.348239867024402
epoch =  6 batch =  2675 / 3597 loss =  -15.344642312359204
epoch =  6 batch =  2700 / 3597 loss =  -15.345520934719636
epoch =  6 batch =  2725 / 3597 loss =  -15.345947358683462
epoch =  6 batch =  2750 / 3597 loss =  -15.346779820529905
epoch =  6 batch =  2775 / 3597 loss =  -15.346150136131374
epoch =  6 batch =  2800 / 3597 loss =  -15.345817940441977
epoch =  6 batch =  2825 / 3597 loss =  -15.345103882089711
epoch =  6 batch =  2850 / 3597 loss =  -15.344826210929151
epoch =  6 batch =  2875 / 3597 loss =  -15.34432842701632
epoch =  6 batch =  2900 / 3597 loss =  -15.34516999541542
epoch =  6 batch =  2925 / 3597 loss =  -15

epoch =  7 batch =  2350 / 3597 loss =  -15.421384617096818
epoch =  7 batch =  2375 / 3597 loss =  -15.420729416388053
epoch =  7 batch =  2400 / 3597 loss =  -15.420834702186314
epoch =  7 batch =  2425 / 3597 loss =  -15.422254794310266
epoch =  7 batch =  2450 / 3597 loss =  -15.424347175184147
epoch =  7 batch =  2475 / 3597 loss =  -15.420837763630708
epoch =  7 batch =  2500 / 3597 loss =  -15.422609673171747
epoch =  7 batch =  2525 / 3597 loss =  -15.422640443320134
epoch =  7 batch =  2550 / 3597 loss =  -15.420370625309364
epoch =  7 batch =  2575 / 3597 loss =  -15.42130497486695
epoch =  7 batch =  2600 / 3597 loss =  -15.423632329173017
epoch =  7 batch =  2625 / 3597 loss =  -15.425962980884863
epoch =  7 batch =  2650 / 3597 loss =  -15.423533373443822
epoch =  7 batch =  2675 / 3597 loss =  -15.42329368070993
epoch =  7 batch =  2700 / 3597 loss =  -15.422039828182193
epoch =  7 batch =  2725 / 3597 loss =  -15.423335220213769
epoch =  7 batch =  2750 / 3597 loss =  -1

epoch =  8 batch =  2200 / 3597 loss =  -15.484633949224321
epoch =  8 batch =  2225 / 3597 loss =  -15.48455393003753
epoch =  8 batch =  2250 / 3597 loss =  -15.485148673055434
epoch =  8 batch =  2275 / 3597 loss =  -15.482038278780836
epoch =  8 batch =  2300 / 3597 loss =  -15.483301411810668
epoch =  8 batch =  2325 / 3597 loss =  -15.483021151066234
epoch =  8 batch =  2350 / 3597 loss =  -15.480084013604243
epoch =  8 batch =  2375 / 3597 loss =  -15.48152215151674
epoch =  8 batch =  2400 / 3597 loss =  -15.481403819524658
epoch =  8 batch =  2425 / 3597 loss =  -15.4789464762657
epoch =  8 batch =  2450 / 3597 loss =  -15.479583158146744
epoch =  8 batch =  2475 / 3597 loss =  -15.479938601060905
epoch =  8 batch =  2500 / 3597 loss =  -15.482374735042503
epoch =  8 batch =  2525 / 3597 loss =  -15.482019523732355
epoch =  8 batch =  2550 / 3597 loss =  -15.483009880077878
epoch =  8 batch =  2575 / 3597 loss =  -15.481658685281408
epoch =  8 batch =  2600 / 3597 loss =  -15.

epoch =  9 batch =  2050 / 3597 loss =  -15.546698233140846
epoch =  9 batch =  2075 / 3597 loss =  -15.543681814491405
epoch =  9 batch =  2100 / 3597 loss =  -15.550079456004344
epoch =  9 batch =  2125 / 3597 loss =  -15.550900569639449
epoch =  9 batch =  2150 / 3597 loss =  -15.556063199808076
epoch =  9 batch =  2175 / 3597 loss =  -15.55353645279127
epoch =  9 batch =  2200 / 3597 loss =  -15.552626712925594
epoch =  9 batch =  2225 / 3597 loss =  -15.551568081865199
epoch =  9 batch =  2250 / 3597 loss =  -15.552270198281423
epoch =  9 batch =  2275 / 3597 loss =  -15.552240067081417
epoch =  9 batch =  2300 / 3597 loss =  -15.55214607772595
epoch =  9 batch =  2325 / 3597 loss =  -15.551857163551443
epoch =  9 batch =  2350 / 3597 loss =  -15.553237138226509
epoch =  9 batch =  2375 / 3597 loss =  -15.554513766308023
epoch =  9 batch =  2400 / 3597 loss =  -15.55698483628762
epoch =  9 batch =  2425 / 3597 loss =  -15.554654678429637
epoch =  9 batch =  2450 / 3597 loss =  -15

epoch =  10 batch =  1850 / 3597 loss =  -15.565925069012428
epoch =  10 batch =  1875 / 3597 loss =  -15.562368883507084
epoch =  10 batch =  1900 / 3597 loss =  -15.56509430517591
epoch =  10 batch =  1925 / 3597 loss =  -15.569105954927819
epoch =  10 batch =  1950 / 3597 loss =  -15.574130655129343
epoch =  10 batch =  1975 / 3597 loss =  -15.573786655900932
epoch =  10 batch =  2000 / 3597 loss =  -15.575534225761265
epoch =  10 batch =  2025 / 3597 loss =  -15.573833719763797
epoch =  10 batch =  2050 / 3597 loss =  -15.566630053787566
epoch =  10 batch =  2075 / 3597 loss =  -15.566427347517656
epoch =  10 batch =  2100 / 3597 loss =  -15.565891568630096
epoch =  10 batch =  2125 / 3597 loss =  -15.566914098170828
epoch =  10 batch =  2150 / 3597 loss =  -15.569094860514948
epoch =  10 batch =  2175 / 3597 loss =  -15.56964320426478
epoch =  10 batch =  2200 / 3597 loss =  -15.56994318236334
epoch =  10 batch =  2225 / 3597 loss =  -15.570292135668787
epoch =  10 batch =  2250 /

epoch =  11 batch =  1625 / 3597 loss =  -15.627848286763621
epoch =  11 batch =  1650 / 3597 loss =  -15.623670609195186
epoch =  11 batch =  1675 / 3597 loss =  -15.622821348094712
epoch =  11 batch =  1700 / 3597 loss =  -15.624910856400568
epoch =  11 batch =  1725 / 3597 loss =  -15.629987413333437
epoch =  11 batch =  1750 / 3597 loss =  -15.62785888630482
epoch =  11 batch =  1775 / 3597 loss =  -15.627968948703629
epoch =  11 batch =  1800 / 3597 loss =  -15.629099355016663
epoch =  11 batch =  1825 / 3597 loss =  -15.624645311495561
epoch =  11 batch =  1850 / 3597 loss =  -15.62439226421519
epoch =  11 batch =  1875 / 3597 loss =  -15.625482164720482
epoch =  11 batch =  1900 / 3597 loss =  -15.623943673504584
epoch =  11 batch =  1925 / 3597 loss =  -15.62624183052673
epoch =  11 batch =  1950 / 3597 loss =  -15.62826493066253
epoch =  11 batch =  1975 / 3597 loss =  -15.629673808692438
epoch =  11 batch =  2000 / 3597 loss =  -15.632217982957984
epoch =  11 batch =  2025 / 

epoch =  12 batch =  1400 / 3597 loss =  -15.66678475823086
epoch =  12 batch =  1425 / 3597 loss =  -15.671903676438298
epoch =  12 batch =  1450 / 3597 loss =  -15.673294183059367
epoch =  12 batch =  1475 / 3597 loss =  -15.668541841713717
epoch =  12 batch =  1500 / 3597 loss =  -15.667484229441724
epoch =  12 batch =  1525 / 3597 loss =  -15.665133342517938
epoch =  12 batch =  1550 / 3597 loss =  -15.670289564716823
epoch =  12 batch =  1575 / 3597 loss =  -15.670123137193283
epoch =  12 batch =  1600 / 3597 loss =  -15.667398100715365
epoch =  12 batch =  1625 / 3597 loss =  -15.669963402061885
epoch =  12 batch =  1650 / 3597 loss =  -15.67456395722822
epoch =  12 batch =  1675 / 3597 loss =  -15.675938537980036
epoch =  12 batch =  1700 / 3597 loss =  -15.677371198888528
epoch =  12 batch =  1725 / 3597 loss =  -15.678508846397886
epoch =  12 batch =  1750 / 3597 loss =  -15.680394308829838
epoch =  12 batch =  1775 / 3597 loss =  -15.68007021820223
epoch =  12 batch =  1800 /

epoch =  13 batch =  1175 / 3597 loss =  -15.645859405297001
epoch =  13 batch =  1200 / 3597 loss =  -15.649242594081297
epoch =  13 batch =  1225 / 3597 loss =  -15.650716506248585
epoch =  13 batch =  1250 / 3597 loss =  -15.63969367337551
epoch =  13 batch =  1275 / 3597 loss =  -15.642116786544225
epoch =  13 batch =  1300 / 3597 loss =  -15.644262493435553
epoch =  13 batch =  1325 / 3597 loss =  -15.64456282841494
epoch =  13 batch =  1350 / 3597 loss =  -15.649485024763512
epoch =  13 batch =  1375 / 3597 loss =  -15.646761388279671
epoch =  13 batch =  1400 / 3597 loss =  -15.6461040372937
epoch =  13 batch =  1425 / 3597 loss =  -15.652580023815053
epoch =  13 batch =  1450 / 3597 loss =  -15.652970128023238
epoch =  13 batch =  1475 / 3597 loss =  -15.651660882360567
epoch =  13 batch =  1500 / 3597 loss =  -15.65871955203184
epoch =  13 batch =  1525 / 3597 loss =  -15.658726245985118
epoch =  13 batch =  1550 / 3597 loss =  -15.661945873041601
epoch =  13 batch =  1575 / 3

epoch =  14 batch =  950 / 3597 loss =  -15.708450358498109
epoch =  14 batch =  975 / 3597 loss =  -15.705537288892465
epoch =  14 batch =  1000 / 3597 loss =  -15.711894969006519
epoch =  14 batch =  1025 / 3597 loss =  -15.717381024685984
epoch =  14 batch =  1050 / 3597 loss =  -15.717648919937613
epoch =  14 batch =  1075 / 3597 loss =  -15.722024913170968
epoch =  14 batch =  1100 / 3597 loss =  -15.71853793262894
epoch =  14 batch =  1125 / 3597 loss =  -15.723384944414372
epoch =  14 batch =  1150 / 3597 loss =  -15.72108748143492
epoch =  14 batch =  1175 / 3597 loss =  -15.722320872910169
epoch =  14 batch =  1200 / 3597 loss =  -15.71906948248413
epoch =  14 batch =  1225 / 3597 loss =  -15.715675043631922
epoch =  14 batch =  1250 / 3597 loss =  -15.71463682096925
epoch =  14 batch =  1275 / 3597 loss =  -15.725374258424047
epoch =  14 batch =  1300 / 3597 loss =  -15.725041794831894
epoch =  14 batch =  1325 / 3597 loss =  -15.72441997067781
epoch =  14 batch =  1350 / 359

epoch =  15 batch =  725 / 3597 loss =  -15.790824787675842
epoch =  15 batch =  750 / 3597 loss =  -15.789902298491741
epoch =  15 batch =  775 / 3597 loss =  -15.789658778721524
epoch =  15 batch =  800 / 3597 loss =  -15.781851187478589
epoch =  15 batch =  825 / 3597 loss =  -15.77490482607419
epoch =  15 batch =  850 / 3597 loss =  -15.772353529790033
epoch =  15 batch =  875 / 3597 loss =  -15.777506834840121
epoch =  15 batch =  900 / 3597 loss =  -15.776980980122659
epoch =  15 batch =  925 / 3597 loss =  -15.783838131000362
epoch =  15 batch =  950 / 3597 loss =  -15.783155010074973
epoch =  15 batch =  975 / 3597 loss =  -15.771795673448532
epoch =  15 batch =  1000 / 3597 loss =  -15.767718103620318
epoch =  15 batch =  1025 / 3597 loss =  -15.776479984119854
epoch =  15 batch =  1050 / 3597 loss =  -15.773422754797677
epoch =  15 batch =  1075 / 3597 loss =  -15.773815706316867
epoch =  15 batch =  1100 / 3597 loss =  -15.773054643504517
epoch =  15 batch =  1125 / 3597 los

epoch =  16 batch =  500 / 3597 loss =  -15.82291775168535
epoch =  16 batch =  525 / 3597 loss =  -15.810234680828486
epoch =  16 batch =  550 / 3597 loss =  -15.790905341478961
epoch =  16 batch =  575 / 3597 loss =  -15.7892291860448
epoch =  16 batch =  600 / 3597 loss =  -15.79346542548816
epoch =  16 batch =  625 / 3597 loss =  -15.79059267653444
epoch =  16 batch =  650 / 3597 loss =  -15.797037643221666
epoch =  16 batch =  675 / 3597 loss =  -15.801777674601627
epoch =  16 batch =  700 / 3597 loss =  -15.808209338984033
epoch =  16 batch =  725 / 3597 loss =  -15.803351281431423
epoch =  16 batch =  750 / 3597 loss =  -15.80491761742197
epoch =  16 batch =  775 / 3597 loss =  -15.821291466349178
epoch =  16 batch =  800 / 3597 loss =  -15.81316182497289
epoch =  16 batch =  825 / 3597 loss =  -15.811314298800637
epoch =  16 batch =  850 / 3597 loss =  -15.80620881673452
epoch =  16 batch =  875 / 3597 loss =  -15.801319461979277
epoch =  16 batch =  900 / 3597 loss =  -15.7951

epoch =  17 batch =  275 / 3597 loss =  -15.792201519012451
epoch =  17 batch =  300 / 3597 loss =  -15.808445978006255
epoch =  17 batch =  325 / 3597 loss =  -15.806689525674457
epoch =  17 batch =  350 / 3597 loss =  -15.830597342249336
epoch =  17 batch =  375 / 3597 loss =  -15.848376479554684
epoch =  17 batch =  400 / 3597 loss =  -15.839655809568942
epoch =  17 batch =  425 / 3597 loss =  -15.818968947504608
epoch =  17 batch =  450 / 3597 loss =  -15.829259288813217
epoch =  17 batch =  475 / 3597 loss =  -15.838229900648614
epoch =  17 batch =  500 / 3597 loss =  -15.844233501457168
epoch =  17 batch =  525 / 3597 loss =  -15.832906487323486
epoch =  17 batch =  550 / 3597 loss =  -15.813544065679698
epoch =  17 batch =  575 / 3597 loss =  -15.813140557871925
epoch =  17 batch =  600 / 3597 loss =  -15.808795502102514
epoch =  17 batch =  625 / 3597 loss =  -15.809915265336205
epoch =  17 batch =  650 / 3597 loss =  -15.815100995076966
epoch =  17 batch =  675 / 3597 loss =  

epoch =  18 batch =  50 / 3597 loss =  -15.766240456525017
epoch =  18 batch =  75 / 3597 loss =  -15.738205282311691
epoch =  18 batch =  100 / 3597 loss =  -15.735605400387604
epoch =  18 batch =  125 / 3597 loss =  -15.765334167177715
epoch =  18 batch =  150 / 3597 loss =  -15.80202117187298
epoch =  18 batch =  175 / 3597 loss =  -15.78047183968804
epoch =  18 batch =  200 / 3597 loss =  -15.814707433406394
epoch =  18 batch =  225 / 3597 loss =  -15.805478703659192
epoch =  18 batch =  250 / 3597 loss =  -15.768800933047595
epoch =  18 batch =  275 / 3597 loss =  -15.758275733477827
epoch =  18 batch =  300 / 3597 loss =  -15.773126092067985
epoch =  18 batch =  325 / 3597 loss =  -15.784173710945925
epoch =  18 batch =  350 / 3597 loss =  -15.783456332323558
epoch =  18 batch =  375 / 3597 loss =  -15.787441476862481
epoch =  18 batch =  400 / 3597 loss =  -15.77723280033863
epoch =  18 batch =  425 / 3597 loss =  -15.785248293003566
epoch =  18 batch =  450 / 3597 loss =  -15.8

epoch =  18 batch =  3450 / 3597 loss =  -15.823553637192582
epoch =  18 batch =  3475 / 3597 loss =  -15.823285013951695
epoch =  18 batch =  3500 / 3597 loss =  -15.823793665404594
epoch =  18 batch =  3525 / 3597 loss =  -15.821759166868992
epoch =  18 batch =  3550 / 3597 loss =  -15.821386801360594
epoch =  18 batch =  3575 / 3597 loss =  -15.821955346421108
Validation loss =  -15.863177299499512
epoch =  19 batch =  0 / 3597 loss =  -17.2310791015625
epoch =  19 batch =  25 / 3597 loss =  -15.96720181978666
epoch =  19 batch =  50 / 3597 loss =  -16.0245626673979
epoch =  19 batch =  75 / 3597 loss =  -15.945761592764603
epoch =  19 batch =  100 / 3597 loss =  -15.880496431105207
epoch =  19 batch =  125 / 3597 loss =  -15.925000932481554
epoch =  19 batch =  150 / 3597 loss =  -15.913228413916581
epoch =  19 batch =  175 / 3597 loss =  -15.905136157165874
epoch =  19 batch =  200 / 3597 loss =  -15.923525307308974
epoch =  19 batch =  225 / 3597 loss =  -15.894025292016764
epoch

epoch =  19 batch =  3225 / 3597 loss =  -15.851855405132296
epoch =  19 batch =  3250 / 3597 loss =  -15.852217681002374
epoch =  19 batch =  3275 / 3597 loss =  -15.852589720189208
epoch =  19 batch =  3300 / 3597 loss =  -15.851664756941311
epoch =  19 batch =  3325 / 3597 loss =  -15.852592049918114
epoch =  19 batch =  3350 / 3597 loss =  -15.8548137510545
epoch =  19 batch =  3375 / 3597 loss =  -15.85508763677136
epoch =  19 batch =  3400 / 3597 loss =  -15.855241428084739
epoch =  19 batch =  3425 / 3597 loss =  -15.855704925802112
epoch =  19 batch =  3450 / 3597 loss =  -15.854974633954429
epoch =  19 batch =  3475 / 3597 loss =  -15.854990922952
epoch =  19 batch =  3500 / 3597 loss =  -15.856240395919421
epoch =  19 batch =  3525 / 3597 loss =  -15.856677782380034
epoch =  19 batch =  3550 / 3597 loss =  -15.855961405636727
epoch =  19 batch =  3575 / 3597 loss =  -15.852590946543136
Validation loss =  -16.00214385986328
epoch =  20 batch =  0 / 3597 loss =  -16.15784072875

epoch =  20 batch =  3000 / 3597 loss =  -15.86214469616034
epoch =  20 batch =  3025 / 3597 loss =  -15.86237572489081
epoch =  20 batch =  3050 / 3597 loss =  -15.861372490079237
epoch =  20 batch =  3075 / 3597 loss =  -15.861127681632967
epoch =  20 batch =  3100 / 3597 loss =  -15.863777024405342
epoch =  20 batch =  3125 / 3597 loss =  -15.863351974438492
epoch =  20 batch =  3150 / 3597 loss =  -15.861122789022772
epoch =  20 batch =  3175 / 3597 loss =  -15.862912418259782
epoch =  20 batch =  3200 / 3597 loss =  -15.863678126884825
epoch =  20 batch =  3225 / 3597 loss =  -15.864759359164632
epoch =  20 batch =  3250 / 3597 loss =  -15.864368231323674
epoch =  20 batch =  3275 / 3597 loss =  -15.862919140793611
epoch =  20 batch =  3300 / 3597 loss =  -15.86443728672595
epoch =  20 batch =  3325 / 3597 loss =  -15.866487070371498
epoch =  20 batch =  3350 / 3597 loss =  -15.864439599729588
epoch =  20 batch =  3375 / 3597 loss =  -15.866441525836692
epoch =  20 batch =  3400 /

epoch =  21 batch =  2775 / 3597 loss =  -15.893442312303812
epoch =  21 batch =  2800 / 3597 loss =  -15.891746326584085
epoch =  21 batch =  2825 / 3597 loss =  -15.886563608853933
epoch =  21 batch =  2850 / 3597 loss =  -15.887790776603643
epoch =  21 batch =  2875 / 3597 loss =  -15.886765806996607
epoch =  21 batch =  2900 / 3597 loss =  -15.888483762823274
epoch =  21 batch =  2925 / 3597 loss =  -15.88876012442296
epoch =  21 batch =  2950 / 3597 loss =  -15.889644434071524
epoch =  21 batch =  2975 / 3597 loss =  -15.889290309721424
epoch =  21 batch =  3000 / 3597 loss =  -15.891081710212909
epoch =  21 batch =  3025 / 3597 loss =  -15.890960555634432
epoch =  21 batch =  3050 / 3597 loss =  -15.891274936079627
epoch =  21 batch =  3075 / 3597 loss =  -15.89277240513825
epoch =  21 batch =  3100 / 3597 loss =  -15.893801755268548
epoch =  21 batch =  3125 / 3597 loss =  -15.891421625572981
epoch =  21 batch =  3150 / 3597 loss =  -15.89280868174348
epoch =  21 batch =  3175 /

epoch =  22 batch =  2550 / 3597 loss =  -15.903535974862201
epoch =  22 batch =  2575 / 3597 loss =  -15.8998217201381
