In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

## Hyperparameters

In [7]:
n_RQS_knots = 16   # Number of knots in RQS transform
n_made_layers = 3  # Number of hidden layers in every made network
n_made_units = 200 # Number of units in every layer of the made network
n_flow_layers = 8  # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001   # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6) # This is missing the required statistical factor to account for negative weights
n_test = int(1e5)

## Load the training data

In [9]:
samples = np.genfromtxt("data/negative_weight_samples.csv", delimiter=',')[:,:8]
weights = np.genfromtxt("data/negative_weight_weights.csv", delimiter=',')

In [10]:
print(samples.shape)

(7995490, 8)


## Find the fraction of negative events and the required statistical factor

In [11]:
f = np.sum(weights < 0)/len(weights)
c = (1-2*f)**-2
print("Fraction of negative events", f)
print("Statistical factor", c)
if (n_train + n_test)*c > samples.shape[0]:
    raise Exception("Not enough training data")

Fraction of negative events 0.2394467380986031
Statistical factor 3.6825358174692755


## Normalise weights

In [12]:
weights = np.sign(weights)

## Split to a train and test set

In [13]:
n_train_with_stats = int(n_train*c)
n_test_with_stats = int(n_test*c)

train_samples = torch.tensor(samples[:n_train_with_stats], dtype=torch.float32, device=device)
train_weights = torch.tensor(weights[:n_train_with_stats], dtype=torch.float32, device=device)
test_samples = torch.tensor(samples[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)
test_weights = torch.tensor(weights[n_train_with_stats:n_train_with_stats+n_test_with_stats], dtype=torch.float32, device=device)

del samples
del weights
gc.collect()

7

## Set up the flow

In [14]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_loss = np.inf
for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        weights_batch = train_weights[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)*weights_batch).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    scheduler.step()
    
    # Compute validation loss
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
        weights_batch = test_weights[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)*weights_batch).mean())/(batch+1)
    
    print("Validation loss = ", validation_loss.item())
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    writer.add_scalar("Loss_test", validation_loss, epoch)
    
    if validation_loss < best_loss:
        torch.save(flow, "best_flow_model.pt")
        best_loss = validation_loss

torch.save(flow, "final_flow_model.pt")

epoch =  0 batch =  0 / 3597 loss =  3.236827850341797
epoch =  0 batch =  25 / 3597 loss =  -0.5628551657383258
epoch =  0 batch =  50 / 3597 loss =  -1.5663445486741905
epoch =  0 batch =  75 / 3597 loss =  -2.1702778511925747
epoch =  0 batch =  100 / 3597 loss =  -2.530372174659578
epoch =  0 batch =  125 / 3597 loss =  -2.7436772916052075
epoch =  0 batch =  150 / 3597 loss =  -2.9068196917211764
epoch =  0 batch =  175 / 3597 loss =  -3.043420422483574
epoch =  0 batch =  200 / 3597 loss =  -3.151381560819066
epoch =  0 batch =  225 / 3597 loss =  -3.252307253073802
epoch =  0 batch =  250 / 3597 loss =  -3.3355220983702822
epoch =  0 batch =  275 / 3597 loss =  -3.404759450667146
epoch =  0 batch =  300 / 3597 loss =  -3.463435308877812
epoch =  0 batch =  325 / 3597 loss =  -3.5197199784173554
epoch =  0 batch =  350 / 3597 loss =  -3.5701684051769074
epoch =  0 batch =  375 / 3597 loss =  -3.6196508899014046
epoch =  0 batch =  400 / 3597 loss =  -3.6584784904323016
epoch =  0

epoch =  0 batch =  3500 / 3597 loss =  -4.565562218912855
epoch =  0 batch =  3525 / 3597 loss =  -4.566952724552789
epoch =  0 batch =  3550 / 3597 loss =  -4.568034326315117
epoch =  0 batch =  3575 / 3597 loss =  -4.570064053272763
Validation loss =  -4.777291297912598
epoch =  1 batch =  0 / 3597 loss =  -4.485310077667236
epoch =  1 batch =  25 / 3597 loss =  -4.765618782777053
epoch =  1 batch =  50 / 3597 loss =  -4.768677178551169
epoch =  1 batch =  75 / 3597 loss =  -4.768467627073589
epoch =  1 batch =  100 / 3597 loss =  -4.798830084281391
epoch =  1 batch =  125 / 3597 loss =  -4.797404497388808
epoch =  1 batch =  150 / 3597 loss =  -4.791207376694835
epoch =  1 batch =  175 / 3597 loss =  -4.79660339518027
epoch =  1 batch =  200 / 3597 loss =  -4.788188614062408
epoch =  1 batch =  225 / 3597 loss =  -4.7744540978322005
epoch =  1 batch =  250 / 3597 loss =  -4.78298174337562
epoch =  1 batch =  275 / 3597 loss =  -4.776787820069687
epoch =  1 batch =  300 / 3597 loss 

epoch =  1 batch =  3375 / 3597 loss =  -4.781784652180591
epoch =  1 batch =  3400 / 3597 loss =  -4.78300388719503
epoch =  1 batch =  3425 / 3597 loss =  -4.783600207385593
epoch =  1 batch =  3450 / 3597 loss =  -4.784383395277093
epoch =  1 batch =  3475 / 3597 loss =  -4.784761549482696
epoch =  1 batch =  3500 / 3597 loss =  -4.784735959994738
epoch =  1 batch =  3525 / 3597 loss =  -4.785197779122101
epoch =  1 batch =  3550 / 3597 loss =  -4.785001984182404
epoch =  1 batch =  3575 / 3597 loss =  -4.785553956258494
Validation loss =  -4.815684795379639
epoch =  2 batch =  0 / 3597 loss =  -5.163106441497803
epoch =  2 batch =  25 / 3597 loss =  -4.769309777479906
epoch =  2 batch =  50 / 3597 loss =  -4.753867990830365
epoch =  2 batch =  75 / 3597 loss =  -4.773739588888066
epoch =  2 batch =  100 / 3597 loss =  -4.754236466813794
epoch =  2 batch =  125 / 3597 loss =  -4.773835553063285
epoch =  2 batch =  150 / 3597 loss =  -4.7789648368658595
epoch =  2 batch =  175 / 3597