In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sys

In [2]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from torch.utils.tensorboard import SummaryWriter

In [4]:
import subprocess
import time
import os
from copy import deepcopy
import math as m
import gc

## Tensorboard writer for loss logging

In [5]:
writer = SummaryWriter()

## GPU/CPU selection

In [6]:
device = torch.device("cuda:1")
#device = torch.device("cpu")

## Reweighting strategy - must be min, mean or max

In [7]:
reference_method = 'max'

## Hyperparameters

In [8]:
n_RQS_knots = 10   # Number of knots in RQS transform
n_made_layers = 1  # Number of hidden layers in every made network
n_made_units = 100 # Number of units in every layer of the made network
n_flow_layers = 6  # Number of layers in the flow

batch_size = 1024
n_epochs = 800
adam_lr = 0.001     # Learning rate for the ADAM optimizer (default: 0.001)

n_train = int(1e6)  # Number of training events
n_test = int(1e5)   # Number of testing events
n_sample = int(1e6) # Number of samples for ess evaluation

## Load the train data and reweight

In [None]:
# Load
train_samples = np.genfromtxt("data/weighted_samples.csv", delimiter=',')[:n_train]
train_weights = np.genfromtxt("data/weighted_weights.csv", delimiter=',')[:n_train]

# Get reference weight
reference_methods = {'min': np.amin, 'mean': np.mean, 'max': np.amax}
ref_weight = reference_methods[reference_method](train_weights)

# Reweighting and rejection sampling
train_weights = train_weights / ref_weight
p_rejection_sampling = np.random.rand(len(train_weights))
select = p_rejection_sampling < train_weights

train_samples = train_samples[select]
train_weights = train_weights[select]
train_weights[train_weights < 1.0] = 1.0

# Normalize weights
train_weights /= train_weights.mean()

# Convert to torch tensors
train_samples = torch.tensor(train_samples, dtype=torch.float32, device=device)
train_weights = torch.tensor(train_weights, dtype=torch.float32, device=device)

## Load the test data

In [None]:
test_samples = torch.tensor(np.genfromtxt("data/unweighted_samples.csv", delimiter=',')[:n_test], dtype=torch.float32, device=device)

## Set up the flow

In [None]:
event_dim = train_samples.shape[1]
base_dist = BoxUniform(torch.zeros(event_dim), torch.ones(event_dim))

transforms = []
for _ in range(n_flow_layers):
    transforms.append(RandomPermutation(features=event_dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=event_dim, 
        hidden_features=n_made_units,
        num_bins=n_RQS_knots,
        num_blocks=n_made_layers-1,
        tails="constrained",
        use_residual_blocks=False
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters(), lr=adam_lr)

scheduler = MultiStepLR(optimizer, milestones=[350, 425, 500, 575, 650, 725, 800], gamma=0.5)

## Training

In [None]:
data_size = train_samples.shape[0]
n_batches = m.ceil(data_size/batch_size)

data_size_validation = test_samples.shape[0]
n_batches_validate = m.ceil(data_size_validation/batch_size)

best_validation_loss = np.inf
best_ess = 0

for epoch in range(n_epochs):
    
    permutation = torch.randperm(data_size, device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        samples_batch = train_samples[indices]
        weights_batch = train_weights[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -(flow.log_prob(inputs=samples_batch)*weights_batch).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        if batch%25 == 0:
            print("epoch = ", epoch, "batch = ", batch, "/", n_batches, "loss = ", cum_loss)
    
    writer.add_scalar("Loss_train", cum_loss, epoch)
    scheduler.step()
    

    # ---------- Compute validation loss -----------
    validation_loss = 0
    for batch in range(n_batches_validate):
        batch_begin = batch*batch_size
        batch_end = min( (batch+1)*batch_size, data_size_validation-1 )
        samples_batch = test_samples[batch_begin:batch_end]
    
        with torch.no_grad():
            validation_loss = (validation_loss*batch - (flow.log_prob(samples_batch)).mean())/(batch+1)

    print("Validation loss = ", validation_loss.item())
    writer.add_scalar("Loss_test", validation_loss.item(), epoch)

    if validation_loss < best_validation_loss:
        torch.save(flow, "flow_model_weighted_{}_best_validation.pt".format(reference_method))
        best_validation_loss = validation_loss

    
    # ---------- Compute effective sample size ----------
    # generate samples and evaluate llhs
    with torch.no_grad():
        samples = flow.sample(n_sample)
        llhs = flow.log_prob(samples)

    # Store files
    np.savetxt("/tmp/samples_weighted_file.csv", samples.cpu().numpy(), delimiter=',')
    np.savetxt("/tmp/llhs_weighted_file.csv", np.exp(llhs.cpu().numpy()), delimiter=',')

    # Run the evaluator
    cmd = os.path.abspath(os.getcwd())+'/ME_VEGAS/compute_metrics_from_likelihoods /tmp/samples_weighted_file.csv /tmp/llhs_weighted_file.csv'
    b = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout
    lines = b.decode('ascii').split("\n")

    ess = float(lines[2].split(' ')[-1])
    
    print("Effective sample size = ", ess)
    writer.add_scalar("Effective_sample_size", ess, epoch)

    if ess > best_ess:
        torch.save(flow, "flow_model_weighted_{}_best_ess.pt".format(reference_method))
        best_ess = ess
        
torch.save(flow, "flow_model_weighted_{}_final.pt".format(reference_method))

epoch =  0 batch =  0 / 63 loss =  -0.4645250141620636
epoch =  0 batch =  25 / 63 loss =  -6.297495146210376
epoch =  0 batch =  50 / 63 loss =  -10.057873146791087
Validation loss =  -18.169567108154297
Effective sample size =  0.000225586
epoch =  1 batch =  0 / 63 loss =  -18.132858276367188
epoch =  1 batch =  25 / 63 loss =  -19.33813711313101
epoch =  1 batch =  50 / 63 loss =  -20.12898908876905
Validation loss =  -21.515731811523438
Effective sample size =  0.0182868
epoch =  2 batch =  0 / 63 loss =  -21.639537811279297
epoch =  2 batch =  25 / 63 loss =  -21.644849116985615
epoch =  2 batch =  50 / 63 loss =  -21.7620132296693
Validation loss =  -22.00618553161621
Effective sample size =  0.0370023
epoch =  3 batch =  0 / 63 loss =  -22.028186798095703
epoch =  3 batch =  25 / 63 loss =  -22.100346638606144
epoch =  3 batch =  50 / 63 loss =  -22.156181036257276
Validation loss =  -22.239023208618164
Effective sample size =  0.0489558
epoch =  4 batch =  0 / 63 loss =  -22.3

epoch =  34 batch =  25 / 63 loss =  -23.03918280968299
epoch =  34 batch =  50 / 63 loss =  -23.04936674529431
Validation loss =  -22.898420333862305
Effective sample size =  0.137145
epoch =  35 batch =  0 / 63 loss =  -22.949687957763672
epoch =  35 batch =  25 / 63 loss =  -23.03973212608924
epoch =  35 batch =  50 / 63 loss =  -23.011451309802485
Validation loss =  -22.915550231933594
Effective sample size =  0.139481
epoch =  36 batch =  0 / 63 loss =  -23.028106689453125
epoch =  36 batch =  25 / 63 loss =  -22.96799711080698
epoch =  36 batch =  50 / 63 loss =  -22.97469845940085
Validation loss =  -22.982295989990234
Effective sample size =  0.15824
epoch =  37 batch =  0 / 63 loss =  -23.19687843322754
epoch =  37 batch =  25 / 63 loss =  -23.013101431039665
epoch =  37 batch =  50 / 63 loss =  -23.033437093098964
Validation loss =  -22.97371482849121
Effective sample size =  0.159378
epoch =  38 batch =  0 / 63 loss =  -22.92870330810547
epoch =  38 batch =  25 / 63 loss =  

epoch =  68 batch =  25 / 63 loss =  -23.130938383249138
epoch =  68 batch =  50 / 63 loss =  -23.12338241876341
Validation loss =  -23.010623931884766
Effective sample size =  0.164685
epoch =  69 batch =  0 / 63 loss =  -23.307445526123047
epoch =  69 batch =  25 / 63 loss =  -23.13263232891376
epoch =  69 batch =  50 / 63 loss =  -23.121554468192308
Validation loss =  -23.02635955810547
Effective sample size =  0.16904
epoch =  70 batch =  0 / 63 loss =  -23.241802215576172
epoch =  70 batch =  25 / 63 loss =  -23.10629353156457
epoch =  70 batch =  50 / 63 loss =  -23.128745770921896
Validation loss =  -23.023773193359375
Effective sample size =  0.156361
epoch =  71 batch =  0 / 63 loss =  -23.018375396728516
epoch =  71 batch =  25 / 63 loss =  -23.107207518357498
epoch =  71 batch =  50 / 63 loss =  -23.12294010087555
Validation loss =  -23.05532455444336
Effective sample size =  0.181021
epoch =  72 batch =  0 / 63 loss =  -23.2298583984375
epoch =  72 batch =  25 / 63 loss =  

epoch =  102 batch =  25 / 63 loss =  -23.167951803940994
epoch =  102 batch =  50 / 63 loss =  -23.145322724884632
Validation loss =  -23.097145080566406
Effective sample size =  0.193805
epoch =  103 batch =  0 / 63 loss =  -23.26131820678711
epoch =  103 batch =  25 / 63 loss =  -23.150866215045635
epoch =  103 batch =  50 / 63 loss =  -23.139224781709558
Validation loss =  -23.08690071105957
Effective sample size =  0.184646
epoch =  104 batch =  0 / 63 loss =  -23.300537109375
epoch =  104 batch =  25 / 63 loss =  -23.180049969599796
epoch =  104 batch =  50 / 63 loss =  -23.177788379145603
Validation loss =  -22.971410751342773
Effective sample size =  0.155994
epoch =  105 batch =  0 / 63 loss =  -22.91793441772461
epoch =  105 batch =  25 / 63 loss =  -23.210590802706204
epoch =  105 batch =  50 / 63 loss =  -23.20238629509421
Validation loss =  -23.098913192749023
Effective sample size =  0.192206
epoch =  106 batch =  0 / 63 loss =  -23.408803939819336
epoch =  106 batch =  2