In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

## Hyperparameters

In [7]:
n_RQS_knots = 16   # Number of knots in RQS transform
n_made_layers = 3  # Number of hidden layers in every made network
n_made_units = 250 # Number of units in every layer of the made network
n_flow_layers = 12 # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001   # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6) # This is missing the required statistical factor to account for negative weights
n_test = int(1e5)

## Load the training data

In [8]:
samples = np.genfromtxt("data/negative_weight_samples.csv", delimiter=',')[:,:8]
weights = np.genfromtxt("data/negative_weight_weights.csv", delimiter=',')

In [9]:
print(samples.shape)

(7995490, 8)


## Find the fraction of negative events and the required statistical factor

In [10]:
f = np.sum(weights < 0)/len(weights)
c = (1-2*f)**-2
print("Fraction of negative events", f)
print("Statistical factor", c)
if (n_train + n_test)*c > samples.shape[0]:
    raise Exception("Not enough training data")

Fraction of negative events 0.2394467380986031
Statistical factor 3.6825358174692755


## Normalise weights

In [11]:
weights = np.sign(weights)

## Split to a train and test set

In [12]:
n_train_with_stats = int(n_train*c)
n_test_with_stats = int(n_test*c)

train_samples = torch.tensor(samples[:n_train_with_stats], dtype=torch.float32, device=device)
train_weights = torch.tensor(weights[:n_train_with_stats], dtype=torch.float32, device=device)
test_samples = torch.tensor(samples[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)
test_weights = torch.tensor(weights[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)

del samples
del weights
gc.collect()

90

## Set up the flow

In [13]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_loss = np.inf
for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        weights_batch = train_weights[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)*weights_batch).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    scheduler.step()
    
    # Compute validation loss
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
        weights_batch = test_weights[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)*weights_batch).mean())/(batch+1)
    
    print("Validation loss = ", validation_loss.item())
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    writer.add_scalar("Loss_test", validation_loss, epoch)
    
    if validation_loss < best_loss:
        torch.save(flow, "best_flow_model.pt")
        best_loss = validation_loss

torch.save(flow, "final_flow_model.pt")

epoch =  0 batch =  0 / 3597 loss =  4.816425800323486
epoch =  0 batch =  25 / 3597 loss =  -0.29749193930855167
epoch =  0 batch =  50 / 3597 loss =  -1.4781019426151818
epoch =  0 batch =  75 / 3597 loss =  -2.091771349801045
epoch =  0 batch =  100 / 3597 loss =  -2.4474620946710655
epoch =  0 batch =  125 / 3597 loss =  -2.686489720961877
epoch =  0 batch =  150 / 3597 loss =  -2.8532665240152792
epoch =  0 batch =  175 / 3597 loss =  -3.015312798194249
epoch =  0 batch =  200 / 3597 loss =  -3.1332063347620163
epoch =  0 batch =  225 / 3597 loss =  -3.2452271737949516
epoch =  0 batch =  250 / 3597 loss =  -3.3405088280361017
epoch =  0 batch =  275 / 3597 loss =  -3.4274140591260753
epoch =  0 batch =  300 / 3597 loss =  -3.500318817447784
epoch =  0 batch =  325 / 3597 loss =  -3.5649235712284324
epoch =  0 batch =  350 / 3597 loss =  -3.6228328467137114
epoch =  0 batch =  375 / 3597 loss =  -3.6799428268989667
epoch =  0 batch =  400 / 3597 loss =  -3.729265578908665
epoch = 

epoch =  0 batch =  3500 / 3597 loss =  -4.6053382648536445
epoch =  0 batch =  3525 / 3597 loss =  -4.60694008287394
epoch =  0 batch =  3550 / 3597 loss =  -4.608264728681911
epoch =  0 batch =  3575 / 3597 loss =  -4.609863153133758
Validation loss =  -4.777769088745117
epoch =  1 batch =  0 / 3597 loss =  -5.016357421875
epoch =  1 batch =  25 / 3597 loss =  -4.727318360255316
epoch =  1 batch =  50 / 3597 loss =  -4.7302153624740315
epoch =  1 batch =  75 / 3597 loss =  -4.764875016714399
epoch =  1 batch =  100 / 3597 loss =  -4.752459563831293
epoch =  1 batch =  125 / 3597 loss =  -4.742090164668978
epoch =  1 batch =  150 / 3597 loss =  -4.740787300842489
epoch =  1 batch =  175 / 3597 loss =  -4.744164873253217
epoch =  1 batch =  200 / 3597 loss =  -4.745584165278954
epoch =  1 batch =  225 / 3597 loss =  -4.758629691284318
epoch =  1 batch =  250 / 3597 loss =  -4.757839349161584
epoch =  1 batch =  275 / 3597 loss =  -4.75127010414566
epoch =  1 batch =  300 / 3597 loss = 

epoch =  1 batch =  3400 / 3597 loss =  -4.797848409149652
epoch =  1 batch =  3425 / 3597 loss =  -4.798250543396162
epoch =  1 batch =  3450 / 3597 loss =  -4.79735694833853
epoch =  1 batch =  3475 / 3597 loss =  -4.797332464482619
epoch =  1 batch =  3500 / 3597 loss =  -4.7979651457648425
epoch =  1 batch =  3525 / 3597 loss =  -4.797671083151716
epoch =  1 batch =  3550 / 3597 loss =  -4.797123134992248
epoch =  1 batch =  3575 / 3597 loss =  -4.79773310700252
Validation loss =  -4.805451393127441
epoch =  2 batch =  0 / 3597 loss =  -4.502457141876221
epoch =  2 batch =  25 / 3597 loss =  -4.817321703984187
epoch =  2 batch =  50 / 3597 loss =  -4.821781541786942
epoch =  2 batch =  75 / 3597 loss =  -4.8196306228637695
epoch =  2 batch =  100 / 3597 loss =  -4.825655701136827
epoch =  2 batch =  125 / 3597 loss =  -4.83210279449584
epoch =  2 batch =  150 / 3597 loss =  -4.824747846616025
epoch =  2 batch =  175 / 3597 loss =  -4.822286370125682
epoch =  2 batch =  200 / 3597 l

epoch =  2 batch =  3275 / 3597 loss =  -4.8166916363114005
epoch =  2 batch =  3300 / 3597 loss =  -4.816549361239621
epoch =  2 batch =  3325 / 3597 loss =  -4.816512301107224
epoch =  2 batch =  3350 / 3597 loss =  -4.816554518608283
epoch =  2 batch =  3375 / 3597 loss =  -4.8163662373454565
epoch =  2 batch =  3400 / 3597 loss =  -4.816659034122189
epoch =  2 batch =  3425 / 3597 loss =  -4.816585473129078
epoch =  2 batch =  3450 / 3597 loss =  -4.816488034826739
epoch =  2 batch =  3475 / 3597 loss =  -4.81595945893277
epoch =  2 batch =  3500 / 3597 loss =  -4.815615499812982
epoch =  2 batch =  3525 / 3597 loss =  -4.815980384651389
epoch =  2 batch =  3550 / 3597 loss =  -4.816478559246394
epoch =  2 batch =  3575 / 3597 loss =  -4.817062934373042
Validation loss =  -4.825273036956787
epoch =  3 batch =  0 / 3597 loss =  -5.256582260131836
epoch =  3 batch =  25 / 3597 loss =  -4.81933225118197
epoch =  3 batch =  50 / 3597 loss =  -4.881012467777028
epoch =  3 batch =  75 / 

epoch =  3 batch =  3175 / 3597 loss =  -4.826763181151908
epoch =  3 batch =  3200 / 3597 loss =  -4.8265606465469215
epoch =  3 batch =  3225 / 3597 loss =  -4.826982191018413
epoch =  3 batch =  3250 / 3597 loss =  -4.827285436145052
epoch =  3 batch =  3275 / 3597 loss =  -4.8273758953744235
epoch =  3 batch =  3300 / 3597 loss =  -4.8271672141367805
epoch =  3 batch =  3325 / 3597 loss =  -4.827300250637769
epoch =  3 batch =  3350 / 3597 loss =  -4.827979334857138
epoch =  3 batch =  3375 / 3597 loss =  -4.828348203292958
epoch =  3 batch =  3400 / 3597 loss =  -4.828356641210274
epoch =  3 batch =  3425 / 3597 loss =  -4.828169307441513
epoch =  3 batch =  3450 / 3597 loss =  -4.827930078061553
epoch =  3 batch =  3475 / 3597 loss =  -4.8276175012248315
epoch =  3 batch =  3500 / 3597 loss =  -4.827506110450945
epoch =  3 batch =  3525 / 3597 loss =  -4.827044108207432
epoch =  3 batch =  3550 / 3597 loss =  -4.82698756787046
epoch =  3 batch =  3575 / 3597 loss =  -4.8270345652

epoch =  4 batch =  3050 / 3597 loss =  -4.829011174273538
epoch =  4 batch =  3075 / 3597 loss =  -4.829643444370071
epoch =  4 batch =  3100 / 3597 loss =  -4.829259608107889
epoch =  4 batch =  3125 / 3597 loss =  -4.829479936598503
epoch =  4 batch =  3150 / 3597 loss =  -4.830352507633308
epoch =  4 batch =  3175 / 3597 loss =  -4.830684967695604
epoch =  4 batch =  3200 / 3597 loss =  -4.83159179994367
epoch =  4 batch =  3225 / 3597 loss =  -4.831963325389325
epoch =  4 batch =  3250 / 3597 loss =  -4.831211255096123
epoch =  4 batch =  3275 / 3597 loss =  -4.831952613788634
epoch =  4 batch =  3300 / 3597 loss =  -4.8325045655547125
epoch =  4 batch =  3325 / 3597 loss =  -4.832793295992748
epoch =  4 batch =  3350 / 3597 loss =  -4.833517717482674
epoch =  4 batch =  3375 / 3597 loss =  -4.833172746744189
epoch =  4 batch =  3400 / 3597 loss =  -4.833063942164181
epoch =  4 batch =  3425 / 3597 loss =  -4.833154417378671
epoch =  4 batch =  3450 / 3597 loss =  -4.8328926675874

epoch =  5 batch =  2925 / 3597 loss =  -4.8371849088577745
epoch =  5 batch =  2950 / 3597 loss =  -4.836651137868109
epoch =  5 batch =  2975 / 3597 loss =  -4.836694717327128
epoch =  5 batch =  3000 / 3597 loss =  -4.83669092678223
epoch =  5 batch =  3025 / 3597 loss =  -4.836216459901475
epoch =  5 batch =  3050 / 3597 loss =  -4.836142062437583
epoch =  5 batch =  3075 / 3597 loss =  -4.836357838233223
epoch =  5 batch =  3100 / 3597 loss =  -4.836209389825914
epoch =  5 batch =  3125 / 3597 loss =  -4.836017913034323
epoch =  5 batch =  3150 / 3597 loss =  -4.836041795757238
epoch =  5 batch =  3175 / 3597 loss =  -4.8360907382118405
epoch =  5 batch =  3200 / 3597 loss =  -4.8365720764840745
epoch =  5 batch =  3225 / 3597 loss =  -4.8369329913536205
epoch =  5 batch =  3250 / 3597 loss =  -4.836746667649569
epoch =  5 batch =  3275 / 3597 loss =  -4.836403120495625
epoch =  5 batch =  3300 / 3597 loss =  -4.836330011735302
epoch =  5 batch =  3325 / 3597 loss =  -4.8364556947

epoch =  6 batch =  2800 / 3597 loss =  -4.842608869437523
epoch =  6 batch =  2825 / 3597 loss =  -4.842454629845419
epoch =  6 batch =  2850 / 3597 loss =  -4.8420914116762015
epoch =  6 batch =  2875 / 3597 loss =  -4.841724868775732
epoch =  6 batch =  2900 / 3597 loss =  -4.8411411065473375
epoch =  6 batch =  2925 / 3597 loss =  -4.840917018204615
epoch =  6 batch =  2950 / 3597 loss =  -4.840975930592999
epoch =  6 batch =  2975 / 3597 loss =  -4.8405948694675125
epoch =  6 batch =  3000 / 3597 loss =  -4.840729664659861
epoch =  6 batch =  3025 / 3597 loss =  -4.840953551454406
epoch =  6 batch =  3050 / 3597 loss =  -4.841310355437457
epoch =  6 batch =  3075 / 3597 loss =  -4.841906562118129
epoch =  6 batch =  3100 / 3597 loss =  -4.841788439241691
epoch =  6 batch =  3125 / 3597 loss =  -4.841348017322196
epoch =  6 batch =  3150 / 3597 loss =  -4.841657425024815
epoch =  6 batch =  3175 / 3597 loss =  -4.841802402017092
epoch =  6 batch =  3200 / 3597 loss =  -4.8410376885

epoch =  7 batch =  2675 / 3597 loss =  -4.847231661792302
epoch =  7 batch =  2700 / 3597 loss =  -4.847755976581264
epoch =  7 batch =  2725 / 3597 loss =  -4.847904214243292
epoch =  7 batch =  2750 / 3597 loss =  -4.848241543778513
epoch =  7 batch =  2775 / 3597 loss =  -4.848084980026111
epoch =  7 batch =  2800 / 3597 loss =  -4.848364917349291
epoch =  7 batch =  2825 / 3597 loss =  -4.848707526943328
epoch =  7 batch =  2850 / 3597 loss =  -4.849019128036106
epoch =  7 batch =  2875 / 3597 loss =  -4.848156184579808
epoch =  7 batch =  2900 / 3597 loss =  -4.848540042935066
epoch =  7 batch =  2925 / 3597 loss =  -4.84767943298826
epoch =  7 batch =  2950 / 3597 loss =  -4.847757920213824
epoch =  7 batch =  2975 / 3597 loss =  -4.847541425977993
epoch =  7 batch =  3000 / 3597 loss =  -4.847361762298828
epoch =  7 batch =  3025 / 3597 loss =  -4.847546773483175
epoch =  7 batch =  3050 / 3597 loss =  -4.846912374343308
epoch =  7 batch =  3075 / 3597 loss =  -4.84623381995411

epoch =  8 batch =  2550 / 3597 loss =  -4.847024901340357
epoch =  8 batch =  2575 / 3597 loss =  -4.846554554952602
epoch =  8 batch =  2600 / 3597 loss =  -4.845596458489327
epoch =  8 batch =  2625 / 3597 loss =  -4.845681618300451
epoch =  8 batch =  2650 / 3597 loss =  -4.845698153914168
epoch =  8 batch =  2675 / 3597 loss =  -4.845579896094377
epoch =  8 batch =  2700 / 3597 loss =  -4.8466159902824755
epoch =  8 batch =  2725 / 3597 loss =  -4.8459347148022065
epoch =  8 batch =  2750 / 3597 loss =  -4.846082814950503
epoch =  8 batch =  2775 / 3597 loss =  -4.845185370026132
epoch =  8 batch =  2800 / 3597 loss =  -4.845673069278077
epoch =  8 batch =  2825 / 3597 loss =  -4.845521931266924
epoch =  8 batch =  2850 / 3597 loss =  -4.8453858772607035
epoch =  8 batch =  2875 / 3597 loss =  -4.845980530488147
epoch =  8 batch =  2900 / 3597 loss =  -4.845584528317168
epoch =  8 batch =  2925 / 3597 loss =  -4.845472086478161
epoch =  8 batch =  2950 / 3597 loss =  -4.8456785686

epoch =  9 batch =  2500 / 3597 loss =  -4.843708905540715
epoch =  9 batch =  2525 / 3597 loss =  -4.844555525674007
epoch =  9 batch =  2550 / 3597 loss =  -4.845045616373456
epoch =  9 batch =  2575 / 3597 loss =  -4.844812617631428
epoch =  9 batch =  2600 / 3597 loss =  -4.844982653844813
epoch =  9 batch =  2625 / 3597 loss =  -4.84478263217623
epoch =  9 batch =  2650 / 3597 loss =  -4.844441077970901
epoch =  9 batch =  2675 / 3597 loss =  -4.844053831602952
epoch =  9 batch =  2700 / 3597 loss =  -4.843131091231897
epoch =  9 batch =  2725 / 3597 loss =  -4.84365208456818
epoch =  9 batch =  2750 / 3597 loss =  -4.843992221229766
epoch =  9 batch =  2775 / 3597 loss =  -4.844296398269336
epoch =  9 batch =  2800 / 3597 loss =  -4.844125003484431
epoch =  9 batch =  2825 / 3597 loss =  -4.844178738614299
epoch =  9 batch =  2850 / 3597 loss =  -4.84402017039777
epoch =  9 batch =  2875 / 3597 loss =  -4.843387859338507
epoch =  9 batch =  2900 / 3597 loss =  -4.8426460398759765

epoch =  10 batch =  2350 / 3597 loss =  -4.845570274030123
epoch =  10 batch =  2375 / 3597 loss =  -4.846129844887079
epoch =  10 batch =  2400 / 3597 loss =  -4.846690889697718
epoch =  10 batch =  2425 / 3597 loss =  -4.847331863538537
epoch =  10 batch =  2450 / 3597 loss =  -4.8465812893995395
epoch =  10 batch =  2475 / 3597 loss =  -4.846951358152693
epoch =  10 batch =  2500 / 3597 loss =  -4.84623147239211
epoch =  10 batch =  2525 / 3597 loss =  -4.846502026205285
epoch =  10 batch =  2550 / 3597 loss =  -4.846251873912075
epoch =  10 batch =  2575 / 3597 loss =  -4.84600193289495
epoch =  10 batch =  2600 / 3597 loss =  -4.846784807452316
epoch =  10 batch =  2625 / 3597 loss =  -4.846764385291652
epoch =  10 batch =  2650 / 3597 loss =  -4.846613864006511
epoch =  10 batch =  2675 / 3597 loss =  -4.846982123249478
epoch =  10 batch =  2700 / 3597 loss =  -4.847305277372796
epoch =  10 batch =  2725 / 3597 loss =  -4.8468622701271125
epoch =  10 batch =  2750 / 3597 loss = 