In [None]:
import torch
import uproot
import numpy as np
import matplotlib.pyplot as plt
import os.path as path
import sys
sys.path.append(path.abspath("src/"))

from norm_flow.realnvp import RealNVP, dual_layer
from norm_flow.utils import *
torch.manual_seed(123)
np.random.seed(123)

In [None]:
# load models
run_path = "/home/zhangbw/Documents/projects/ttbar-unfolding/run"
encoder_model = "realnvp_reco.pth"
decoder_model = "realnvp_truth.pth"

encoder = RealNVP(dual_layer * 2, 32)
decoder = RealNVP(dual_layer * 2, 32)

encoder.load_state_dict(torch.load(path.join(run_path, encoder_model)))
decoder.load_state_dict(torch.load(path.join(run_path, decoder_model)))

In [None]:
file_name = "reco_analysis__ttbar_nlo_ATLAS_PileUp_Test_Sel.root"

df = get_dateframe(path.join(run_path, file_name), "reco")
df_truth = get_dateframe(path.join(run_path, file_name), "truth")

st, _ = quantile_scaling(df, "ST")
b0_pt, _ = quantile_scaling(df, "b0_Pt")
st_truth, quantile = quantile_scaling(df_truth, "ST_truth")

In [None]:
x = np.concatenate((st.reshape(-1, 1), b0_pt.reshape(-1, 1)), axis=1)
w = df["weight"].to_numpy()

x = torch.from_numpy(x).float()
w = torch.from_numpy(w).float()

x, w, idx = get_batch(x, w, df_truth.weight.count())
N = x.shape[0]

In [None]:
z, log_det = encoder.inverse(x)
# decoder only for validation purpose, poorer performance than encoder
# z = torch.randn(x.shape, generator=gen)
x_trans, log_det = decoder.forward(z)
z_ = z.detach().numpy()
x_ = x.detach().numpy()
x_trans_ = x_trans.detach().numpy()

In [None]:
draw_dist2d3(x_, x_trans_, z_)

In [None]:
q25, q50, q75 = quantile

x_truth_ = df_truth["ST_truth"]
x_ = df["ST"]
x_trans_[:, 0] = (q75 - q25) * x_trans_[:, 0] + q50
x_trans_ = x_trans_[:, 0]

In [None]:
w = df["weight"]
w_t = df_truth["weight"]
# # random shuffle of truth weights: probably not good practice
# w_unfold = w_t.to_numpy().copy()
# np.random.shuffle(w_unfold)
# or use the original weights
idx = idx.numpy()
np.random.shuffle(idx)
w_unfold = df["weight"].to_numpy()[idx]

In [None]:
draw_hist3(x_, x_trans_, x_truth_, w, w_unfold, w_t)

In [None]:
output = uproot.recreate(path.join(run_path, "realnvp_output.root"))
output["reco"] = {"ST": x_, "weight": w}
output["unfold"] = {"ST_NF": x_trans_, "weight": w_unfold}
output["truth"] = {"ST_truth": x_truth_, "weight": w_t}