In [None]:
import torch
from JetGraphProducer import JetGraphProducer
import numpy as np, awkward as ak
import uproot

We need to calculate MT on the fly since the version with AK15 Jets is not stored in the ntuples. This will be loaded by the `JetGraphProducer` and ran event-by-event

In [None]:
def mt(event):
    """
    Calculates the transverse mass MT and RT (closely related calcs)
    """
    met_x = np.cos(event.METPhi) * event.MET
    met_y = np.sin(event.METPhi) * event.MET
    jet_phi = event["JetsAK15/JetsAK15.fCoordinates.fPhi"][1]
    jet_pt = event["JetsAK15/JetsAK15.fCoordinates.fPt"][1]
    jet_e = event["JetsAK15/JetsAK15.fCoordinates.fE"][1]
    jet_x = np.cos(jet_phi) * jet_pt
    jet_y = np.sin(jet_phi) * jet_pt
    # jet_e = np.sqrt(jets.mass2 + jets.pt**2)
    # m^2 + pT^2 = E^2 - pT^2 - pz^2 + pT^2 = E^2 - pz^2
    pz = jet_pt * np.sinh(event["JetsAK15/JetsAK15.fCoordinates.fEta"][1])
    transverse_e = np.sqrt(jet_e**2 - pz**2)
    mt = np.sqrt( (transverse_e + event.MET)**2 - (jet_x + met_x)**2 - (jet_y + met_y)**2 )

    return mt

We load the signal and background data as graphs. The frist time the data is loaded, the rootfiles are processed into `torch_geometric.data.InMemoryDataset` objects and also stored on diks

In [None]:
signal = JetGraphProducer(
    "test_data",
    n_store_jets=2,
    use_lund_decomp=True,
    n_lund_vars=5,
    weights="xsec",
    extra_obs_to_load=["MET", "METPhi"],
    extra_obs_to_compute_per_event=[mt],
    input_format="TreeMaker2",
    jet_collection="JetsAK15",
    verbose=True,
    mask=True,
    max_events_to_process=3000,
    label=1.,
)

In [None]:

background = JetGraphProducer(
    "test_data_bkg",
    n_store_jets=2,
    use_lund_decomp=True,
    n_lund_vars=5,
    weights="xsec",
    extra_obs_to_load=["MET", "METPhi"],
    extra_obs_to_compute_per_event=[mt],
    input_format="TreeMaker2",
    jet_collection="JetsAK15",
    verbose=True,
    mask=True,
    max_events_to_process=3000,
    label=0.,
)

We now preprocess the data by normalizing the node features and splitting it into training and testing, and merging them into one labelled dataset

In [None]:
from LundTreeUtilities import OnTheFlyNormalizer
from torch.utils.data import ConcatDataset

signal_training = signal[::2][:int(len(signal)/2*0.8)]
signal_testing = signal[::2][int(len(signal)/2*0.8):]

background_training = background[:int(len(background)*0.8)]
background_testing = background[int(len(background)*0.8):]

weigths_signal_training = signal_training.w
weigths_background_training = background_training.w
weigths_signal_testing = signal_testing.w
weigths_background_testing = background_testing.w

means, stds = 0., 0.

for graph in signal_training:
    means += graph.x.sum(dim=0)*graph.w
for graph in background_training:
    means += graph.x.sum(dim=0)*graph.w
for graph in signal_training:
    stds += ((graph.x - means)**2).sum(dim=0)*graph.w
for graph in background_training:
    stds += ((graph.x - means)**2).sum(dim=0)*graph.w

stds /= (signal_training.w.sum()+background_training.w.sum())
stds = torch.sqrt(stds)

# Careful that the normalizer is applied only once: slices in torch_geometric are actually only masks,
# so _training and _testing objects share the same underlying tensor

normalizer = OnTheFlyNormalizer(["x"], means, stds)
normalizer(signal_training.data)
normalizer(background_training.data)

data_training = ConcatDataset((signal_training, background_training))
data_testing = ConcatDataset((signal_testing, background_testing))
weights = torch.cat((
    weigths_signal_training/weigths_signal_training.sum(),
    weigths_background_training/weigths_background_training.sum()
    ))
weights_testing = torch.cat((
    weigths_signal_testing/weigths_signal_testing.sum(),
    weigths_background_testing/weigths_background_testing.sum()
    ))

Define the model

In [None]:
from architectures import LundNet

# torch.set_num_threads(2)

n_lund_vars = 5
add_fractions = True
num_classes = 1
conv_params = [[32, 32], [32, 32], [64, 64], [64, 64], [128, 128], [128, 128]]
fc_params = [(128, 0.8)]

model = LundNet(
    conv_params=conv_params,
    fc_params=fc_params,
    input_dims=n_lund_vars,
    use_fusion=True,
    num_classes=num_classes,
    add_fractions_to_lund=add_fractions,
)

Start training

In [None]:
from torch_geometric.loader import DataLoader
from torch.utils.data import WeightedRandomSampler
from tqdm import tqdm

sampler = WeightedRandomSampler(weights/weights.sum(), len(data_training), replacement=True)
loader = DataLoader(data_training, batch_size=128, sampler=sampler)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

n_epochs = 5
loss_history = []

for epoch in range(n_epochs):
    for batch in tqdm(loader, desc=f"Training epoch {epoch}", leave=False):
        y_pred = model(batch)
        loss = torch.nn.BCELoss()(y_pred[:,0], batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_history)
plt.xlabel("Training step")
plt.ylabel("Loss")
plt.show()

Check performance on the test dataset

In [None]:
sampler_testing = WeightedRandomSampler(
    weights_testing/weights_testing.sum(),
    len(data_testing),
    replacement = True
)

loader_testing = DataLoader(data_testing, sampler=sampler_testing, batch_size=len(data_testing))

scores_testing = []
labels_testing = []

with torch.no_grad():
    for batch in loader_testing:
        scores_testing.append(model(batch))
        labels_testing.append(batch.y)

In [None]:
scores_testing = scores_testing[0][:,0]

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score

fpr, tpr, _ = roc_curve(labels_testing[0], scores_testing)
auc = roc_auc_score(labels_testing[0], scores_testing)

In [None]:
plt.plot(fpr, tpr, label=f"ROC AUC: {auc:.3f}")
plt.legend()
plt.xlabel("fpr")
plt.ylabel("tpr")
plt.show()

In [None]:
plt.hist(scores_testing[labels_testing[0] == 1.], label="Signal")
plt.hist(scores_testing[labels_testing[0] == 0.], label="Background")
plt.xlabel("LundNET score")
plt.legend()
plt.show()