In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:1")
#device = torch.device("cpu")

## Hyperparameters

In [7]:
n_RQS_knots = 5   # Number of knots in RQS transform
n_made_layers = 0  # Number of hidden layers in every made network
n_made_units = 10 # Number of units in every layer of the made network
n_flow_layers = 6  # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001     # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6)  # Number of training events
n_test = int(1e5)   # Number of testing events
n_sample = int(1e6) # Number of samples for ess evaluation

## Load the training data

In [8]:
samples = np.genfromtxt("data/unweighted_samples.csv", delimiter=',')
if (n_train + n_test > samples.shape[0]):
    raise Exception("Not enough training data")

## Split to a train and test set

In [None]:
train_samples = torch.tensor(samples[:n_train], dtype=torch.float32, device=device)
test_samples = torch.tensor(samples[n_train:n_train+n_test], dtype=torch.float32, device=device)

del samples
gc.collect()

59

## Set up the flow

In [None]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_validation_loss = np.inf
best_ess = 0

for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    scheduler.step()


    # ---------- Compute validation loss -----------
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)).mean())/(batch+1)

    print("Validation loss = ", validation_loss.item())
    writer.add_scalar("Loss_test", validation_loss.item(), epoch)

    if validation_loss < best_validation_loss:
        torch.save(flow, "flow_model_unweighted_best_validation.pt")
        best_validation_loss = validation_loss

    
    # ---------- Compute effective sample size ----------
    # generate samples and evaluate llhs
    with torch.no_grad():
        samples = flow.sample(n_sample)
        llhs = flow.log_prob(samples)

    # Store files
    np.savetxt("/tmp/samples_file.csv", samples.cpu().numpy(), delimiter=',')
    np.savetxt("/tmp/llhs_file.csv", np.exp(llhs.cpu().numpy()), delimiter=',')

    # Run the evaluator
    cmd = os.path.abspath(os.getcwd())+'/ME_VEGAS/compute_metrics_from_likelihoods /tmp/samples_file.csv /tmp/llhs_file.csv'
    b = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout
    lines = b.decode('ascii').split("\n")

    ess = float(lines[2].split(' ')[-1])
    
    print("Effective sample size = ", ess)
    writer.add_scalar("Effective_sample_size", ess, epoch)

    if ess > best_ess:
        torch.save(flow, "flow_model_unweighted_best_ess.pt")
        best_ess = ess
        
torch.save(flow, "flow_model_unweighted_final.pt")

epoch =  0 batch =  0 / 977 loss =  0.36687445640563965
epoch =  0 batch =  25 / 977 loss =  -2.0629490261467605
epoch =  0 batch =  50 / 977 loss =  -3.618684484239887
epoch =  0 batch =  75 / 977 loss =  -5.101591120130921
epoch =  0 batch =  100 / 977 loss =  -6.667597164390701
epoch =  0 batch =  125 / 977 loss =  -7.950408336011663
epoch =  0 batch =  150 / 977 loss =  -9.043609794648676
epoch =  0 batch =  175 / 977 loss =  -10.07322311828929
epoch =  0 batch =  200 / 977 loss =  -11.037217718044031
epoch =  0 batch =  225 / 977 loss =  -11.898104333409433
epoch =  0 batch =  250 / 977 loss =  -12.648688003984821
epoch =  0 batch =  275 / 977 loss =  -13.305771851156289
epoch =  0 batch =  300 / 977 loss =  -13.882847493792891
epoch =  0 batch =  325 / 977 loss =  -14.40396438431612
epoch =  0 batch =  350 / 977 loss =  -14.872637476443058
epoch =  0 batch =  375 / 977 loss =  -15.303768905652491
epoch =  0 batch =  400 / 977 loss =  -15.702826970328864
epoch =  0 batch =  425 / 

epoch =  3 batch =  475 / 977 loss =  -23.69041650034801
epoch =  3 batch =  500 / 977 loss =  -23.690577920087566
epoch =  3 batch =  525 / 977 loss =  -23.692342384686494
epoch =  3 batch =  550 / 977 loss =  -23.69088866965958
epoch =  3 batch =  575 / 977 loss =  -23.690294189585586
epoch =  3 batch =  600 / 977 loss =  -23.69194695318797
epoch =  3 batch =  625 / 977 loss =  -23.690567613790584
epoch =  3 batch =  650 / 977 loss =  -23.689800069079418
epoch =  3 batch =  675 / 977 loss =  -23.68904068625184
epoch =  3 batch =  700 / 977 loss =  -23.68867304389725
epoch =  3 batch =  725 / 977 loss =  -23.68964217845402
epoch =  3 batch =  750 / 977 loss =  -23.690529809334624
epoch =  3 batch =  775 / 977 loss =  -23.69125203496402
epoch =  3 batch =  800 / 977 loss =  -23.6910340336527
epoch =  3 batch =  825 / 977 loss =  -23.69153913400941
epoch =  3 batch =  850 / 977 loss =  -23.69215629747976
epoch =  3 batch =  875 / 977 loss =  -23.69311002182634
epoch =  3 batch =  900 / 

epoch =  6 batch =  950 / 977 loss =  -23.759752676690017
epoch =  6 batch =  975 / 977 loss =  -23.759892115827444
Validation loss =  -23.76146697998047
Effective sample size =  0.588077
epoch =  7 batch =  0 / 977 loss =  -23.683998107910156
epoch =  7 batch =  25 / 977 loss =  -23.7843258197491
epoch =  7 batch =  50 / 977 loss =  -23.795803331861308
epoch =  7 batch =  75 / 977 loss =  -23.801229803185716
epoch =  7 batch =  100 / 977 loss =  -23.797364659828716
epoch =  7 batch =  125 / 977 loss =  -23.790553214058043
epoch =  7 batch =  150 / 977 loss =  -23.78041722127144
epoch =  7 batch =  175 / 977 loss =  -23.782222997058522
epoch =  7 batch =  200 / 977 loss =  -23.775672694343832
epoch =  7 batch =  225 / 977 loss =  -23.774108355024218
epoch =  7 batch =  250 / 977 loss =  -23.771767863239425
epoch =  7 batch =  275 / 977 loss =  -23.770421283832494
epoch =  7 batch =  300 / 977 loss =  -23.77311010772604
epoch =  7 batch =  325 / 977 loss =  -23.773515718846237
epoch =  

epoch =  10 batch =  400 / 977 loss =  -23.795546771879504
epoch =  10 batch =  425 / 977 loss =  -23.794366518656414
epoch =  10 batch =  450 / 977 loss =  -23.79383095616512
epoch =  10 batch =  475 / 977 loss =  -23.794841209379566
epoch =  10 batch =  500 / 977 loss =  -23.795906413339093
epoch =  10 batch =  525 / 977 loss =  -23.79535579681397
epoch =  10 batch =  550 / 977 loss =  -23.79637775663456
epoch =  10 batch =  575 / 977 loss =  -23.79573469029533
epoch =  10 batch =  600 / 977 loss =  -23.79641585580125
epoch =  10 batch =  625 / 977 loss =  -23.79745897165122
epoch =  10 batch =  650 / 977 loss =  -23.79727296492289
epoch =  10 batch =  675 / 977 loss =  -23.798857054061454
epoch =  10 batch =  700 / 977 loss =  -23.79937880457555
epoch =  10 batch =  725 / 977 loss =  -23.80007571753722
epoch =  10 batch =  750 / 977 loss =  -23.80014764961011
epoch =  10 batch =  775 / 977 loss =  -23.79868963084272
epoch =  10 batch =  800 / 977 loss =  -23.79919515983593
epoch =  

epoch =  13 batch =  825 / 977 loss =  -23.813917402493733
epoch =  13 batch =  850 / 977 loss =  -23.815979178727847
epoch =  13 batch =  875 / 977 loss =  -23.814537727669478
epoch =  13 batch =  900 / 977 loss =  -23.814293593598244
epoch =  13 batch =  925 / 977 loss =  -23.813726647626233
epoch =  13 batch =  950 / 977 loss =  -23.81303695574417
epoch =  13 batch =  975 / 977 loss =  -23.81292539541837
Validation loss =  -23.799545288085938
Effective sample size =  0.609028
epoch =  14 batch =  0 / 977 loss =  -23.822444915771484
epoch =  14 batch =  25 / 977 loss =  -23.824701969440166
epoch =  14 batch =  50 / 977 loss =  -23.843042897243127
epoch =  14 batch =  75 / 977 loss =  -23.8301887512207
epoch =  14 batch =  100 / 977 loss =  -23.8123352692859
epoch =  14 batch =  125 / 977 loss =  -23.813177063351585
epoch =  14 batch =  150 / 977 loss =  -23.80861771185667
epoch =  14 batch =  175 / 977 loss =  -23.810181303457792
epoch =  14 batch =  200 / 977 loss =  -23.81432901448