# PHYS805 Final Project Notebook

- Overview:
    - Supervised
        - Use transformer over jets
        - Attach classifier head
        - Train en-to-end to predict signal vs background using CE loss
    - Self-supervised + supervised probe
        - Pretrain: 
            - First train using InfoNCE. 
            - No labels, no pretext objective. 
            - Just finding good representation. 
            - Would NOT use labels here.
        - Freeze transformer encoder. Then train small classification head to predict signal vs background. Would use labels here.

<!-- ```
ml4phys$ eosls /store/group/lpcemj/EMJAnalysis2025
QCD_PT-1000to1400_TuneCP5_13p6TeV_pythia8
QCD_PT-100to1400_TuneCP5_13p6TeV_pythia8
QCD_PT-120to170_TuneCP5_13p6TeV_pythia8
QCD_PT-1400to1800_TuneCP5_13p6TeV_pythia8
QCD_PT-15to30_TuneCP5_13p6TeV_pythia8
QCD_PT-170to300_TuneCP5_13p6TeV_pythia8
QCD_PT-1800to2400_TuneCP5_13p6TeV_pythia8
QCD_PT-2400to3200_TuneCP5_13p6TeV_pythia8
QCD_PT-300to470_TuneCP5_13p6TeV_pythia8
QCD_PT-30to50_TuneCP5_13p6TeV_pythia8
QCD_PT-3200_TuneCP5_13p6TeV_pythia8
QCD_PT-470to600_TuneCP5_13p6TeV_pythia8
QCD_PT-50to80_TuneCP5_13p6TeV_pythia8
QCD_PT-600to800_TuneCP5_13p6TeV_pythia8
QCD_PT-800to1000_TuneCP5_13p6TeV_pythia8
QCD_PT-80to120_TuneCP5_13p6TeV_pythia8
``` -->

In [None]:
import awkward as ak
import numpy as np
import matplotlib.pyplot as plt
import yaml
import torch
import wandb
from sklearn.model_selection import train_test_split
import mplhep as hep
import datetime
from pprint import pprint
plt.style.use(hep.style.CMS)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# auto reload of imported modules
%load_ext autoreload
%autoreload 2

from utils import data_utils
from utils import dataloader
from utils import metrics
from utils import model
from utils import training

In [None]:
# Data config stuff
with open("datasets.yaml", "r") as f:
    ds_cfg = yaml.safe_load(f)

features = [
    "Jet_pt",
    "Jet_eta",
    "Jet_phi",
    "Jet_mass",
    "Jet_nConstituents",
    "Jet_nSVs",
    "Jet_area",
]
other_branches = [
    "nJet",
    "Pileup_nPU",
]
branches = features + other_branches
num_ftrs = len(features)

test_split = 0.2
val_split = 0.5
njets = 10

In [None]:
max_events_per_class = 20_000
sig = data_utils.load_data(ds_cfg, "EMJ", filter_name=branches, entry_stop=max_events_per_class)
bkg = data_utils.load_data(ds_cfg, "QCD", filter_name=branches, entry_stop=max_events_per_class)

nPU = torch.tensor(
    np.concatenate([
        ak.to_numpy(sig["Pileup_nPU"]),
        ak.to_numpy(bkg["Pileup_nPU"])
    ]),
    dtype=torch.float32
)

sig_tensor = data_utils.ak_to_torch(sig, features, njets, label=1)
bkg_tensor = data_utils.ak_to_torch(bkg, features, njets, label=0)
data_tensor = torch.cat([sig_tensor, bkg_tensor], dim=0)
rnd_idx = torch.randperm(data_tensor.size(0))
data_tensor = data_tensor[rnd_idx]
nPU = nPU[rnd_idx]

In [None]:
# Split
X_train, X_temp, y_train, y_temp = train_test_split(
    data_tensor[..., :-1],
    data_tensor[:, 0, -1],
    test_size=test_split,
    shuffle=True,
)

# test and val split
X_val, X_test, y_val, y_test = train_test_split(
    X_temp,
    y_temp,
    test_size= val_split
)

print(y_train.shape, y_val.shape, y_test.shape)
print(X_train.shape, X_val.shape, X_test.shape)

In [None]:
norm_constants = data_utils.compute_norm_constants(X_train)
train_ds = dataloader.JetDataset(X_train, norm_constants, y_train)
val_ds = dataloader.JetDataset(X_val, norm_constants, y_val)
test_ds = dataloader.JetDataset(X_test, norm_constants, y_test)

## Features Study

In [None]:
# Plot njets distribution
fig, ax = plt.subplots()
sig_njets = ak.to_numpy(sig['nJet'])
bkg_njets = ak.to_numpy(bkg['nJet'])
ax.hist(sig_njets, bins=np.arange(0, 15), density=True, label='Signal (EMJ)', histtype='step')
ax.hist(bkg_njets, bins=np.arange(0, 15), density=True, label='Background (QCD)', histtype='step')
ax.set_xlabel('Number of Jets')
ax.set_ylabel('A.U.')
ax.set_title('Jet Multiplicity Distribution')
ax.legend()
ax.grid(True)
plt.show()

In [None]:
# Plot pileup distribution
bins = 20
x_min = 0
x_max = nPU.max().item() + 1
fig, ax = plt.subplots()
sig_pu = ak.to_numpy(sig['Pileup_nPU'])
bkg_pu = ak.to_numpy(bkg['Pileup_nPU'])
ax.hist(sig_pu, bins=bins, range=(x_min, x_max), density=True, label='Signal (EMJ)', histtype='step')
ax.hist(bkg_pu, bins=bins, range=(x_min, x_max), density=True, label='Background (QCD)', histtype='step')
ax.set_xlabel('Number of Pileup')
ax.set_ylabel('A.U.')
ax.set_title('Pileup Distribution')
ax.legend()
ax.grid(True)
plt.show()

In [None]:
# Plot pt distrib using awkward arrays
# Plot njets distribution
sig_njets = ak.to_numpy(sig['nJet'])
bkg_njets = ak.to_numpy(bkg['nJet'])
sig_jets = sig[sig_njets > njets]
bkg_jets = bkg[bkg_njets > njets]

def plot_ftrs(sig, bkg, ftr_name, fig_title, nbins, xrange, xlabel, ylabel, log, normalized=True):
    sig_ftr = sig[ftr_name]
    bkg_ftr = bkg[ftr_name]
    
    fig, ax = plt.subplots()
    ax.hist(ak.flatten(sig_ftr), bins=nbins, range=xrange, density=normalized, alpha=0.5, label='EMJ', histtype='step', linewidth=1.5)
    ax.hist(ak.flatten(bkg_ftr), bins=nbins, range=xrange, density=normalized, alpha=0.5, label='QCD', histtype='step', linewidth=1.5)
    ax.set_title(fig_title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_yscale('log' if log else 'linear')
    ax.legend()
    ax.grid(True)
    plt.show()

In [None]:
plot_ftrs(sig_jets, bkg_jets, "Jet_pt", "Jet $p_T$ Distribution", 20, (0, 2000), "Jet $p_T$ [GeV]", "Count", log=True)
plot_ftrs(sig_jets, bkg_jets, "Jet_eta", "Jet $\eta$ Distribution", 20, (-5, 5), "Jet $\eta$", "Count", log=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_phi", "Jet $\phi$ Distribution", 20, (-3.14, 3.14), "Jet $\phi$", "Count", log=False)
plot_ftrs(sig_jets, bkg_jets, "Jet_mass", "Jet Mass Distribution", 20, (0, 250), "Jet Mass [GeV]", "Count", log=True)
plot_ftrs(sig_jets, bkg_jets, "Jet_nConstituents", "Jet Number of Constituents Distribution", 20, (0, 20), "Number of Constituents", "Count", log=True)
plot_ftrs(sig_jets, bkg_jets, "Jet_nSVs", "Jet Number of Secondary Vertices Distribution", 20, (0, 20), "Number of Secondary Vertices", "Count", log=True)
plot_ftrs(sig_jets, bkg_jets, "Jet_area", "Jet Area Distribution", 30, (0, 1), "Jet Area", "Count", log=True)

In [None]:
# Looking at feature correlation with PU
# Use torch.corrcoef

lead_jets = data_tensor[:, 0, :]
corr_coefs = {}
for i, ftr_name in enumerate(features):
    corr_matrix = torch.corrcoef(torch.stack((lead_jets[:, i], nPU), dim=0))
    corr_coefs[ftr_name] = corr_matrix[0, 1].item()
    print(f"Correlation coefficient between {ftr_name} and nPU: {corr_coefs[ftr_name]:.4f}")

## Model Instantiation & Training

In [None]:
config = {
    "n_epochs": 20,
    "learning_rate": 1e-3,
    "batch_size": 512,
    "num_heads": 2,
    "num_layers": 1,
    "hidden_size": 4,
    "infonce_temp": 0.07,
    "beta": 0.5
}
pprint(config)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False)

encoder = model.TransformerEncoder(
    num_features=num_ftrs,
    embed_size=config["hidden_size"],
    num_heads=config["num_heads"],
    num_layers=config["num_layers"]
).to(device)

classifier = model.MLPClassifier(
    input_dim=config["hidden_size"],
    hidden_dim=config["hidden_size"],
    hidden_layers=2,
    output_dim=1
).to(device)

bce_loss = torch.nn.BCEWithLogitsLoss()
infonce_loss = training.InfoNCELoss(config["infonce_temp"])
sigmoid = torch.nn.Sigmoid()
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(classifier.parameters()),
    lr=config["learning_rate"]
)
metrics = metrics.metrics()

print("")
pprint("Classifier arch:")
pprint(classifier)
print("")
pprint("Encoder arc:")
pprint(encoder)

In [None]:
# Train loop
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
wandb.init(
    project="ml4phys_finalproj",
    name=f"run_encoderclassifier_{timestamp}",
    config=config
)

for epoch in range(config["n_epochs"]):
    metrics.reset()
    encoder.train()
    classifier.train()

    correct = count = 0
    total_loss = total_bce = total_infonce = 0.0
    for x, mask, labels in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        latent = encoder(x, mask).to(device)
        logit = classifier(latent)

        loss_infonce = infonce_loss(latent, labels)
        loss_bce = bce_loss(logit.squeeze(), labels)
        loss = config["beta"] * loss_bce + (1 - config["beta"]) * loss_infonce

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metrics.update(
            labels.cpu(), 
            (sigmoid(logit).squeeze() > 0.5).long().cpu()
        )

        total_loss += loss.item() * x.size(0)
        total_bce += loss_bce.item() * x.size(0)
        total_infonce += loss_infonce.item() * x.size(0)
        
    avg_loss = total_loss / len(train_ds)
    avg_bce = total_bce / len(train_ds)
    avg_infonce = total_infonce / len(train_ds)
    accuracy, precision, recall, f1_score = metrics.compute()

    # Eval
    metrics.reset()
    correct_val = count_val = 0
    total_loss_val = total_bce_val = total_infonce_val = 0.0
    encoder.eval()
    classifier.eval()
    with torch.no_grad():
        for x_val, mask_val, labels_val in val_loader:
            x_val = x_val.to(device)
            mask_val = mask_val.to(device)
            labels_val = labels_val.to(device)

            latent_val = encoder(x_val, mask_val).to(device)
            logit_val = classifier(latent_val)

            loss_infonce_val = infonce_loss(latent_val, labels_val)
            loss_bce_val = bce_loss(logit_val.squeeze(), labels_val)
            loss_val = config["beta"] * loss_bce_val + (1 - config["beta"]) * loss_infonce_val

            total_loss_val += loss_val.item() * x_val.size(0)
            total_bce_val += loss_bce_val.item() * x_val.size(0)
            total_infonce_val += loss_infonce_val.item() * x_val.size(0)

        metrics.update(labels_val.cpu(), (sigmoid(logit_val).squeeze() > 0.5).long().cpu())

    avg_loss_val = total_loss_val / len(val_ds)
    avg_bce_val = total_bce_val / len(val_ds)
    avg_infonce_val = total_infonce_val / len(val_ds)
    accuracy_val, precision_val, recall_val, f1_score_val = metrics.compute()

    print(f"Epoch {epoch+1}/{config['n_epochs']}")
    print(f"Loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, InfoNCE: {avg_infonce:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}")
    print(f"Val Loss: {avg_loss_val:.4f}, Val BCE: {avg_bce_val:.4f}, Val InfoNCE: {avg_infonce_val:.4f}, Val Accuracy: {accuracy_val:.4f}, Val Precision: {precision_val:.4f}, Val Recall: {recall_val:.4f}, Val F1 Score: {f1_score_val:.4f}")

    wandb.log({
        "epoch": epoch + 1,
        "loss": avg_loss,
        "bce_loss": avg_bce,
        "infonce_loss": avg_infonce,
        "accuracy": accuracy,
        "val_loss": avg_loss_val,
        "val_bce_loss": avg_bce_val,
        "val_infonce_loss": avg_infonce_val,
        "val_accuracy": accuracy_val,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "val_precision": precision_val,
        "val_recall": recall_val,
        "val_f1_score": f1_score_val,
    })
wandb.finish()

In [None]:
# Run evaluation
true_labels = []
predicted_probs = []

metrics.reset()
encoder.eval()
classifier.eval()

with torch.no_grad():
    for x_test, mask_test, label_test in test_loader:
        x_test = x_test.to(device)
        mask_test = mask_test.to(device)
        label_test = label_test.to(device)

        latent_test = encoder(x_test, mask_test)
        logit_test = classifier(latent_test)
        probs = sigmoid(logit_test).squeeze()
        preds = (probs > 0.5).long()

        metrics.update(label_test.cpu(), preds.cpu())
        true_labels.extend(label_test.cpu().tolist())
        predicted_probs.extend(probs.cpu().tolist())

accuracy_test, precision_test, recall_test, f1_score_test = metrics.compute()
print(f"Test Accuracy: {accuracy_test:.4f}, Test Precision: {precision_test:.4f}, Test Recall: {recall_test:.4f}, Test F1 Score: {f1_score_test:.4f}")

In [None]:
utils.plot_roc(true_labels, predicted_probs)
utils.plot_scores(
    sig_scores=[pred for pred, label in zip(predicted_probs, true_labels) if label == 1],
    bkg_scores=[pred for pred, label in zip(predicted_probs, true_labels) if label == 0],
    bins=50,
    range=(0,1),
    logscale=True
)

---

Junk code

In [None]:
# Compute per-bin feature correlation with PU
PU_bins = torch.linspace(0, nPU.max().item()+1, steps=51)
per_feature_bin_corrs = [] # Eventual shape: (num_features, num_bins)
lead_jet_ftrs = data_tensor[:, 0, :]
for i, ftr_name in enumerate(features):
    bin_corrs = []
    for j in range(len(PU_bins)-1):
        bin_mask = (nPU >= PU_bins[j]) & (nPU < PU_bins[j+1])
        if bin_mask.sum() < 2:
            bin_corrs.append(0.0)
            continue
        bin_ftr_values = lead_jet_ftrs[bin_mask, i]
        bin_nPU_values = nPU[bin_mask]
        corr_matrix = torch.corrcoef(torch.stack((bin_ftr_values, bin_nPU_values), dim=0))
        bin_corrs.append(corr_matrix[0, 1].item())
    per_feature_bin_corrs.append(bin_corrs)
per_feature_bin_corrs = torch.tensor(per_feature_bin_corrs)
print(per_feature_bin_corrs.shape)

# Heat map where x-axis is PU bins, y-axis is features, color is correlation coefficient
# Make them two plots: top = nPU histogram, bottom = heatmap
fig, ax = plt.subplots(figsize=(10, 6))
cax = ax.imshow(per_feature_bin_corrs.numpy(), aspect='auto', cmap='coolwarm', vmin=-1, vmax=1)
ax.set_xticks(np.arange(len(PU_bins)-1))
ax.set_xticklabels([f"{PU_bins[i]:.1f}-{PU_bins[i+1]:.1f}" for i in range(len(PU_bins)-1)], rotation=45)
ax.set_yticks(np.arange(len(features)))
ax.set_yticklabels([ftr.replace("Jet_", "") for ftr in features])
ax.set_xlabel('nPU Bins')
ax.set_ylabel('Features')
ax.set_title('Feature Correlation with Pileup across nPU Bins')
fig.colorbar(cax, label='Correlation Coefficient')
plt.show()

# Plot nPU histogram
plt.hist(nPU.numpy(), bins=PU_bins.numpy(), alpha=0.7, color='blue', histtype='step')
plt.xlabel('Number of Primary Vertices (nPU)')
plt.ylabel('Frequency')
plt.title('Histogram of nPU')
plt.show()