# PHYS805 Final Project Notebook

In [None]:
import awkward as ak
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import torch
import wandb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
from sklearn.decomposition import PCA
import mplhep as hep
import pandas as pd
import random
import datetime
from pprint import pprint
plt.style.use(hep.style.CMS)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# auto reload of imported modules
%load_ext autoreload
%autoreload 2

from utils import data_utils
from utils import dataloader
from utils import metric
from utils import model
from utils import training

%config InlineBackend.figure_format = 'retina'

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)

Setting up the data. This includes selecting the features, loading the data from the `.root` files specified in the configuration file, and constructing the data loaders using the `JetDataClass`.

In [None]:
with open("datasets.yaml", "r") as f:
    ds_cfg = yaml.safe_load(f)

features = [
    "Jet_pt",
    "Jet_eta",
    "Jet_phi",
    "Jet_mass",
    "Jet_nConstituents",
    "Jet_nSVs",
    "Jet_area",
]
other_branches = [
    "nJet",
    "Pileup_nPU",
]
branches = features + other_branches
num_ftrs = len(features)

test_split = 0.2
val_split = 0.5

In [None]:
sig = data_utils.load_data(ds_cfg, "EMJ", filter_name=branches, entry_stop=-1)
bkg = data_utils.load_data(ds_cfg, "QCD", filter_name=branches, entry_stop=100_000)
bkg = data_utils.match_pu(sig, bkg)

In [None]:
nPU = torch.tensor(
    np.concatenate([
        ak.to_numpy(sig["Pileup_nPU"]),
        ak.to_numpy(bkg["Pileup_nPU"])
    ]),
    dtype=torch.float32
)

# njets = int(max(ak.max(sig["nJet"]), ak.max(bkg["nJet"])))
njets = 10
sig_tensor = data_utils.ak_to_torch(sig, features, njets, label=1)
bkg_tensor = data_utils.ak_to_torch(bkg, features, njets, label=0)
data_tensor = torch.cat([sig_tensor, bkg_tensor], dim=0)
rnd_idx = torch.randperm(data_tensor.size(0))
data_tensor = data_tensor[rnd_idx]
nPU = nPU[rnd_idx]

In [None]:
# Split
X_train, X_temp, y_train, y_temp = train_test_split(
    data_tensor[..., :-1],
    data_tensor[:, 0, -1],
    test_size=test_split,
    shuffle=True,
)

# test and val split
X_val, X_test, y_val, y_test = train_test_split(
    X_temp,
    y_temp,
    test_size= val_split
)

print(y_train.shape, y_val.shape, y_test.shape)
print(X_train.shape, X_val.shape, X_test.shape)

In [None]:
norm_constants = data_utils.compute_norm_constants(X_train)
train_ds = dataloader.JetDataset(X_train, norm_constants, y_train)
val_ds = dataloader.JetDataset(X_val, norm_constants, y_val)
test_ds = dataloader.JetDataset(X_test, norm_constants, y_test)

In [None]:
batch_size = 512
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [None]:
pprint(norm_constants)

## Features Study

In this section, we explore the features selected above.

In [None]:
# Plot njets distribution
fig, ax = plt.subplots()
sig_njets = ak.to_numpy(sig['nJet'])
bkg_njets = ak.to_numpy(bkg['nJet'])
ax.hist(sig_njets, bins=np.arange(0, 15), density=False, label="EMJ", histtype="step")
ax.hist(bkg_njets, bins=np.arange(0, 15), density=False, label="QCD", histtype="step")
ax.set_xlabel("Number of Jets")
ax.set_ylabel("A.U.")
ax.set_title("Jet Multiplicity Distribution")
ax.legend()
ax.grid(True)
plt.show()

In [None]:
# Plot pileup distribution
bins = 20
x_min = 0
x_max = nPU.max().item() + 1

# Plotting PU before resampling
sig_pu = data_utils.load_data(ds_cfg, "EMJ", filter_name="Pileup_nPU", entry_stop=-1)["Pileup_nPU"].to_numpy()
bkg_pu = data_utils.load_data(ds_cfg, "QCD", filter_name="Pileup_nPU", entry_stop=-1)["Pileup_nPU"].to_numpy()

fig, ax = plt.subplots()
ax.hist(sig_pu, bins=bins, range=(x_min, x_max), label="EMJ", histtype="step", density=True)
ax.hist(bkg_pu, bins=bins, range=(x_min, x_max), label="QCD", histtype="step", density=True)
ax.set_xlabel("Pileup")
ax.set_ylabel("A.U.")
ax.set_title("Pileup Distributions Before Resampling")
ax.grid(True)
ax.legend()
plt.show()
del(bkg_pu, sig_pu)

# After resampling
fig, ax = plt.subplots()
ax.hist(ak.to_numpy(sig["Pileup_nPU"]), bins=bins, range=(x_min, x_max), density=True, label="EMJ", histtype="step")
ax.hist(ak.to_numpy(bkg["Pileup_nPU"]), bins=bins, range=(x_min, x_max), density=True, label="QCD", histtype="step")
ax.set_xlabel("Number of Pileup")
ax.set_title("Pileup Distributions After Resampling")
ax.set_ylabel("A.U.")
ax.legend()
ax.grid(True)
plt.show()

In [None]:
sig_njets = ak.to_numpy(sig['nJet'])
bkg_njets = ak.to_numpy(bkg['nJet'])
sig_jets = sig[sig_njets < njets]
bkg_jets = bkg[bkg_njets < njets]

def plot_ftrs(sig, bkg, ftr_name, fig_title, nbins, xrange, xlabel, ylabel, log, normalized=True):
    sig_ftr = sig[ftr_name]
    bkg_ftr = bkg[ftr_name]
    
    fig, ax = plt.subplots()
    ax.hist(ak.flatten(sig_ftr), bins=nbins, range=xrange, density=normalized, alpha=0.5, label='EMJ', histtype='step', linewidth=1.5)
    ax.hist(ak.flatten(bkg_ftr), bins=nbins, range=xrange, density=normalized, alpha=0.5, label='QCD', histtype='step', linewidth=1.5)
    ax.set_title(fig_title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_yscale('log' if log else 'linear')
    ax.legend()
    ax.grid(True)
    plt.show()

In [None]:
plot_ftrs(sig_jets, bkg_jets, "Jet_pt", "Jet $p_T$", 20, (0, 3500), "Jet $p_T$ [GeV]", "Count", log=True, normalized=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_eta", "Jet $\eta$", 20, (-5, 5), "Jet $\eta$", "Count", log=False, normalized=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_phi", "Jet $\phi$", 20, (-3.14, 3.14), "Jet $\phi$", "Count", log=False, normalized=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_mass", "Jet Mass", 20, (0, 300), "Jet Mass [GeV]", "Count", log=True, normalized=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_nConstituents", "Jet Number of Constituents", 20, (0, 100), "Number of Constituents", "Count", log=True, normalized=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_nSVs", "Jet Number of Secondary Vertices", 10, (0, 10), "Number of Secondary Vertices", "Count", log=True, normalized=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_area", "Jet Area", 30, (0, 1), "Jet Area", "Count", log=True, normalized=False)

In [None]:
# Looking at feature correlation with PU
lead_jets = data_tensor[:, 0, :]
corr_coefs = {}
for i, ftr_name in enumerate(features):
    corr_matrix = torch.corrcoef(torch.stack((lead_jets[:, i], nPU), dim=0))
    corr_coefs[ftr_name] = corr_matrix[0, 1].item()
    print(f"Correlation coefficient between {ftr_name} and nPU: {corr_coefs[ftr_name]:.4f}")

## Baseline: Training NN on Inputs

We now setup and train a baseline model which consists of a simple feed-forward NN. This same architecture is used later on, but with the input dimension appropriately adapted to the expected inputs.

In [None]:
config_nn = {
    "n_epochs": 500,
    "learning_rate": 1e-3,
    "num_heads": 2,
    "num_layers": 1,
    "hidden_size": 4,
    "beta": 0.5,
    "patience": 20,
    "batch_size": batch_size,
}
pprint(config_nn)

In [None]:
metrics = metric.metrics()
bce_loss = torch.nn.BCEWithLogitsLoss()
sigmoid = torch.nn.Sigmoid()

classifier_nn = model.MLPClassifier(
    input_dim=len(features) * njets,
    hidden_dim=config_nn["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)

optimizer_nn = torch.optim.Adam(
    list(classifier_nn.parameters()),
    lr=config_nn["learning_rate"]
)
    
early_stopper = training.EarlyStopping(config_nn["patience"])

In [None]:
# Classifier train without encoder
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
wandb.init(
    project="ml4phys_finalproj",
    name=f"run_classifier_nn_baseline_{timestamp}",
    config=config_nn
)

early_stopper = training.EarlyStopping(patience=config_nn["patience"]) # New instance

for epoch in range(config_nn["n_epochs"]):
    metrics.reset()
    classifier_nn.train()
    correct = count = 0
    total_bce = 0.0
    for x, _, labels in train_loader:
        x = x.to(device)
        x = x.view(x.size(0), -1)

        labels = labels.to(device)
        logit = classifier_nn(x)

        loss_bce = bce_loss(logit.squeeze(), labels)
        optimizer_nn.zero_grad()
        loss_bce.backward()
        optimizer_nn.step()
        total_bce += loss_bce.item() * x.size(0)
        metrics.update(labels.cpu(), (sigmoid(logit).squeeze() > 0.5).long().cpu())
        
    avg_bce = total_bce / len(train_ds)
    accuracy, precision, recall, f1_score = metrics.compute()

    # Eval
    metrics.reset()
    correct_val = count_val = 0
    total_bce_val = 0.0
    classifier_nn.eval()
    with torch.no_grad():
        for x_val, _, labels_val in val_loader:
            x_val = x_val.to(device)
            x_val = x_val.view(x_val.size(0), -1)
            labels_val = labels_val.to(device)
            logit_val = classifier_nn(x_val)

            loss_bce_val = bce_loss(logit_val.squeeze(), labels_val)
            total_bce_val += loss_bce_val.item() * x_val.size(0)
            metrics.update(labels_val.cpu(), (sigmoid(logit_val).squeeze() > 0.5).long().cpu())
    
    avg_bce_val = total_bce_val / len(val_ds)
    accuracy_val, precision_val, recall_val, f1_score_val = metrics.compute()

    print(f"Epoch {epoch+1}/{config_nn['n_epochs']}")
    print(f"Loss: {avg_bce:.4f}, BCE: {avg_bce:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}")
    print(f"Val Loss: {avg_bce_val:.4f}, Val BCE: {avg_bce_val:.4f}, Val Accuracy: {accuracy_val:.4f}, Val Precision: {precision_val:.4f}, Val Recall: {recall_val:.4f}, Val F1 Score: {f1_score_val:.4f}")

    wandb.log({
        "epoch": epoch + 1,

        "bce_loss": avg_bce,
        "val_bce_loss": avg_bce_val,

        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "val_accuracy": accuracy_val,
        "val_precision": precision_val,
        "val_recall": recall_val,
        "val_f1_score": f1_score_val,
    })

    if early_stopper(avg_bce_val, classifier_nn):
        break

wandb.finish()

In [None]:
# Load the model from best checkpoint
classifier_nn.load_state_dict(early_stopper.best_model_state)

# Save best model 
torch.save(classifier_nn.state_dict(), "./models/classifier_nn.pth")

## Jointly Trained Encoder + Classifier

This first approach trains the transformer encoder and classifier jointly as a single model. It uses as the loss the linear combination of the Supervised Contrastive Loss and the BCE.

In [None]:
config_joint = {
    "n_epochs": 500,
    "learning_rate": 1e-3,
    "num_heads": 2,
    "num_layers": 1,
    "hidden_size": 4,
    "contrastive_temp": 0.07,
    "beta": 0.5,
    "patience": 20,
    "batch_size": batch_size,
}
pprint(config_joint)

In [None]:
metrics = metric.metrics()
bce_loss = torch.nn.BCEWithLogitsLoss()
contrastive_loss = training.SupConLoss(config_joint["contrastive_temp"])
sigmoid = torch.nn.Sigmoid()

encoder_joint = model.TransformerEncoder(
    num_features=num_ftrs,
    embed_size=config_joint["hidden_size"],
    num_heads=config_joint["num_heads"],
    num_layers=config_joint["num_layers"]
).to(device)

classifier_joint = model.MLPClassifier(
    input_dim=config_joint["hidden_size"],
    hidden_dim=config_joint["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)

optimizer_joint = torch.optim.Adam(
    list(encoder_joint.parameters()) + list(classifier_joint.parameters()),
    lr=config_joint["learning_rate"]
)

early_stopper = training.EarlyStopping(config_joint["patience"])

In [None]:
# Train loop
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
wandb.init(
    project="ml4phys_finalproj",
    name=f"run_jointarch_{timestamp}",
    config=config_joint
)

for epoch in range(config_joint["n_epochs"]):
    metrics.reset()
    encoder_joint.train()
    classifier_joint.train()

    correct = count = 0
    total_loss = total_bce = total_contrastive = 0.0
    for x, mask, labels in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        latent = encoder_joint(x, mask).to(device)
        logit = classifier_joint(latent)

        loss_bce = bce_loss(logit.squeeze(), labels)
        loss_contrastive = contrastive_loss(latent.unsqueeze(1), labels)
        loss = (1 - config_joint["beta"]) * loss_bce + config_joint["beta"] * loss_contrastive
        optimizer_joint.zero_grad()
        loss.backward()
        optimizer_joint.step()

        total_contrastive += loss_contrastive.item() * x.size(0)
        total_bce += loss_bce.item() * x.size(0)
        total_loss += loss.item() * x.size(0)
        metrics.update(labels.cpu(), (sigmoid(logit).squeeze() > 0.5).long().cpu())
        
    avg_loss = total_loss / len(train_ds)
    avg_contrastive = total_contrastive / len(train_ds)
    avg_bce = total_bce / len(train_ds)
    accuracy, precision, recall, f1_score = metrics.compute()

    # Eval
    metrics.reset()
    correct_val = count_val = 0
    total_loss_val = total_bce_val = total_contrastive_val = 0.0
    encoder_joint.eval()
    classifier_joint.eval()
    with torch.no_grad():
        for x_val, mask_val, labels_val in val_loader:
            x_val = x_val.to(device)
            mask_val = mask_val.to(device)
            labels_val = labels_val.to(device)

            latent_val = encoder_joint(x_val, mask_val).to(device)
            logit_val = classifier_joint(latent_val)

            loss_contrastive_val = contrastive_loss(latent_val.unsqueeze(1), labels_val)
            loss_bce_val = bce_loss(logit_val.squeeze(), labels_val)
            loss_val = (1 - config_joint["beta"]) * loss_bce_val + config_joint["beta"] * loss_contrastive_val
            total_contrastive_val += loss_contrastive_val.item() * x_val.size(0)
            total_bce_val += loss_bce_val.item() * x_val.size(0)
            total_loss_val += loss_val.item() * x_val.size(0)
        metrics.update(labels_val.cpu(), (sigmoid(logit_val).squeeze() > 0.5).long().cpu())

    avg_contrastive_val = total_contrastive_val / len(val_ds)
    avg_bce_val = total_bce_val / len(val_ds)
    avg_loss_val = total_loss_val / len(val_ds)
    accuracy_val, precision_val, recall_val, f1_score_val = metrics.compute()

    print(f"Epoch {epoch+1}/{config_joint['n_epochs']}")
    print(f"Loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}")
    print(f"Val Loss: {avg_loss_val:.4f}, Val BCE: {avg_bce_val:.4f}, Val Accuracy: {accuracy_val:.4f}, Val Precision: {precision_val:.4f}, Val Recall: {recall_val:.4f}, Val F1 Score: {f1_score_val:.4f}")

    wandb.log({
        "epoch": epoch + 1,

        "contrastive_loss": avg_contrastive,
        "bce_loss": avg_bce,
        "loss": avg_loss,
        "val_contrastive_loss": avg_contrastive_val,
        "val_bce_loss": avg_bce_val,
        "val_loss": avg_loss_val,

        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "val_accuracy": accuracy_val,
        "val_precision": precision_val,
        "val_recall": recall_val,
        "val_f1_score": f1_score_val,
    })

    if early_stopper(avg_loss_val, [encoder_joint, classifier_joint]):
        break
    
wandb.finish()

In [None]:
# Load the best model from early stopping
best_model_states = early_stopper.best_model_state
encoder_joint.load_state_dict(best_model_states[0])
classifier_joint.load_state_dict(best_model_states[1])

# Save
torch.save({
    "encoder_state": encoder_joint.state_dict(),
    "classifier_state": classifier_joint.state_dict(),
}, "./models/encoder_classifier_joint.pth")

## Encoder + Classifier w/ Encoder Pre-training

This second approach is the same as the one above, but the models are now trained separately: first the encoder is trained using the Supervised Constrastive Loss, then this component is frozen, and a small classifier NN is trained on the embeddings of this encoder.

In [None]:
config_pretr = {
    "n_epochs_encoder": 500,
    "n_epochs_classifier": 500,
    "learning_rate_encoder": 1e-3,
    "learning_rate_classifier": 1e-3,
    "num_heads": 2,
    "num_layers": 1,
    "hidden_size": 4,
    "contrastive_temp": 0.07,
    "patience_encoder": 20,
    "patience_classifier": 20,
    "batch_size": batch_size,
}
pprint(config_pretr)

In [None]:
metrics = metric.metrics()
bce_loss = torch.nn.BCEWithLogitsLoss()
contrastive_loss = training.SupConLoss(config_pretr["contrastive_temp"])
sigmoid = torch.nn.Sigmoid()

classifier = model.MLPClassifier(
    input_dim=config_pretr["hidden_size"],
    hidden_dim=config_pretr["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)

optimizer_classifier = torch.optim.Adam(
    classifier.parameters(),
    lr=config_pretr["learning_rate_classifier"]
)

In [None]:
encoder = model.TransformerEncoder(
    num_features=num_ftrs,
    embed_size=config_pretr["hidden_size"],
    num_heads=config_pretr["num_heads"],
    num_layers=config_pretr["num_layers"]
).to(device)

optimizer_encoder = torch.optim.Adam(
    encoder.parameters(),
    lr=config_pretr["learning_rate_encoder"]
)

# Encoder pretrain
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
wandb.init(
    project="ml4phys_finalproj",
    name=f"run_encoderpretrain_{timestamp}",
    config=config_pretr
)

early_stopper = training.EarlyStopping(config_pretr["patience_encoder"])
for epoch in range(config_pretr["n_epochs_encoder"]):
    encoder.train()
    total_contrastive = 0.0
    for x, mask, labels in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        labels = labels.to(device)
        latent = encoder(x, mask).to(device)

        loss_contrastive = contrastive_loss(latent.unsqueeze(1), labels)
        optimizer_encoder.zero_grad()
        loss_contrastive.backward()
        optimizer_encoder.step()
        total_contrastive += loss_contrastive.item() * x.size(0)
        
    avg_contrastive = total_contrastive / len(train_ds)

    # Eval
    encoder.eval()
    total_contrastive_val = 0.0
    with torch.no_grad():
        for x_val, mask_val, labels_val in val_loader:
            x_val = x_val.to(device)
            mask_val = mask_val.to(device)
            labels_val = labels_val.to(device)

            latent_val = encoder(x_val, mask_val).to(device)
            loss_contrastive_val = contrastive_loss(latent_val.unsqueeze(1), labels_val)
            total_contrastive_val += loss_contrastive_val.item() * x_val.size(0)

    avg_contrastive_val = total_contrastive_val / len(val_ds)

    print(f"Epoch {epoch+1}/{config_pretr['n_epochs_encoder']}")
    print(f"Contrastive Loss: {avg_contrastive:.4f}, Contrastive Loss Val: {avg_contrastive_val:.4f}")

    wandb.log({
        "epoch": epoch + 1,
        "contrastive_loss": avg_contrastive,
        "val_contrastive_loss": avg_contrastive_val,
    })
    
    if early_stopper(avg_contrastive_val, encoder):
        break

wandb.finish()

In [None]:
# Load the best model from the early stopping
encoder.load_state_dict(early_stopper.best_model_state)

# Save best model 
# torch.save(encoder.state_dict(), "./models/encoder.pth")

In [None]:
# Classifier train with pretrained encoder
wandb.init(
    project="ml4phys_finalproj",
    name=f"run_classifier_wpretrainedencoder_{timestamp}",
    config=config_pretr
)

encoder.eval()
early_stopper = training.EarlyStopping(patience=config_pretr["patience_classifier"]) # New instance

for epoch in range(config_pretr["n_epochs_classifier"]):
    metrics.reset()
    classifier.train()
    correct = count = 0
    total_bce = 0.0
    for x, mask, labels in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        labels = labels.to(device)
        latent = encoder(x, mask).to(device)
        logit = classifier(latent)

        loss_bce = bce_loss(logit.squeeze(), labels)
        optimizer_classifier.zero_grad()
        loss_bce.backward()
        optimizer_classifier.step()
        total_bce += loss_bce.item() * x.size(0)
        metrics.update(labels.cpu(), (sigmoid(logit).squeeze() > 0.5).long().cpu())
        
    avg_bce = total_bce / len(train_ds)
    accuracy, precision, recall, f1_score = metrics.compute()

    # Eval
    metrics.reset()
    correct_val = count_val = 0
    total_bce_val = 0.0
    classifier.eval()
    with torch.no_grad():
        for x_val, mask_val, labels_val in val_loader:
            x_val = x_val.to(device)
            mask_val = mask_val.to(device)
            labels_val = labels_val.to(device)
            latent_val = encoder(x_val, mask_val).to(device)
            logit_val = classifier(latent_val)

            loss_bce_val = bce_loss(logit_val.squeeze(), labels_val)
            total_bce_val += loss_bce_val.item() * x_val.size(0)
            metrics.update(labels_val.cpu(), (sigmoid(logit_val).squeeze() > 0.5).long().cpu())
    
    avg_bce_val = total_bce_val / len(val_ds)
    accuracy_val, precision_val, recall_val, f1_score_val = metrics.compute()

    print(f"Epoch {epoch+1}/{config_pretr['n_epochs_classifier']}")
    print(f"Loss: {avg_bce:.4f}, BCE: {avg_bce:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}")
    print(f"Val Loss: {avg_bce_val:.4f}, Val BCE: {avg_bce_val:.4f}, Val Accuracy: {accuracy_val:.4f}, Val Precision: {precision_val:.4f}, Val Recall: {recall_val:.4f}, Val F1 Score: {f1_score_val:.4f}")

    wandb.log({
        "epoch": epoch + 1,

        "bce_loss": avg_bce,
        "val_bce_loss": avg_bce_val,

        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "val_accuracy": accuracy_val,
        "val_precision": precision_val,
        "val_recall": recall_val,
        "val_f1_score": f1_score_val,
    })

    if early_stopper(avg_bce_val, classifier):
        break

wandb.finish()

In [None]:
# Load the model from best checkpoint
classifier.load_state_dict(early_stopper.best_model_state)

# Save best model 
torch.save({
    "encoder_state": encoder.state_dict(),
    "classifier_state": classifier.state_dict(),    
}, "./models/encoder_classifier_pretrained.pth")

## Evaluation

The models are re-instantiated, and the saved states are loaded.

In [None]:
# # Loading models
# Baseline NN
classifier_nn = model.MLPClassifier(
    input_dim=len(features) * njets,
    hidden_dim=config_nn["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)
classifier_nn.load_state_dict(torch.load("./models/classifier_nn.pth"))

# Joint architecture
checkpoint_joint = torch.load("./models/encoder_classifier_joint.pth")
encoder_joint = model.TransformerEncoder(
    num_features=num_ftrs,
    embed_size=config_joint["hidden_size"],
    num_heads=config_joint["num_heads"],
    num_layers=config_joint["num_layers"]
).to(device)
encoder_joint.load_state_dict(checkpoint_joint["encoder_state"])

classifier_joint = model.MLPClassifier(
    input_dim=config_joint["hidden_size"],
    hidden_dim=config_joint["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)
classifier_joint.load_state_dict(checkpoint_joint["classifier_state"])

# Pretrained encoder + classifier
checkpoint_pretr = torch.load("./models/encoder_classifier_pretrained.pth")
encoder_pretr = model.TransformerEncoder(
    num_features=num_ftrs,
    embed_size=config_pretr["hidden_size"],
    num_heads=config_pretr["num_heads"],
    num_layers=config_pretr["num_layers"]
).to(device)
encoder_pretr.load_state_dict(checkpoint_pretr["encoder_state"])

classifier_pretr = model.MLPClassifier(
    input_dim=config_pretr["hidden_size"],
    hidden_dim=config_pretr["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)
classifier_pretr.load_state_dict(checkpoint_pretr["classifier_state"])

sigmoid = torch.nn.Sigmoid()

We now evaluate the trained models on the previously untouched test dataset.

In [None]:
def get_embeddings_and_logits(encoder, classifier, dataloader):
    metrics = metric.metrics()
    encoder.eval()
    classifier.eval()
    all_embeddings = []
    all_logits = []
    all_labels = []
    with torch.no_grad():
        for x, mask, labels in dataloader:
            x = x.to(device)
            mask = mask.to(device)
            labels = labels.to(device)

            latent = encoder(x, mask).to(device)
            logit = classifier(latent)

            all_embeddings.append(latent.cpu())
            all_logits.append(logit.cpu())
            all_labels.append(labels.cpu())

            print(logit)

            metrics.update(labels.cpu().numpy(), sigmoid(logit).float().cpu())

    print(metrics.preds)
    print(metrics.labels)
    accuracy, precision, recall, f1_score = metrics.compute()
    return torch.cat(all_embeddings), torch.cat(all_logits), torch.cat(all_labels)

print("Performance metrics for jointly trained model...")
embeddings_joint, logits_joint, labels_joint = get_embeddings_and_logits(encoder_joint, classifier_joint, test_loader)
print("Performance metrics for pretrained encoder + classifier model...")
embeddings_pretr, logits_pretr, labels_pretr = get_embeddings_and_logits(encoder_pretr, classifier_pretr, test_loader)

In [None]:
# Get logits for baseline NN
def get_logits_nn(classifier, dataloader):
    classifier.eval()
    all_logits = []
    all_labels = []
    with torch.no_grad():
        for x, _, labels in dataloader:
            x = x.to(device)
            x = x.view(x.size(0), -1)
            labels = labels.to(device)

            logit = classifier(x)

            all_logits.append(logit.cpu())
            all_labels.append(labels.cpu())
    return torch.cat(all_logits), torch.cat(all_labels)

logits_nn, labels_nn = get_logits_nn(classifier_nn, test_loader)

In [None]:
prob_joint = torch.sigmoid(logits_joint).squeeze().numpy()
prob_pretr = torch.sigmoid(logits_pretr).squeeze().numpy()
prob_nn = torch.sigmoid(logits_nn).squeeze().numpy()

In [None]:
test_data_tensor = test_ds.data.to(device)
test_data_labels = test_ds.labels.to(device)

# Get PCA
pca = PCA(n_components=4)
pca_joint = pca.fit_transform(embeddings_joint.cpu().numpy())
pca_pretr = pca.fit_transform(embeddings_pretr.cpu().numpy())

embeddings_joint_df = pd.DataFrame({
    "label": test_data_labels.cpu().numpy(),
    "embed_1": embeddings_joint[0:, 0].cpu().numpy(),
    "embed_2": embeddings_joint[0:, 1].cpu().numpy(),
    "embed_3": embeddings_joint[0:, 2].cpu().numpy(),
    "embed_4": embeddings_joint[0:, 3].cpu().numpy(),
    "PCA_1": pca_joint[:, 0],
    "PCA_2": pca_joint[:, 1],
    "PCA_3": pca_joint[:, 2],
    "PCA_4": pca_joint[:, 3],
})
embeddings_pretr_df = pd.DataFrame({
    "label": test_data_labels.cpu().numpy(),
    "embed_1": embeddings_pretr[0:, 0].cpu().numpy(),
    "embed_2": embeddings_pretr[0:, 1].cpu().numpy(),
    "embed_3": embeddings_pretr[0:, 2].cpu().numpy(),
    "embed_4": embeddings_pretr[0:, 3].cpu().numpy(),
    "PCA_1": pca_pretr[:, 0],
    "PCA_2": pca_pretr[:, 1],
    "PCA_3": pca_pretr[:, 2],
    "PCA_4": pca_pretr[:, 3],
})

# Convert labels to str (EJ = 1, QCD = 0)
embeddings_joint_df["label"] = embeddings_joint_df["label"].map({1: "EJ", 0: "QCD"})
embeddings_pretr_df["label"] = embeddings_pretr_df["label"].map({1: "EJ", 0: "QCD"})

In [None]:
sns.pairplot(
    embeddings_joint_df,
    vars=["PCA_1", "PCA_2", "PCA_3", "PCA_4"],
    hue="label",
    diag_kind="hist",
    plot_kws={"alpha": 0.5},
    diag_kws={"bins": 30, "alpha": 0.5}
)
plt.suptitle("Joint Model Embeddings PCA Pairplot", y=1.02)
plt.show()

In [None]:
sns.pairplot(
    embeddings_pretr_df,
    vars=["PCA_1", "PCA_2", "PCA_3", "PCA_4"],
    hue="label",
    diag_kind="hist",
    plot_kws={"alpha": 0.5},
    diag_kws={"bins": 30, "alpha": 0.5}
)
plt.suptitle("Pretrained Encoder Model Embeddings PCA Pairplot", y=1.02)
plt.show()

ROC Curves

In [None]:
fpr_joint, tpr_joint, _ = roc_curve(labels_joint.numpy(), torch.sigmoid(logits_joint).squeeze().numpy())
fpr_pretr, tpr_pretr, _ = roc_curve(labels_pretr.numpy(), torch.sigmoid(logits_pretr).squeeze().numpy())
fpr_nn, tpr_nn, _ = roc_curve(labels_nn.numpy(), torch.sigmoid(logits_nn).squeeze().numpy())

auc_joint = auc(fpr_joint, tpr_joint)
auc_pretr = auc(fpr_pretr, tpr_pretr)
auc_nn = auc(fpr_nn, tpr_nn)

# Plot ROC curves
fig, ax = plt.subplots()
ax.plot(tpr_joint, 1 - fpr_joint, label=f"Joint AUC: {auc_joint:.2f}")
ax.plot(tpr_pretr, 1 - fpr_pretr, label=f"Pretrained AUC: {auc_pretr:.2f}")
ax.plot(tpr_nn, 1 - fpr_nn, label=f"NN Baseline AUC: {auc_nn:.2f}")
# Diagonal line
ax.plot([1, 0], [0, 1], linestyle="--", color="gray")
ax.set_xlabel("Signal Efficiency")
ax.set_ylabel("Background Rejection")
ax.set_title("ROC Curves on Test Set")
ax.legend()
ax.grid(True)
plt.show()

In [None]:
# Plot scores
def plot_scores(logits, labels, title, ylog=False, nbins=50):
    scores = torch.sigmoid(logits).squeeze().numpy()
    labels = labels.numpy()
    sig_scores = scores[labels == 1]
    bkg_scores = scores[labels == 0]

    fig, ax = plt.subplots()
    ax.hist(sig_scores, bins=nbins, range=(0, 1), density=False, label="EJ", histtype="step", linewidth=1.5)
    ax.hist(bkg_scores, bins=nbins, range=(0, 1), density=False, label="QCD", histtype="step", linewidth=1.5)
    ax.set_title(title)
    ax.set_xlabel("Classifier Score")
    ax.set_ylabel("Count")
    ax.set_yscale('log' if ylog else 'linear')
    ax.legend()
    ax.grid(True)
    plt.show()

plot_scores(logits_joint, labels_joint, "Joint Model Classifier Scores", nbins=25)
plot_scores(logits_pretr, labels_pretr, "Pretrained Encoder Model Classifier Scores", nbins=25)
plot_scores(logits_nn, labels_nn, "NN Baseline Classifier Scores", nbins=25)

plot_scores(logits_joint, labels_joint, "Joint Model Classifier Scores", ylog=True, nbins=25)
plot_scores(logits_pretr, labels_pretr, "Pretrained Encoder Model Classifier Scores", ylog=True, nbins=25)
plot_scores(logits_nn, labels_nn, "NN Baseline Classifier Scores", ylog=True, nbins=25)