In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

## Hyperparameters

In [7]:
n_RQS_knots = 10   # Number of knots in RQS transform
n_made_layers = 3  # Number of hidden layers in every made network
n_made_units = 500 # Number of units in every layer of the made network
n_flow_layers = 12 # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001   # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6) # This is missing the required statistical factor to account for negative weights
n_test = int(1e5)

## Load the training data

In [8]:
samples = np.genfromtxt("data/negative_weight_samples.csv", delimiter=',')
weights = np.genfromtxt("data/negative_weight_weights.csv", delimiter=',')

## Find the fraction of negative events and the required statistical factor

In [None]:
f = np.sum(weights < 0)/len(weights)
c = (1-2*f)**-2
print("Fraction of negative events", f)
print("Statistical factor", c)
if (n_train + n_test)*c > samples.shape[0]:
    raise Exception("Not enough training data")

Fraction of negative events 0.2394467380986031
Statistical factor 3.6825358174692755


## Normalise weights

In [None]:
weights = np.sign(weights)

## Split to a train and test set

In [None]:
n_train_with_stats = int(n_train*c)
n_test_with_stats = int(n_test*c)

train_samples = torch.tensor(samples[:n_train_with_stats], dtype=torch.float32, device=device)
train_weights = torch.tensor(weights[:n_train_with_stats], dtype=torch.float32, device=device)
test_samples = torch.tensor(samples[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)
test_weights = torch.tensor(weights[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)

del samples
del weights
gc.collect()

90

## Set up the flow

In [None]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_loss = np.inf
for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        weights_batch = train_weights[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)*weights_batch).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    scheduler.step()
    
    # Compute validation loss
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
        weights_batch = test_weights[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)*weights_batch).mean())/(batch+1)
    
    print("Validation loss = ", validation_loss.item())
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    writer.add_scalar("Loss_test", validation_loss, epoch)
    
    if validation_loss < best_loss:
        torch.save(flow, "best_flow_model.pt")
        best_loss = validation_loss

torch.save(flow, "final_flow_model.pt")

epoch =  0 batch =  0 / 3597 loss =  17.9576358795166
epoch =  0 batch =  25 / 3597 loss =  0.5931739480449599
epoch =  0 batch =  50 / 3597 loss =  -2.5938962186668437
epoch =  0 batch =  75 / 3597 loss =  -4.261602783085485
epoch =  0 batch =  100 / 3597 loss =  -5.302870060105135
epoch =  0 batch =  125 / 3597 loss =  -6.050047131876151
epoch =  0 batch =  150 / 3597 loss =  -6.653813500376726
epoch =  0 batch =  175 / 3597 loss =  -7.145457922819663
epoch =  0 batch =  200 / 3597 loss =  -7.599675016690843
epoch =  0 batch =  225 / 3597 loss =  -7.969954705409772
epoch =  0 batch =  250 / 3597 loss =  -8.289210155725007
epoch =  0 batch =  275 / 3597 loss =  -8.554372438108146
epoch =  0 batch =  300 / 3597 loss =  -8.79812773799778
epoch =  0 batch =  325 / 3597 loss =  -9.013570887590847
epoch =  0 batch =  350 / 3597 loss =  -9.203197745218917
epoch =  0 batch =  375 / 3597 loss =  -9.354217412346545
epoch =  0 batch =  400 / 3597 loss =  -9.499882419544862
epoch =  0 batch =  4

epoch =  0 batch =  3475 / 3597 loss =  -12.81066526694199
epoch =  0 batch =  3500 / 3597 loss =  -12.818650589678596
epoch =  0 batch =  3525 / 3597 loss =  -12.82636924849963
epoch =  0 batch =  3550 / 3597 loss =  -12.832992517962412
epoch =  0 batch =  3575 / 3597 loss =  -12.841424104855525
Validation loss =  -13.9771146774292
epoch =  1 batch =  0 / 3597 loss =  -14.246454238891602
epoch =  1 batch =  25 / 3597 loss =  -14.041108791644756
epoch =  1 batch =  50 / 3597 loss =  -14.156505079830394
epoch =  1 batch =  75 / 3597 loss =  -14.05435713968779
epoch =  1 batch =  100 / 3597 loss =  -14.039949709826177
epoch =  1 batch =  125 / 3597 loss =  -14.093002849155003
epoch =  1 batch =  150 / 3597 loss =  -14.119642983998684
epoch =  1 batch =  175 / 3597 loss =  -14.119377212090926
epoch =  1 batch =  200 / 3597 loss =  -14.094803079443784
epoch =  1 batch =  225 / 3597 loss =  -14.081807191393017
epoch =  1 batch =  250 / 3597 loss =  -14.0563746676502
epoch =  1 batch =  275 

epoch =  1 batch =  3325 / 3597 loss =  -14.337573185930465
epoch =  1 batch =  3350 / 3597 loss =  -14.339742073120554
epoch =  1 batch =  3375 / 3597 loss =  -14.340022903765547
epoch =  1 batch =  3400 / 3597 loss =  -14.341916379561251
epoch =  1 batch =  3425 / 3597 loss =  -14.344772187994018
epoch =  1 batch =  3450 / 3597 loss =  -14.346889278640681
epoch =  1 batch =  3475 / 3597 loss =  -14.347250139480082
epoch =  1 batch =  3500 / 3597 loss =  -14.348909794007394
epoch =  1 batch =  3525 / 3597 loss =  -14.352779027892321
epoch =  1 batch =  3550 / 3597 loss =  -14.355382529355143
epoch =  1 batch =  3575 / 3597 loss =  -14.357113652314649
Validation loss =  -14.616870880126953
epoch =  2 batch =  0 / 3597 loss =  -14.837501525878906
epoch =  2 batch =  25 / 3597 loss =  -14.700287635509785
epoch =  2 batch =  50 / 3597 loss =  -14.615548208648084
epoch =  2 batch =  75 / 3597 loss =  -14.615412787387246
epoch =  2 batch =  100 / 3597 loss =  -14.663581687625092
epoch =  2 

epoch =  2 batch =  3150 / 3597 loss =  -14.79252499937822
epoch =  2 batch =  3175 / 3597 loss =  -14.792928848218558
epoch =  2 batch =  3200 / 3597 loss =  -14.795544390751399
epoch =  2 batch =  3225 / 3597 loss =  -14.796652232203806
epoch =  2 batch =  3250 / 3597 loss =  -14.796315499138224
epoch =  2 batch =  3275 / 3597 loss =  -14.79568148911072
epoch =  2 batch =  3300 / 3597 loss =  -14.795369102029648
epoch =  2 batch =  3325 / 3597 loss =  -14.798789903476736
epoch =  2 batch =  3350 / 3597 loss =  -14.799766401360904
epoch =  2 batch =  3375 / 3597 loss =  -14.799943755588261
epoch =  2 batch =  3400 / 3597 loss =  -14.80244406823795
epoch =  2 batch =  3425 / 3597 loss =  -14.801963307046083
epoch =  2 batch =  3450 / 3597 loss =  -14.804576300358779
epoch =  2 batch =  3475 / 3597 loss =  -14.804631808821982
epoch =  2 batch =  3500 / 3597 loss =  -14.804732692885215
epoch =  2 batch =  3525 / 3597 loss =  -14.80584778842504
epoch =  2 batch =  3550 / 3597 loss =  -14.

epoch =  3 batch =  3000 / 3597 loss =  -15.02278028715058
epoch =  3 batch =  3025 / 3597 loss =  -15.023341060709024
epoch =  3 batch =  3050 / 3597 loss =  -15.023287990998925
epoch =  3 batch =  3075 / 3597 loss =  -15.022860536959763
epoch =  3 batch =  3100 / 3597 loss =  -15.021656931311574
epoch =  3 batch =  3125 / 3597 loss =  -15.02323330081737
epoch =  3 batch =  3150 / 3597 loss =  -15.024923563987178
epoch =  3 batch =  3175 / 3597 loss =  -15.026135816982471
epoch =  3 batch =  3200 / 3597 loss =  -15.028353724469248
epoch =  3 batch =  3225 / 3597 loss =  -15.028823765035158
epoch =  3 batch =  3250 / 3597 loss =  -15.031098364316952
epoch =  3 batch =  3275 / 3597 loss =  -15.029220966483502
epoch =  3 batch =  3300 / 3597 loss =  -15.030485417546016
epoch =  3 batch =  3325 / 3597 loss =  -15.028140841660123
epoch =  3 batch =  3350 / 3597 loss =  -15.029541417471654
epoch =  3 batch =  3375 / 3597 loss =  -15.027470055632117
epoch =  3 batch =  3400 / 3597 loss =  -1

epoch =  4 batch =  2825 / 3597 loss =  -15.165194915020962
epoch =  4 batch =  2850 / 3597 loss =  -15.164841786053085
epoch =  4 batch =  2875 / 3597 loss =  -15.16409544686117
epoch =  4 batch =  2900 / 3597 loss =  -15.163910172471336
epoch =  4 batch =  2925 / 3597 loss =  -15.164854743589867
epoch =  4 batch =  2950 / 3597 loss =  -15.16470019658191
epoch =  4 batch =  2975 / 3597 loss =  -15.16631504156256
epoch =  4 batch =  3000 / 3597 loss =  -15.165048338659043
epoch =  4 batch =  3025 / 3597 loss =  -15.166337831084961
epoch =  4 batch =  3050 / 3597 loss =  -15.167202263103471
epoch =  4 batch =  3075 / 3597 loss =  -15.1675673833284
epoch =  4 batch =  3100 / 3597 loss =  -15.16865845581517
epoch =  4 batch =  3125 / 3597 loss =  -15.16949273406582
epoch =  4 batch =  3150 / 3597 loss =  -15.169659325524911
epoch =  4 batch =  3175 / 3597 loss =  -15.168034217219507
epoch =  4 batch =  3200 / 3597 loss =  -15.167977552047484
epoch =  4 batch =  3225 / 3597 loss =  -15.166

epoch =  5 batch =  2675 / 3597 loss =  -15.272952343494543
epoch =  5 batch =  2700 / 3597 loss =  -15.274478148990362
epoch =  5 batch =  2725 / 3597 loss =  -15.274069173145362
epoch =  5 batch =  2750 / 3597 loss =  -15.2782363368138
epoch =  5 batch =  2775 / 3597 loss =  -15.278668407060913
epoch =  5 batch =  2800 / 3597 loss =  -15.279261132811953
epoch =  5 batch =  2825 / 3597 loss =  -15.279348164860101
epoch =  5 batch =  2850 / 3597 loss =  -15.27671886151065
epoch =  5 batch =  2875 / 3597 loss =  -15.277596800648949
epoch =  5 batch =  2900 / 3597 loss =  -15.276391437324067
epoch =  5 batch =  2925 / 3597 loss =  -15.27610117918159
epoch =  5 batch =  2950 / 3597 loss =  -15.275830533859162
epoch =  5 batch =  2975 / 3597 loss =  -15.277163187021849
epoch =  5 batch =  3000 / 3597 loss =  -15.279022233321719
epoch =  5 batch =  3025 / 3597 loss =  -15.281352892508192
epoch =  5 batch =  3050 / 3597 loss =  -15.282058812094217
epoch =  5 batch =  3075 / 3597 loss =  -15.