In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os.path as path
import sys
sys.path.append(path.abspath("src/"))

from norm_flow.realnvp import RealNVP_2D, dual_layer
from norm_flow.utils import *
torch.manual_seed(42)

# hyperparameters
masks = dual_layer * 2
hidden_dim = 16
max_iter = 40000
batch_size = 1024
eval_step = 1000

step = list()
loss_train = list()
loss_train_step = list()
loss_val_step = list()

In [None]:
# model, optimizer
realNVP = RealNVP_2D(masks, hidden_dim)
optimizer = torch.optim.AdamW(realNVP.parameters(), lr=0.001)
warm_up = torch.optim.lr_scheduler.LinearLR(optimizer, 1e-3, 1, 4000)
decay = torch.optim.lr_scheduler.StepLR(optimizer, max_iter / 4, 0.5)
scheduler = torch.optim.lr_scheduler.ChainedScheduler([warm_up, decay])

sum(p.numel() for p in realNVP.parameters() if p.requires_grad)

In [None]:
# data
run_path = "/home/zhangbw/Documents/projects/ttbar-unfolding/run"
file_name = "reco_analysis__ttbar_nlo_ATLAS_PileUp.root"
tree_name = "reco" # "truth"

df = get_dateframe(path.join(run_path, file_name), tree_name)
df

In [None]:
# processing
quantile_scaling_(df, "ST")
quantile_scaling_(df, "b1_Pt")

print(f"weight_median = {df.weight.median()}")
df["weight"] /= df["weight"].median()

# reject outliers
df = df[df.weight > 0]

In [None]:
x = df[["ST", "b1_Pt"]].to_numpy()
w = df["weight"].to_numpy()

x = torch.from_numpy(x).float()
w = torch.from_numpy(w).float()
N = x.shape[0]
N_split = int(0.8 * N)
x_val, w_val = x[N_split:], w[N_split:]
x, w = x[:N_split], w[:N_split]
N -= N_split

N, x, w, w.shape

In [None]:
for i in range(max_iter):
    if i > 0 and i % eval_step == 0:
        with torch.no_grad():
            train_loss = np.array(loss_train[-eval_step:]).mean().item()
            val_loss = validate(realNVP, x_val, w_val)
            print(
                f"{i:6d} / {max_iter:6d}, "
                f"train_loss={train_loss:.6f}, "
                f"val_loss={val_loss:.6f}, "
                f"lr = {scheduler.get_last_lr()[0]:.6f}"
            )
            step.append(i)
            loss_train_step.append(train_loss)
            loss_val_step.append(val_loss)
    xb, wb = get_batch(x, w, batch_size)
    z, log_det = realNVP.inverse(xb)
    loss = torch.log(two_pi) + torch.mean(wb * (torch.sum(0.5 * z**2, -1) - log_det))
    loss_train.append(loss.item())
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    scheduler.step()

torch.save(realNVP.state_dict(), path.join(run_path, "realnvp_reco.pth"))
plt.figure()
plt.plot(np.array(step), np.array(loss_train_step))
plt.plot(np.array(step), np.array(loss_val_step))

In [None]:
x, _ = get_batch(x_val, w_val, 10000)
z, log_det = realNVP.inverse(x)
z = z.detach().numpy()
x = x.detach().numpy()
x, z

In [None]:
fig = plt.figure(2, figsize = (10, 4))
plt.subplot(1,2,1)
plt.plot(x[:, 0], x[:, 1], ".")
plt.title("Observed distribution")
plt.xlabel(r"$S_{T}$")
plt.ylabel(r"sub b-jet $p_T$")

plt.subplot(1,2,2)
plt.plot(z[:, 0], z[:, 1], ".")
plt.title("Latent distribution")
plt.xlabel(r"$z_{0}$")
plt.ylabel(r"$z_{1}$")
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.show()