In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os.path as path
import sys
sys.path.append(path.abspath("src/"))

from norm_flow.realnvp import RealNVP_2D, dual_layer
from norm_flow.utils import *
torch.manual_seed(123)

# hyperparameters
masks = dual_layer * 2
hidden_dim = 32
max_iter = 40000
batch_size = 2048
eval_step = 1000

In [None]:
# model, optimizer
realNVP = RealNVP_2D(masks, hidden_dim).to(device)
optimizer = torch.optim.AdamW(realNVP.parameters(), lr=0.001)
warm_up = torch.optim.lr_scheduler.LinearLR(optimizer, 1e-3, 1, 8000)
decay = torch.optim.lr_scheduler.StepLR(optimizer, max_iter / 4, 0.5)
scheduler = torch.optim.lr_scheduler.ChainedScheduler([warm_up, decay])

sum(p.numel() for p in realNVP.parameters() if p.requires_grad)

In [None]:
# data
run_path = "/home/zhangbw/Documents/projects/ttbar-unfolding/run"
file_name = "reco_analysis__ttbar_nlo_ATLAS_PileUp.root"
tree_name = "truth" # "reco"

df = get_dateframe(path.join(run_path, file_name), tree_name)
display(df)

# processing
quantile_scaling_(df, "ST_truth")
quantile_scaling_(df, "t0_truth_Pt")

print(f"weight_median = {df.weight.median()}")
df["weight"] /= df["weight"].median()

# reject outliers
df = df[df.weight > 0]

x = df[["ST_truth", "t0_truth_Pt"]].to_numpy()
w = df["weight"].to_numpy()
x, w, x_val, w_val, N = train_val_split(df, x, w, device=device)

In [None]:
step, loss_train_step, loss_val_step = train(
    realNVP, x, w, x_val, w_val, max_iter, eval_step, batch_size, optimizer, scheduler
)
torch.save(realNVP.state_dict(), path.join(run_path, "realnvp_truth.pth"))

In [None]:
plt.figure()
plt.plot(np.array(step), np.array(loss_train_step))
plt.plot(np.array(step), np.array(loss_val_step))

In [None]:
x, _, _ = get_batch(x_val, w_val, 10000)
z, log_det = realNVP.inverse(x)
z = z.detach().cpu().numpy()
x = x.detach().cpu().numpy()
draw_dist2d2(x, z)