# Classification power and AD with varying jet $p_T$ thresholds

In this notebook we take a preprocessed file and test the how well can the data be classified by varying minimum jet $p_T$ thresholds. The plots get stored in a folder in google drive.

In [1]:
# Standard library
import os
import math
import json
import pickle
from datetime import datetime

# Third-party scientific stack
import numpy as np
import pandas as pd
import itertools

# Plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Graph utilities
import networkx as nx

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA

# Dimensionality reduction
# import umap

# Colab drive (only needed if you actually run in Google Colab)
# from google.colab import drive
# drive.mount('/content/drive')


All pytorch related imports done in the section GNN supervised Classification

In [2]:
# Set plotting style at module level
plt.rcParams.update({
    # Font sizes
    'font.size': 18,
    'axes.labelsize': 18,
    'axes.titlesize': 18,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 16,
    'legend.frameon': False,  # No box around legend
    'axes.grid': False,
    # Tick settings
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 10,
    'ytick.major.size': 10,
    'xtick.minor.size': 5,
    'ytick.minor.size': 5,
    'xtick.major.width': 1,
    'ytick.major.width': 1,
    'xtick.top': True,
    'ytick.right': True,
    'xtick.minor.visible': True,
    'ytick.minor.visible': True
})

In [3]:
with open("/kaggle/input/balanced-with-dr-or/balanced_dfs_no_dup_OR.pkl", "rb") as f:
    ML_dict = pickle.load(f)


## Details of the data
ML_dict is a dictionary with the dataframes

```
'all_signals', 'HAHMggf', 'HNLeemu', 'HtoSUEP',
'VBF_H125_a55a55_4b_ctau1_filtered', 'Znunu',
'ggF_H125_a16a16_4b_ctau10_filtered', 'hh_bbbb_vbf_novhh_5fs_l1cvv1cv1'
```
All of them have the same columns:

```
'j0pt', 'j0eta', 'j0phi', 'j1pt', 'j1eta', 'j1phi', 'j2pt', 'j2eta',
       'j2phi', 'j3pt', 'j3eta', 'j3phi', 'j4pt', 'j4eta', 'j4phi', 'j5pt',
       'j5eta', 'j5phi', 'e0pt', 'e0eta', 'e0phi', 'e1pt', 'e1eta', 'e1phi',
       'e2pt', 'e2eta', 'e2phi', 'mu0pt', 'mu0eta', 'mu0phi', 'mu1pt',
       'mu1eta', 'mu1phi', 'mu2pt', 'mu2eta', 'mu2phi', 'ph0pt', 'ph0eta',
       'ph0phi', 'ph1pt', 'ph1eta', 'ph1phi', 'ph2pt', 'ph2eta', 'ph2phi',
       'METpt', 'METeta', 'METphi', 'run_number', 'event_number', 'weight',
       'target'
```
When loaded with `balanced_dfs_no_dup_processed.pkl`, the dataframes contain events for which there are no duplicate objects. Events with undefined METpt have been removed. All events where all objects have 0 pt have been removed. All of them have equal amount of signal and background ('target' == 'EB_test').

When loaded with `balanced_dfs_no_dup_OR.pkl`, in addition to the above mentioned processing, overlap removal has also been performed.

## Random Forest

In [None]:
# Consistent style
plt.rcParams['figure.figsize'] = (8,6)
plt.rcParams['font.size'] = 12

# Jet-pt thresholds you want to test
JET_PT_THRESHOLDS = [5, 15, 25, 45, 60, 80]

# Which jet columns to use
JET_PT_COLS = [f"j{i}pt" for i in range(6)]   # j0pt ... j5pt


In [None]:
PLOT_BASE_DIR = "/content/drive/MyDrive/Datasets/plots_with_OR"
os.makedirs(PLOT_BASE_DIR, exist_ok=True)

### Define all the necessary helper functions

In [None]:
def apply_jet_pt_threshold(df, threshold):
    """
    Returns a filtered dataframe where all nonzero jets have pt >= threshold.
    Condition: for each jet jX,
      keep event if (jXpt == 0) or (jXpt >= threshold)
    """
    mask = np.ones(len(df), dtype=bool)
    for col in JET_PT_COLS:
        if col in df.columns:
            mask &= (df[col] == 0) | (df[col] >= threshold)
    return df[mask]


In [None]:
def prepare_dataset(df):
    """
    Drops unwanted columns, splits dataset once for reproducibility.
    """
    # Features = all physics columns except bookkeeping
    drop_cols = ["run_number", "event_number", "target", "weight"]
    feature_cols = [c for c in df.columns if c not in drop_cols]

    X = df[feature_cols].copy()
    y = (df["target"] == "EB_test").astype(int)

    # Single train/test split to be reused for all thresholds
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, stratify=y, random_state=42
    )

    return X_train, X_test, y_train, y_test, feature_cols


In [None]:
def train_and_evaluate_rf(X_train, y_train, X_test, y_test):
    """Trains a simple, stable Random Forest and returns ROC curve + AUC."""

    rf = RandomForestClassifier(
        n_estimators=300,
        max_depth=None,
        min_samples_split=2,
        random_state=42,
        n_jobs=-1
    )
    rf.fit(X_train, y_train)

    # Probabilities for ROC
    y_score = rf.predict_proba(X_test)[:, 1]

    fpr, tpr, _ = roc_curve(y_test, y_score)
    roc_auc = auc(fpr, tpr)

    return rf, fpr, tpr, roc_auc


In [None]:
def plot_roc_curves(roc_results, dataset_name, save_dir):
    plt.figure()
    for T, (fpr, tpr, roc_auc) in roc_results.items():
        plt.plot(fpr, tpr, label=f"T={T} GeV (AUC={roc_auc:.3f})")

    # plt.plot([0,1], [0,1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curves — {dataset_name}")
    plt.legend()
    plt.grid(True)

    plt.savefig(os.path.join(save_dir, "roc_curves.png"), dpi=200, bbox_inches="tight")
    plt.close()


In [None]:
def plot_auc_vs_threshold(roc_results, dataset_name, save_dir):
    thresholds = list(roc_results.keys())
    auc_vals = [roc_results[T][2] for T in thresholds]

    plt.figure()
    plt.plot(thresholds, auc_vals, marker="o")
    plt.xlabel("Jet $p_T$ Threshold [GeV]")
    plt.ylabel("AUC")
    plt.title(f"AUC vs Jet $p_T$ Threshold — {dataset_name}")
    plt.grid(True)

    plt.savefig(os.path.join(save_dir, "auc_vs_threshold.png"), dpi=200, bbox_inches="tight")
    plt.close()


In [None]:
def plot_event_yields(counts_sig, counts_bkg, dataset_name, save_dir):
    plt.figure()
    plt.plot(list(counts_sig.keys()), list(counts_sig.values()),
             marker="o", label="Signal")
    plt.plot(list(counts_bkg.keys()), list(counts_bkg.values()),
             marker="s", label="Background")

    plt.xlabel("Jet $p_T$ Threshold [GeV]")
    plt.ylabel("Events passing selection")
    plt.title(f"Event Yields — {dataset_name}")
    plt.legend()
    plt.grid(True)

    plt.savefig(os.path.join(save_dir, "event_yields.png"), dpi=200, bbox_inches="tight")
    plt.close()


In [None]:
def compare_feature_importances(
        rf_low, rf_high, feature_cols, dataset_name, save_dir,
        low_T=15, high_T=60, topN=12):

    importances_low = rf_low.feature_importances_
    importances_high = rf_high.feature_importances_

    idx = np.argsort(importances_high)[::-1][:topN]

    plt.figure(figsize=(9,6))
    plt.barh(
        [feature_cols[i] for i in idx],
        importances_high[idx],
        alpha=0.7,
        label=f"T={high_T} GeV"
    )
    plt.barh(
        [feature_cols[i] for i in idx],
        importances_low[idx],
        alpha=0.7,
        label=f"T={low_T} GeV"
    )

    plt.gca().invert_yaxis()
    plt.xlabel("Feature Importance")
    plt.title(f"Feature Importance Comparison — {dataset_name}")
    plt.legend()

    plt.savefig(os.path.join(save_dir, "feature_importances.png"), dpi=200, bbox_inches="tight")
    plt.close()


### Main function

In [None]:
def run_full_analysis(dataset_name):
    print(f"=== Running full analysis for: {dataset_name} ===")

    df = ML_dict[dataset_name].copy()

    # Create directory for this dataset's plots
    save_dir = os.path.join(PLOT_BASE_DIR, dataset_name)
    os.makedirs(save_dir, exist_ok=True)

    # Prepare once
    X_train_all, X_test_all, y_train_all, y_test_all, feature_cols = prepare_dataset(df)

    roc_results = {}
    counts_sig = {}
    counts_bkg = {}
    rf_models = {}

    # Loop thresholds
    for T in JET_PT_THRESHOLDS:
        print(f"\n→ Applying jet pt threshold T = {T} GeV")

        X_train = apply_jet_pt_threshold(X_train_all, T)
        y_train = y_train_all.loc[X_train.index]

        X_test = apply_jet_pt_threshold(X_test_all, T)
        y_test = y_test_all.loc[X_test.index]

        counts_sig[T] = (y_test == 1).sum()
        counts_bkg[T] = (y_test == 0).sum()

        rf, fpr, tpr, roc_auc = train_and_evaluate_rf(
            X_train, y_train, X_test, y_test
        )

        roc_results[T] = (fpr, tpr, roc_auc)
        rf_models[T] = rf

    # Save all plots
    plot_roc_curves(roc_results, dataset_name, save_dir)
    plot_auc_vs_threshold(roc_results, dataset_name, save_dir)
    plot_event_yields(counts_sig, counts_bkg, dataset_name, save_dir)

    compare_feature_importances(
        rf_low=rf_models[15],
        rf_high=rf_models[60],
        feature_cols=feature_cols,
        dataset_name=dataset_name,
        save_dir=save_dir
    )

    print(f" Completed. Plots saved in: {save_dir}")


### Run the analysis for different datasets

In [None]:
run_full_analysis("HAHMggf")

=== Running full analysis for: HAHMggf ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/HAHMggf


In [None]:
run_full_analysis("HNLeemu")

=== Running full analysis for: HNLeemu ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/HNLeemu


In [None]:
run_full_analysis("HtoSUEP")

=== Running full analysis for: HtoSUEP ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/HtoSUEP


In [None]:
run_full_analysis("VBF_H125_a55a55_4b_ctau1_filtered")

=== Running full analysis for: VBF_H125_a55a55_4b_ctau1_filtered ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/VBF_H125_a55a55_4b_ctau1_filtered


In [None]:
run_full_analysis("Znunu")

=== Running full analysis for: Znunu ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/Znunu


In [None]:
run_full_analysis("ggF_H125_a16a16_4b_ctau10_filtered")

=== Running full analysis for: ggF_H125_a16a16_4b_ctau10_filtered ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/ggF_H125_a16a16_4b_ctau10_filtered


In [None]:
run_full_analysis("hh_bbbb_vbf_novhh_5fs_l1cvv1cv1")

=== Running full analysis for: hh_bbbb_vbf_novhh_5fs_l1cvv1cv1 ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/hh_bbbb_vbf_novhh_5fs_l1cvv1cv1


In [None]:
run_full_analysis("all_signals")

=== Running full analysis for: all_signals ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/all_signals


## GNN: Supervised Classification

In [4]:
!pip install -q torch_geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h

In [5]:
# imports for GNN

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.nn import (
    GCNConv,
    SAGEConv,
    GINConv,
    global_mean_pool,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [6]:
# fixed object ordering: [MET, e, j, mu, ph] for one-hot encoding
OBJ_TYPES  = ["MET", "e", "j", "mu", "ph"]
ONEHOT_MAP = {
    "MET": np.array([1, 0, 0, 0, 0], dtype=np.float32),
    "e"  : np.array([0, 1, 0, 0, 0], dtype=np.float32),
    "j"  : np.array([0, 0, 1, 0, 0], dtype=np.float32),
    "mu" : np.array([0, 0, 0, 1, 0], dtype=np.float32),
    "ph" : np.array([0, 0, 0, 0, 1], dtype=np.float32),
}


def _complete_graph_edge_index(n_nodes: int) -> torch.Tensor:
    """
    Build a fully-connected, directed edge_index without self-loops.
    For undirected GNNs, having i->j and j->i is fine.
    """
    rows = []
    cols = []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i == j:
                continue
            rows.append(i)
            cols.append(j)
    return torch.tensor([rows, cols], dtype=torch.long)


def build_obj_df_and_pyg_dataset(df: pd.DataFrame):
    """
    Parameters
    ----------
    df : pd.DataFrame
        One of the ML_dict dataframes (e.g. ML_dict['all_signals']).
        Must have columns:
        j0pt..j5phi, e0pt..e2phi, mu0pt..mu2phi, ph0pt..ph2phi,
        METpt, METeta, METphi, target.
        (run_number, event_number, weight may be present but are ignored.)

    Returns
    -------
    obj_df : pd.DataFrame
        Per-object table with columns:
        [pT, eta, phi, obj, event, target, obj_MET, obj_e, obj_j, obj_mu, obj_ph]

    data_list : list[torch_geometric.data.Data]
        One PyG Data object per event (row of df).
        Each graph has:
          - x: node features [pT, eta, phi, one-hot(obj-type)]  (shape [N_nodes, 8])
          - edge_index: complete graph between all nodes
          - y: event label (0 = EB_test/background, 1 = signal)
    """

    obj_rows = []      # rows for per-object pandas table
    data_list = []     # PyG graphs

    for event_idx, (_, row) in enumerate(df.iterrows()):
        node_feats = []

        # ---- event-level label ----
        tgt_raw = row["target"]
        if isinstance(tgt_raw, str):
            # EB_test is background (0), everything else is signal (1)
            target = 0 if tgt_raw == "EB_test" else 1
        else:
            target = int(tgt_raw)

        # ---- helper to add one object ----
        def add_obj(pt, eta, phi, obj_name):
            # treat (0,0,0) as "no object" and skip
            if pt == 0 and eta == 0 and phi == 0:
                return

            oh = ONEHOT_MAP[obj_name]  # length-5 one-hot

            # Node features for PyG: [pT, eta, phi, one-hot]
            node_feats.append(np.concatenate([[pt, eta, phi], oh], axis=0))

            # Row for per-object dataframe
            obj_rows.append({
                "pT": float(pt),
                "eta": float(eta),
                "phi": float(phi),
                "obj": obj_name,
                "event": event_idx,      # internal event id
                "target": int(target),
                "obj_MET": int(obj_name == "MET"),
                "obj_e"  : int(obj_name == "e"),
                "obj_j"  : int(obj_name == "j"),
                "obj_mu" : int(obj_name == "mu"),
                "obj_ph" : int(obj_name == "ph"),
            })

        # ---- jets j0..j5 ----
        for i in range(6):
            add_obj(row[f"j{i}pt"],  row[f"j{i}eta"],  row[f"j{i}phi"],  "j")

        # ---- electrons e0..e2 ----
        for i in range(3):
            add_obj(row[f"e{i}pt"],  row[f"e{i}eta"],  row[f"e{i}phi"],  "e")

        # ---- muons mu0..mu2 ----
        for i in range(3):
            add_obj(row[f"mu{i}pt"], row[f"mu{i}eta"], row[f"mu{i}phi"], "mu")

        # ---- photons ph0..ph2 ----
        for i in range(3):
            add_obj(row[f"ph{i}pt"], row[f"ph{i}eta"], row[f"ph{i}phi"], "ph")

        # ---- MET (single object) ----
        add_obj(row["METpt"], row["METeta"], row["METphi"], "MET")

        # If the event has no surviving objects, skip graph creation
        if len(node_feats) == 0:
            continue

        x = torch.tensor(np.vstack(node_feats), dtype=torch.float32)  # [N_nodes, 8]
        edge_index = _complete_graph_edge_index(x.size(0))
        y = torch.tensor([target], dtype=torch.long)                  # [1]

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

    obj_df = pd.DataFrame(obj_rows)
    return obj_df, data_list


In [None]:
df_all = ML_dict["all_signals"]
obj_df, pyg_dataset = build_obj_df_and_pyg_dataset(df_all)

In [None]:
print(obj_df.head())
print(len(pyg_dataset), "graphs")
print(pyg_dataset[0])   # first PyG Data object

In [None]:
# Cell 2: train/val/test split and DataLoaders

# Extract labels to stratify by event-level target
y_all = np.array([g.y.item() for g in pyg_dataset])

idx_train_val, idx_test = train_test_split(
    np.arange(len(pyg_dataset)),
    test_size=0.15,
    stratify=y_all,
    random_state=42,
)

y_train_val = y_all[idx_train_val]

idx_train, idx_val = train_test_split(
    idx_train_val,
    test_size=0.1765,  # 0.85 * 0.1765 ≈ 0.15 → 70/15/15 split overall
    stratify=y_train_val,
    random_state=42,
)

train_graphs = [pyg_dataset[i] for i in idx_train]
val_graphs   = [pyg_dataset[i] for i in idx_val]
test_graphs  = [pyg_dataset[i] for i in idx_test]

print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}, Test: {len(test_graphs)}")

batch_size = 64

train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_graphs,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_graphs,  batch_size=batch_size, shuffle=False)


In [None]:
# Cell 3: visualise one example graph

from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib as mpl

# Safely choose an example event
example_idx = 721  # change this to look at other events, must be < len(train_graphs)
example_data = train_graphs[example_idx]

# Convert to NetworkX graph (nodes are 0 .. N-1, same order as example_data.x)
G = to_networkx(example_data, to_undirected=True)

# Deduce object type from the one-hot part of the node features (x[:, 3:])
# Uses the SAME ordering as in build_obj_df_and_pyg_dataset:
# OBJ_TYPES = ["MET", "e", "j", "mu", "ph"]
try:
    OBJ_TYPES  # use the global definition if it exists
except NameError:
    OBJ_TYPES = ["MET", "e", "j", "mu", "ph"]

x = example_data.x.cpu().numpy()          # shape [N_nodes, 8]
one_hot = x[:, 3:]                        # shape [N_nodes, 5]
obj_type_idx = one_hot.argmax(axis=1)     # integers 0..4

# Map each node to a specific color
cmap = mpl.colormaps.get_cmap("tab10")
node_colors = [cmap(i) for i in obj_type_idx]

fig, ax = plt.subplots(figsize=(5, 5))
pos = nx.spring_layout(G, seed=42)

nx.draw(
    G,
    pos,
    node_color=node_colors,
    with_labels=False,
    node_size=300,
    edge_color="lightgray",
    ax=ax,
)

# Build a matching legend
for i, name in enumerate(OBJ_TYPES):
    ax.scatter([], [], color=cmap(i), label=name, s=80)
ax.legend(title="Object type", bbox_to_anchor=(1.05, 1), loc="upper left")

ax.set_title(f"Example event (train index = {example_idx}) as a graph")
plt.tight_layout()
plt.show()

# (optional) print a quick summary of how many objects of each type are in this event
unique, counts = np.unique(obj_type_idx, return_counts=True)
print("Object counts in this event:")
for idx, cnt in zip(unique, counts):
    print(f"  {OBJ_TYPES[idx]}: {cnt}")


In [None]:
# Cell 4: define a simple GNN for event classification

from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn

class EventGNN(torch.nn.Module):
    def __init__(self, in_channels=8, hidden_channels=32, num_classes=2):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin1  = nn.Linear(hidden_channels, hidden_channels)
        self.lin2  = nn.Linear(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch  # batch is added by DataLoader

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # global pooling → [num_graphs, hidden_channels]
        x = global_mean_pool(x, batch)

        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)          # logits for 2 classes

        return x

model = EventGNN(in_channels=train_graphs[0].x.size(1), hidden_channels=64, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

print(model)


In [None]:
# Cell 5 (fixed): pictorial representation of the GNN architecture (schematic)

import matplotlib.patches as patches

layers = [
    "Input\n(node features)",
    "GCNConv(8 → 64)",
    "ReLU",
    "GCNConv(64 → 64)",
    "ReLU",
    "global_mean_pool",
    "Linear(64 → 64)",
    "ReLU",
    "Linear(64 → 2)\n(logits)",
]

n_layers   = len(layers)
box_width  = 1.4
box_height = 0.8
gap        = 0.5
x0         = 0.5   # left margin
y0         = 0.2   # bottom margin

# Figure width scales with number of layers so everything is visible
fig_width = max(12, n_layers * (box_width + gap) * 0.6)
fig, ax = plt.subplots(figsize=(fig_width+2, 2.8))
ax.axis("off")

for i, name in enumerate(layers):
    x_left = x0 + i * (box_width + gap)

    rect = patches.FancyBboxPatch(
        (x_left, y0),
        box_width,
        box_height,
        boxstyle="round,pad=0.15",
        edgecolor="black",
        facecolor="lightblue",
    )
    ax.add_patch(rect)

    ax.text(
        x_left + box_width / 2.0,
        y0 + box_height / 2.0,
        name,
        ha="center",
        va="center",
        fontsize=8,
    )

    # Draw arrows between blocks
    if i < n_layers - 1:
        x_start = x_left + box_width
        x_end   = x0 + (i + 1) * (box_width + gap)
        ax.annotate(
            "",
            xy=(x_end, y0 + box_height / 2.0),
            xytext=(x_start, y0 + box_height / 2.0),
            arrowprops=dict(arrowstyle="->"),
        )

# Make sure everything is in view
x_max = x0 + n_layers * (box_width + gap)
ax.set_xlim(0, x_max + 0.5)
ax.set_ylim(0, y0 + box_height + 0.5)

plt.title("Schematic of EventGNN architecture", pad=20)
plt.tight_layout()
plt.show()


Should take appoximately 17 mins to train 20 epochs

In [None]:
# Cell 6: training & validation loop

from torch_geometric.loader import DataLoader

num_epochs = 10

train_losses = []
val_losses   = []
val_accuracies = []

for epoch in range(1, num_epochs + 1):
    # --- Training ---
    model.train()
    total_loss = 0.0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)                # [batch_size_graphs, 2]
        loss = criterion(out, batch.y)    # batch.y shape [batch_size_graphs]
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs

    avg_train_loss = total_loss / len(train_graphs)
    train_losses.append(avg_train_loss)

    # --- Validation ---
    model.eval()
    total_val_loss = 0.0
    correct = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            out = model(batch)
            loss = criterion(out, batch.y)
            total_val_loss += loss.item() * batch.num_graphs

            preds = out.argmax(dim=1)
            correct += (preds == batch.y).sum().item()

    avg_val_loss = total_val_loss / len(val_graphs)
    val_losses.append(avg_val_loss)

    val_acc = correct / len(val_graphs)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch:02d} | "
          f"Train loss: {avg_train_loss:.4f} | "
          f"Val loss: {avg_val_loss:.4f} | "
          f"Val acc: {val_acc:.3f}")


In [None]:
# Cell 7: training & validation loss curves

plt.figure(figsize=(6, 4))
plt.plot(train_losses, label="Train loss")
plt.plot(val_losses,   label="Val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training / Validation loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
# Cell 8: evaluation on test set and ROC curve

model.eval()
all_probs = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        logits = model(batch)                      # [B, 2]
        probs = F.softmax(logits, dim=1)[:, 1]     # probability of class 1 (signal)
        all_probs.append(probs.cpu())
        all_labels.append(batch.y.cpu())

all_probs = torch.cat(all_probs).numpy()
all_labels = torch.cat(all_labels).numpy()

# Basic accuracy
pred_labels = (all_probs >= 0.5).astype(int)
test_acc = (pred_labels == all_labels).mean()
print(f"Test accuracy: {test_acc:.3f}")

# ROC curve
fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)
print(f"AUC: {roc_auc:.3f}")

plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.3f})")
plt.plot([0, 1], [0, 1], "k--", label="Random")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC curve (test set)")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


## GNN: Unsupervised Anomaly Detection (GCN / GraphSAGE / GIN)

In this section we train a simple Graph Autoencoder (GAE) in an
unsupervised way on background (EB_test) events and compare three
convolution operators in PyTorch Geometric:

- GCNConv
- SAGEConv (GraphSAGE)
- GINConv

We keep the same graph construction (`build_obj_df_and_pyg_dataset`) and
optionally apply jet $p_T$ cuts. We support:

- **No jet $p_T$ cut** (whole dataset),
- **One fixed jet $p_T$ cut**, or
- **A scan over several jet $p_T$ thresholds.**

For each configuration and convolution, we:
1. Build PyG graphs.
2. Train a GAE on background-only graphs.
3. Compute graph-level reconstruction errors on a mixed test set.
4. Build ROC curves / AUC for anomaly detection.
5. Visualise latent space.
6. Plot input vs reconstructed $(\eta,\phi)$ for a few events using  
   `plot_eta_phi_input_vs_reco_and_save`.


Run the cell for `build_obj_df_and_pyg_dataset(df)` in the previous section. Takes $\sim$ 4 minutes.

In [None]:
# Cell 1: build PyG graphs and split background vs signal

df = ML_dict["Znunu"]

# Uses the previously defined function (EB_test -> 0, signal -> 1)
obj_df, pyg_list = build_obj_df_and_pyg_dataset(df)

print(f"Total events (graphs): {len(pyg_list)}")
print("Labels (0=EB_test/background, 1=signal):",
      {int(y): int(sum(g.y.item() == y for g in pyg_list)) for y in [0,1]})

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Extract labels
labels = np.array([g.y.item() for g in pyg_list])
bkg_idx    = np.where(labels == 0)[0]   # EB_test
signal_idx = np.where(labels == 1)[0]   # everything else
print(f"Background graphs: {len(bkg_idx)}, Signal graphs: {len(signal_idx)}")


Total events (graphs): 17242
Labels (0=EB_test/background, 1=signal): {0: 8017, 1: 9225}
Using device: cpu
Background graphs: 8017, Signal graphs: 9225


### All helper functions

In [7]:
# Cell 1: global config for plots & jet pT selection

# ------------------------------------------------------------------
# Base directory for plots
# ------------------------------------------------------------------
if "PLOT_BASE_DIR" not in globals():
    # Adjust if you use a different path
    PLOT_BASE_DIR = "/kaggle/working/GAEplots"

GNN_UNSUP_BASE_DIR = os.path.join(PLOT_BASE_DIR, "GAE_detailedstudy")
os.makedirs(GNN_UNSUP_BASE_DIR, exist_ok=True)
print("GNN unsupervised results will be saved under:", GNN_UNSUP_BASE_DIR)

# ------------------------------------------------------------------
# Jet pT thresholds & selection
# ------------------------------------------------------------------
try:
    JET_PT_THRESHOLDS
    JET_PT_COLS
except NameError:
    # Fallback if RF section hasn't run
    JET_PT_THRESHOLDS = [5, 15, 25, 45, 60, 80]
    JET_PT_COLS = [f"j{i}pt" for i in range(6)]

def apply_jet_pt_threshold(df, threshold):
    """
    Returns a filtered dataframe where all nonzero jets have pt >= threshold.
    Condition: for each jet jX,
      keep event if (jXpt == 0) or (jXpt >= threshold)
    """
    mask = np.ones(len(df), dtype=bool)
    for col in JET_PT_COLS:
        if col in df.columns:
            mask &= (df[col] == 0) | (df[col] >= threshold)
    return df[mask]

# Object types for η–φ plots (node one-hot order)
if "OBJ_TYPES" not in globals():
    OBJ_TYPES = ["MET", "e", "j", "mu", "ph"]


GNN unsupervised results will be saved under: /kaggle/working/GAEplots/GAE_detailedstudy


In [8]:
print("PLOT_BASE_DIR      =", PLOT_BASE_DIR)
print("GNN_UNSUP_BASE_DIR =", GNN_UNSUP_BASE_DIR)

PLOT_BASE_DIR      = /kaggle/working/GAEplots
GNN_UNSUP_BASE_DIR = /kaggle/working/GAEplots/GAE_detailedstudy


In [9]:
from torch.nn import Sequential, Linear, ReLU

class GraphAutoEncoder(nn.Module):
    """
    Node-wise graph autoencoder.

    Encoder:
      - conv1 (GCN/SAGE/GIN) -> ReLU -> (optional BatchNorm) -> Dropout
      - conv_mid (GCN/SAGE/GIN, hidden_channels -> hidden_channels/2)
           -> ReLU -> (optional BatchNorm) -> Dropout
      - conv2 (GCN/SAGE/GIN, hidden_channels/2 -> latent) -> ReLU -> (optional BatchNorm) -> Dropout

    Decoder:
      - Linear -> ReLU -> Dropout -> Linear

    Works on generic node features (in_channels), not hard-coded to 8.
    """
    def __init__(
        self,
        in_channels,
        hidden_channels=64,
        latent_channels=16,
        dropout=0.1,
        conv_type="gcn",
        sage_aggr="mean",
        use_batchnorm=False,
    ):
        super().__init__()

        self.conv_type = conv_type.lower()
        self.dropout = dropout
        self.use_batchnorm = use_batchnorm

        # intermediate hidden size
        mid_channels = max(1, hidden_channels // 2)
        self.mid_channels = mid_channels

        # ----- Encoder convolutions -----
        if self.conv_type == "gcn":
            self.conv1     = GCNConv(in_channels,     hidden_channels)
            self.conv_mid  = GCNConv(hidden_channels, mid_channels)
            self.conv2     = GCNConv(mid_channels,    latent_channels)

        elif self.conv_type == "sage":
            self.conv1     = SAGEConv(in_channels,     hidden_channels, aggr=sage_aggr)
            self.conv_mid  = SAGEConv(hidden_channels, mid_channels,    aggr=sage_aggr)
            self.conv2     = SAGEConv(mid_channels,    latent_channels, aggr=sage_aggr)

        elif self.conv_type == "gin":
            mlp1 = Sequential(
                Linear(in_channels, hidden_channels),
                ReLU(),
                Linear(hidden_channels, hidden_channels),
            )
            mlp_mid = Sequential(
                Linear(hidden_channels, mid_channels),
                ReLU(),
                Linear(mid_channels, mid_channels),
            )
            mlp2 = Sequential(
                Linear(mid_channels, latent_channels),
                ReLU(),
                Linear(latent_channels, latent_channels),
            )
            self.conv1    = GINConv(mlp1)
            self.conv_mid = GINConv(mlp_mid)
            self.conv2    = GINConv(mlp2)

        else:
            raise ValueError(f"Unknown conv_type '{conv_type}'. Use 'gcn', 'sage', or 'gin'.")

        if self.use_batchnorm:
            self.bn1     = nn.BatchNorm1d(hidden_channels)
            self.bn_mid  = nn.BatchNorm1d(mid_channels)
            self.bn2     = nn.BatchNorm1d(latent_channels)

        # ----- Decoder (unchanged) -----
        self.dec_lin1 = nn.Linear(latent_channels, hidden_channels)
        self.dec_lin2 = nn.Linear(hidden_channels, in_channels)

    def encode(self, x, edge_index):
        # First hidden layer
        h = self.conv1(x, edge_index)
        if self.use_batchnorm:
            h = self.bn1(h)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout, training=self.training)

        # New intermediate hidden layer (hidden_channels -> hidden_channels/2)
        h_mid = self.conv_mid(h, edge_index)
        if self.use_batchnorm:
            h_mid = self.bn_mid(h_mid)
        h_mid = F.relu(h_mid)
        h_mid = F.dropout(h_mid, p=self.dropout, training=self.training)

        # Latent layer
        z = self.conv2(h_mid, edge_index)
        if self.use_batchnorm:
            z = self.bn2(z)
        z = F.relu(z)
        z = F.dropout(z, p=self.dropout, training=self.training)
        return z

    def decode(self, z):
        h = self.dec_lin1(z)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        x_hat = self.dec_lin2(h)
        return x_hat

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        z = self.encode(x, edge_index)
        x_hat = self.decode(z)
        return x_hat, z


In [10]:
# Cell 4: Train/val/test splits, training loop, scoring

def make_gae_splits_and_loaders(pyg_list, batch_size=128, random_state=42):
    """
    Splits a list of PyG graphs into train/val/test for unsupervised anomaly detection.

    - Training uses *only* background graphs (y == 0).
    - Validation and test contain both background and signal.

    Returns
    -------
    train_loader, val_loader, test_loader, test_labels, split_info
    """
    labels = np.array([int(g.y.item()) for g in pyg_list])
    bkg_idx = np.where(labels == 0)[0]
    sig_idx = np.where(labels == 1)[0]

    n_total = len(pyg_list)
    n_bkg = len(bkg_idx)
    n_sig = len(sig_idx)
    print(f"Total graphs={n_total}, background={n_bkg}, signal={n_sig}")

    if n_bkg < 10 or n_sig < 10:
        print("Not enough background or signal graphs for a meaningful ROC → aborting.")
        return None

    from sklearn.model_selection import train_test_split

    # 60/20/20 split of background graphs
    bkg_train_idx, bkg_tmp_idx = train_test_split(
        bkg_idx, test_size=0.4, random_state=random_state
    )
    bkg_val_idx, bkg_test_idx = train_test_split(
        bkg_tmp_idx, test_size=0.5, random_state=random_state
    )

    # 50/50 split of signal graphs into val/test
    if len(sig_idx) >= 2:
        sig_val_idx, sig_test_idx = train_test_split(
            sig_idx, test_size=0.5, random_state=random_state
        )
    else:
        sig_val_idx, sig_test_idx = sig_idx, np.array([], dtype=int)

    train_graphs = [pyg_list[i] for i in bkg_train_idx]
    val_graphs   = [pyg_list[i] for i in np.concatenate([bkg_val_idx, sig_val_idx])]
    test_idx     = np.concatenate([bkg_test_idx, sig_test_idx])
    test_graphs  = [pyg_list[i] for i in test_idx]
    test_labels  = labels[test_idx]

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_graphs,   batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_graphs,  batch_size=batch_size, shuffle=False)

    print(f"Train bkg graphs: {len(train_graphs)}")
    print(f"Val graphs:       {len(val_graphs)} (bkg={len(bkg_val_idx)}, sig={len(sig_val_idx)})")
    print(f"Test graphs:      {len(test_graphs)} (bkg={len(bkg_test_idx)}, sig={len(sig_test_idx)})")

    split_info = {
        "n_total_graphs": int(n_total),
        "n_bkg_graphs":   int(n_bkg),
        "n_sig_graphs":   int(n_sig),
        "n_train_bkg":    int(len(bkg_train_idx)),
        "n_val_bkg":      int(len(bkg_val_idx)),
        "n_test_bkg":     int(len(bkg_test_idx)),
        "n_val_sig":      int(len(sig_val_idx)),
        "n_test_sig":     int(len(sig_test_idx)),
    }

    return train_loader, val_loader, test_loader, test_labels, split_info


def train_gae(
    model,
    train_loader,
    val_loader,
    device,
    num_epochs=20,
    lr=1e-3,
    weight_decay=1e-5,
    verbose=False,
):
    """
    Trains the GAE with MSE reconstruction loss on node features.

    Returns
    -------
    model, history   (history has 'train_loss', 'val_loss')
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.to(device)

    train_losses = []
    val_losses = []

    for epoch in range(1, num_epochs + 1):
        # ---- Train ----
        model.train()
        total_train_loss = 0.0
        n_train_batches = 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            x_hat, z = model(batch)
            loss = F.mse_loss(x_hat, batch.x)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            n_train_batches += 1

        avg_train_loss = total_train_loss / max(n_train_batches, 1)

        # ---- Validation ----
        model.eval()
        total_val_loss = 0.0
        n_val_batches = 0

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                x_hat, z = model(batch)
                loss = F.mse_loss(x_hat, batch.x)

                total_val_loss += loss.item()
                n_val_batches += 1

        avg_val_loss = total_val_loss / max(n_val_batches, 1)

        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        if verbose:
            print(
                f"[Epoch {epoch:02d}/{num_epochs}] "
                f"train_loss={avg_train_loss:.4e}, val_loss={avg_val_loss:.4e}"
            )

    history = {
        "train_loss": train_losses,
        "val_loss":   val_losses,
    }
    return model, history


def compute_gae_scores_and_latent(model, loader, device):
    """
    Compute graph-level anomaly scores and latent embeddings.

    - score  = mean node-wise reconstruction error per graph
    - latent = mean-pooled node latent vector per graph
    - labels = graph labels from batch.y
    """
    model.eval()
    all_scores = []
    all_latent = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            labels = batch.y.cpu().numpy()
            batch = batch.to(device)

            x_hat, z = model(batch)

            per_elem = F.mse_loss(x_hat, batch.x, reduction="none")   # [N_nodes, F]
            per_node = per_elem.mean(dim=1)                            # [N_nodes]
            scores = global_mean_pool(per_node, batch.batch)           # [n_graphs]

            latent = global_mean_pool(z, batch.batch)

            all_scores.append(scores.cpu().numpy())
            all_latent.append(latent.cpu().numpy())
            all_labels.append(labels)

    scores = np.concatenate(all_scores, axis=0)
    latent = np.concatenate(all_latent, axis=0)
    labels = np.concatenate(all_labels, axis=0)

    return scores, latent, labels


In [11]:
# Cell 5: ROC computation and generic plotting helpers

def compute_roc_auc(scores, labels, pos_label=1):
    """
    Compute ROC curve and AUC for anomaly scores.
    Assumes larger scores = more anomalous (signal-like).
    """
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=pos_label)
    roc_auc = auc(fpr, tpr)
    return fpr, tpr, roc_auc


def plot_gae_loss_curves(history, label_str, out_dir, conv_tag):
    train_losses = history["train_loss"]
    val_losses   = history["val_loss"]
    epochs = np.arange(1, len(train_losses) + 1)

    plt.figure()
    plt.plot(epochs, train_losses, label="train")
    plt.plot(epochs, val_losses,   label="val")
    plt.xlabel("Epoch")
    plt.ylabel("MSE loss")
    plt.title(f"GAE loss ({conv_tag}) — {label_str}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{conv_tag}_loss_{label_str.replace(' ', '_')}.png")
    plt.savefig(out_path)
    plt.close()
    print("Saved loss curve to:", out_path)


def plot_anomaly_score_distribution(scores, labels, label_str, out_dir, conv_tag):
    scores = np.asarray(scores)
    labels = np.asarray(labels)

    bkg_scores = scores[labels == 0]
    sig_scores = scores[labels == 1]

    lo = np.percentile(scores, 1)
    hi = np.percentile(scores, 99)
    if hi <= lo:
        lo, hi = scores.min(), scores.max()
    bins = np.linspace(lo, hi, 60)

    plt.figure()
    plt.hist(bkg_scores, bins=bins, histtype="step", density=True, label="bkg")
    plt.hist(sig_scores, bins=bins, histtype="step", density=True, label="sig")
    plt.xlabel("Graph-level reconstruction error")
    plt.ylabel("Density")
    plt.title(f"Anomaly score distribution ({conv_tag}) — {label_str}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{conv_tag}_score_hist_{label_str.replace(' ', '_')}.png")
    plt.savefig(out_path)
    plt.close()
    print("Saved score histogram to:", out_path)


def plot_latent_space(latent, labels, label_str, out_dir, conv_tag):
    """
    Simple 2D latent scatter using PCA.
    """
    latent = np.asarray(latent)
    labels = np.asarray(labels)

    if latent.shape[1] < 2:
        print("Latent dimension < 2; skipping latent space plot.")
        return

    pca = PCA(n_components=2)
    z_pca = pca.fit_transform(latent)

    plt.figure()
    for cls, name, marker in [(0, "bkg", "."), (1, "sig", "^")]:
        mask = labels == cls
        if mask.sum() == 0:
            continue
        plt.scatter(z_pca[mask, 0], z_pca[mask, 1], s=10, alpha=0.6, label=name, marker=marker)

    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.title(f"Latent space PCA ({conv_tag}) — {label_str}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{conv_tag}_latent_{label_str.replace(' ', '_')}.png")
    plt.savefig(out_path)
    plt.close()
    print("Saved latent space plot to:", out_path)


def plot_roc_curve_single(fpr, tpr, roc_auc, label_str, out_dir, conv_tag):
    plt.figure()
    plt.plot(fpr, tpr, label=f"{conv_tag} (AUC={roc_auc:.3f})")
    plt.plot([0, 1], [0, 1], "k--", label="random")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC — {label_str}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{conv_tag}_roc_{label_str.replace(' ', '_')}.png")
    plt.savefig(out_path)
    plt.close()
    print("Saved ROC curve to:", out_path)


def plot_gae_roc_all_thresholds(roc_results, dataset_name, save_dir, conv_tag):
    """
    For pT scan: overlay ROC curves for multiple thresholds for a single conv_type.
    roc_results: dict[pt_threshold] = (fpr, tpr, auc)
    """
    if not roc_results:
        print("No ROC results to plot.")
        return

    plt.figure()
    for T, (fpr, tpr, roc_auc) in sorted(roc_results.items()):
        plt.plot(fpr, tpr, label=f"T={T} GeV (AUC={roc_auc:.3f})")

    plt.plot([0, 1], [0, 1], "k--", label="random")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"GAE ROC vs jet $p_T$ — {dataset_name} ({conv_tag})")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs(save_dir, exist_ok=True)
    out_path = os.path.join(save_dir, f"{conv_tag}_roc_all_thresholds.png")
    plt.savefig(out_path)
    plt.close()
    print("Saved multi-threshold ROC plot to:", out_path)


def plot_gae_auc_vs_threshold(roc_results, dataset_name, save_dir, conv_tag):
    """
    AUC vs jet pT threshold for a single conv_type.
    """
    if not roc_results:
        print("No ROC results to plot.")
        return

    thresholds = sorted(roc_results.keys())
    auc_vals = [roc_results[T][2] for T in thresholds]

    plt.figure()
    plt.plot(thresholds, auc_vals, marker="o")
    plt.xlabel("Jet $p_T$ threshold [GeV]")
    plt.ylabel("AUC")
    plt.title(f"GAE AUC vs jet $p_T$ — {dataset_name} ({conv_tag})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs(save_dir, exist_ok=True)
    out_path = os.path.join(save_dir, f"{conv_tag}_auc_vs_threshold.png")
    plt.savefig(out_path)
    plt.close()
    print("Saved AUC vs threshold plot to:", out_path)


In [12]:
# Cell 6 (updated): η–φ visualisation of input vs reconstruction
import copy

def plot_eta_phi_input_vs_reco_and_save(
    data,
    model,
    device,
    exp_dir,
    filename_prefix,
    title_suffix="",
):
    """
    Visualise one event as η–φ scatter plots (input vs reconstruction) and save it.

    IMPORTANT: we clone `data` before moving to device to avoid mutating the
    original dataset graphs (which would cause CPU/GPU mixing inside DataLoader).
    """
    model.eval()
    with torch.no_grad():
        # Make a deep copy so we do NOT modify the original graph in test_graphs
        batch = copy.deepcopy(data).to(device)
        x_hat, z = model(batch)
        x_hat = x_hat.cpu().numpy()

    # Use the original CPU graph for inputs
    x_in = data.x.cpu().numpy()

    # Input features
    pt_in  = x_in[:, 0]
    eta_in = x_in[:, 1]
    phi_in = x_in[:, 2]

    # Reconstructed features
    pt_out  = x_hat[:, 0]
    eta_out = x_hat[:, 1]
    phi_out = x_hat[:, 2]

    # Object type from one-hot
    try:
        OBJ_TYPES  # noqa: F823
    except NameError:
        OBJ_TYPES = ["MET", "e", "j", "mu", "ph"]

    onehot_in = x_in[:, 3:]
    type_idx  = onehot_in.argmax(axis=1)

    cmap = mpl.colormaps.get_cmap("tab10")

    def pt_to_size(pt):
        pt = np.clip(pt, 0, None)
        if pt.max() <= 0:
            return np.full_like(pt, 30.0)
        pt_norm = pt / (pt.max() + 1e-8)
        return 20.0 + 80.0 * pt_norm

    sizes_in  = pt_to_size(pt_in)
    sizes_out = pt_to_size(pt_out)

    eta_min = min(eta_in.min(), eta_out.min())
    eta_max = max(eta_in.max(), eta_out.max())
    phi_min = min(phi_in.min(), phi_out.min())
    phi_max = max(phi_in.max(), phi_out.max())

    node_colors = [cmap(int(i)) for i in type_idx]

    fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

    # ---- Left: input ----
    ax = axes[0]
    ax.scatter(
        eta_in,
        phi_in,
        s=sizes_in,
        c=node_colors,
        alpha=0.8,
        edgecolors="k",
    )
    ax.set_xlabel(r"$\eta$")
    ax.set_ylabel(r"$\phi$")
    ax.set_title("Input features " + title_suffix)
    ax.set_xlim(eta_min, eta_max)
    ax.set_ylim(phi_min, phi_max)
    ax.grid(True, alpha=0.3)

    # ---- Right: reconstruction ----
    ax = axes[1]
    ax.scatter(
        eta_out,
        phi_out,
        s=sizes_out,
        c=node_colors,
        alpha=0.8,
        edgecolors="k",
    )
    ax.set_xlabel(r"$\eta$")
    ax.set_ylabel(r"$\phi$")
    ax.set_title("Reconstructed features " + title_suffix)
    ax.set_xlim(eta_min, eta_max)
    ax.set_ylim(phi_min, phi_max)
    ax.grid(True, alpha=0.3)

    # Legend for object types
    handles = []
    labels  = []
    for i, name in enumerate(OBJ_TYPES):
        handles.append(
            plt.Line2D(
                [], [],
                marker="o",
                linestyle="",
                markersize=8,
                markerfacecolor=cmap(i),
                markeredgecolor="k",
            )
        )
        labels.append(name)

    fig.legend(
        handles,
        labels,
        title="Object type",
        loc="upper center",
        bbox_to_anchor=(0.5, 0.02),
        ncol=len(OBJ_TYPES),
    )

    os.makedirs(exp_dir, exist_ok=True)
    save_path = os.path.join(exp_dir, f"{filename_prefix}.png")
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    # plt.show()
    plt.close(fig)
    print(f"Saved {filename_prefix} plot to: {save_path}")


def plot_eta_phi_examples(
    model,
    device,
    test_graphs,
    test_labels,
    exp_dir,
    max_bkg=3,
    max_sig=3,
):
    """
    Pick up to `max_bkg` background and `max_sig` signal graphs from the test set
    and make η–φ plots for each.

    Note: Uses the updated `plot_eta_phi_input_vs_reco_and_save` which clones graphs.
    """
    test_labels = np.array(test_labels)

    bkg_indices = np.where(test_labels == 0)[0]
    sig_indices = np.where(test_labels == 1)[0]

    print(f"Found {len(bkg_indices)} background and {len(sig_indices)} signal test events.")

    n_bkg_to_show = min(max_bkg, len(bkg_indices))
    n_sig_to_show = min(max_sig, len(sig_indices))

    print(f"Showing {n_bkg_to_show} background and {n_sig_to_show} signal events.")

    for k in range(n_bkg_to_show):
        idx = bkg_indices[k]
        print(f"\nBackground example {k+1} (test index = {idx})")
        example_bkg = test_graphs[idx]
        plot_eta_phi_input_vs_reco_and_save(
            example_bkg,
            model,
            device,
            exp_dir,
            filename_prefix=f"bkg_example_{k+1}_eta_phi",
            title_suffix="(Background)",
        )

    for k in range(n_sig_to_show):
        idx = sig_indices[k]
        print(f"\nSignal example {k+1} (test index = {idx})")
        example_sig = test_graphs[idx]
        plot_eta_phi_input_vs_reco_and_save(
            example_sig,
            model,
            device,
            exp_dir,
            filename_prefix=f"sig_example_{k+1}_eta_phi",
            title_suffix="(Signal)",
        )


In [13]:
# Cell 7: Core runner on a PyG graph list (single dataset + selection + conv)

def run_gae_on_graph_list(
    pyg_list,
    dataset_name,
    selection_label,
    conv_type="gcn",
    num_epochs=20,
    batch_size=128,
    hidden_channels=64,
    latent_channels=16,
    lr=1e-3,
    weight_decay=1e-5,
    dropout=0.1,
    random_state=42,
    sage_aggr="mean",
    use_batchnorm=False,
    n_bkg_examples=3,
    n_sig_examples=3,
):
    """
    Core routine:
      - Takes list of PyG graphs (with .x, .edge_index, .y).
      - Splits into train/val/test (background-only training).
      - Builds GraphAutoEncoder with specified conv_type and hyperparams.
      - Trains, computes scores & ROC, produces plots.

    Returns a dict containing:
      - 'summary' (with roc_auc, counts, etc.)
      - 'fpr', 'tpr', 'scores', 'latent', 'test_labels'
      - 'model', 'history', 'selection_dir', 'conv_tag', 'sel_label_str'
    """
    assert len(pyg_list) > 0, "pyg_list is empty."

    conv_type = conv_type.lower()
    conv_tag = conv_type.upper()

    labels = np.array([int(g.y.item()) for g in pyg_list])
    n_total = len(pyg_list)
    n_bkg = int((labels == 0).sum())
    n_sig = int((labels == 1).sum())
    print(f"[{dataset_name}, {selection_label}, {conv_tag}] Graphs total={n_total}, background={n_bkg}, signal={n_sig}")

    # --- Splits ---
    splits = make_gae_splits_and_loaders(
        pyg_list,
        batch_size=batch_size,
        random_state=random_state,
    )
    if splits is None:
        return None

    train_loader, val_loader, test_loader, test_labels, split_info = splits
    test_graphs = list(test_loader.dataset)  # underlying list of graphs

    # --- Model ---
    in_channels = pyg_list[0].x.size(1)
    print("Inferred in_channels:", in_channels)

    model = GraphAutoEncoder(
        in_channels=in_channels,
        hidden_channels=hidden_channels,
        latent_channels=latent_channels,
        dropout=dropout,
        conv_type=conv_type,
        sage_aggr=sage_aggr,
        use_batchnorm=use_batchnorm,
    ).to(device)

    # --- Train ---
    model, history = train_gae(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=num_epochs,
        lr=lr,
        weight_decay=weight_decay,
        verbose=True,
    )

    # --- Scores / latent on test ---
    scores, latent, test_labels = compute_gae_scores_and_latent(
        model, test_loader, device
    )

    # --- ROC ---
    fpr, tpr, roc_auc = compute_roc_auc(scores, test_labels, pos_label=1)
    print(f"[{dataset_name}, {selection_label}, {conv_tag}] ROC AUC = {roc_auc:.3f}")

    # --- Directories & label strings ---
    sel_label_str = f"{dataset_name}, {selection_label}"
    selection_dir = os.path.join(
        GNN_UNSUP_BASE_DIR,
        f"{dataset_name}__{conv_tag}",
        selection_label,
    )
    os.makedirs(selection_dir, exist_ok=True)

    # --- Plots ---
    plot_gae_loss_curves(history, sel_label_str, selection_dir, conv_tag)
    plot_anomaly_score_distribution(scores, test_labels, sel_label_str, selection_dir, conv_tag)
    plot_latent_space(latent, test_labels, sel_label_str, selection_dir, conv_tag)
    plot_roc_curve_single(fpr, tpr, roc_auc, sel_label_str, selection_dir, conv_tag)

    # η–φ examples (up to a few background + signal events)
    try:
        plot_eta_phi_examples(
            model,
            device,
            test_graphs,
            test_labels,
            selection_dir,
            max_bkg=n_bkg_examples,
            max_sig=n_sig_examples,
        )
    except Exception as e:
        print("Warning: η–φ plotting failed with error:", e)

    # --- Summary dict ---
    summary = {
        "dataset_name": dataset_name,
        "selection_label": selection_label,
        "conv_type": conv_type,
        "roc_auc": float(roc_auc),
        **split_info,
    }

    return {
        "model": model,
        "history": history,
        "scores": scores,
        "latent": latent,
        "test_labels": test_labels,
        "fpr": fpr,
        "tpr": tpr,
        "roc_auc": roc_auc,
        "summary": summary,
        "selection_dir": selection_dir,
        "conv_tag": conv_tag,
        "sel_label_str": sel_label_str,
    }


In [14]:
# Cell 8: Single-dataset driver with optional jet pT threshold

def run_gae_unsupervised_single(
    dataset_name="Znunu",
    conv_type="gcn",
    pt_threshold=None,
    num_epochs=20,
    batch_size=128,
    hidden_channels=64,
    latent_channels=16,
    lr=1e-3,
    weight_decay=1e-5,
    dropout=0.1,
    random_state=42,
    sage_aggr="mean",
    use_batchnorm=False,
    max_events=None,
):
    """
    Run a single GAE unsupervised anomaly detection experiment for one dataset
    and one convolution type, with an optional jet pT threshold.

    - If pt_threshold is None: uses the full dataset.
    - If pt_threshold is a float/int: events are filtered with apply_jet_pt_threshold.
    - Training is background-only; evaluation uses held-out background + signal.

    Returns
    -------
    dict from run_gae_on_graph_list, or None if there are not enough events/graphs.
    """
    assert "ML_dict" in globals(), "ML_dict not found; please run the cell that loads the pickle."
    assert "build_obj_df_and_pyg_dataset" in globals(), "build_obj_df_and_pyg_dataset not defined; run that cell first."

    df = ML_dict[dataset_name].copy()
    print("\n" + "#" * 80)
    print(f"GAE unsupervised run on dataset: {dataset_name} (conv={conv_type})")

    if pt_threshold is not None:
        selection_label = f"T{pt_threshold}GeV"
        df = apply_jet_pt_threshold(df, pt_threshold)
        print(f"Applying jet pT threshold: {pt_threshold} GeV → events surviving: {len(df)}")
    else:
        selection_label = "noPtCut"
        print(f"No jet pT threshold applied → events: {len(df)}")

    if len(df) == 0:
        print("No events left after selection; skipping.")
        return None

    if max_events is not None and len(df) > max_events:
        df = df.iloc[:max_events].copy()
        print(f"Downsampling to first {max_events} events.")

    # Build PyG dataset
    obj_df, pyg_list = build_obj_df_and_pyg_dataset(df)

    if len(pyg_list) < 20:
        print("Not enough graphs to train/evaluate; skipping.")
        return None

    result = run_gae_on_graph_list(
        pyg_list=pyg_list,
        dataset_name=dataset_name,
        selection_label=selection_label,
        conv_type=conv_type,
        num_epochs=num_epochs,
        batch_size=batch_size,
        hidden_channels=hidden_channels,
        latent_channels=latent_channels,
        lr=lr,
        weight_decay=weight_decay,
        dropout=dropout,
        random_state=random_state,
        sage_aggr=sage_aggr,
        use_batchnorm=use_batchnorm,
    )

    print("#" * 80 + "\n")
    return result


In [15]:
# Cell 9: Compare GCN / GraphSAGE / GIN at a single jet pT selection

def run_gae_compare_convs_single(
    dataset_name="Znunu",
    conv_types=("gcn", "sage", "gin"),
    pt_threshold=None,
    num_epochs=20,
    batch_size=128,
    hidden_channels=64,
    latent_channels=16,
    lr=1e-3,
    weight_decay=1e-5,
    dropout=0.1,
    random_state=42,
    sage_aggr="mean",
    use_batchnorm=False,
    max_events=None,
):
    """
    For a fixed dataset and an optional jet pT selection, run a GAE with multiple
    convolution types and compare ROC AUCs.

    Returns
    -------
    results: dict[conv_type] -> dict (output of run_gae_unsupervised_single)
    """
    results = {}

    for conv_type in conv_types:
        print("\n" + "=" * 80)
        print(f"Running conv_type = {conv_type}")
        print("=" * 80)

        res = run_gae_unsupervised_single(
            dataset_name=dataset_name,
            conv_type=conv_type,
            pt_threshold=pt_threshold,
            num_epochs=num_epochs,
            batch_size=batch_size,
            hidden_channels=hidden_channels,
            latent_channels=latent_channels,
            lr=lr,
            weight_decay=weight_decay,
            dropout=dropout,
            random_state=random_state,
            sage_aggr=sage_aggr,
            use_batchnorm=use_batchnorm,
            max_events=max_events,
        )
        results[conv_type] = res

    # Quick summary table
    print("\nSummary of ROC AUCs:")
    for conv_type in conv_types:
        res = results.get(conv_type)
        if res is None:
            print(f"  {conv_type:>4s}:  (no result)")
        else:
            auc_val = res["roc_auc"]
            print(f"  {conv_type:>4s}:  AUC = {auc_val:.3f}")

    return results


In [16]:
# Cell 10: Full jet pT scan for multiple convolutions

def run_gae_compare_convs_pt_scan(
    dataset_name="Znunu",
    conv_types=("gcn", "sage", "gin"),
    pt_thresholds=None,
    num_epochs=20,
    batch_size=128,
    hidden_channels=64,
    latent_channels=16,
    lr=1e-3,
    weight_decay=1e-5,
    dropout=0.1,
    random_state=42,
    sage_aggr="mean",
    use_batchnorm=False,
    max_events=None,
):
    """
    For a given dataset, scan over a list of jet pT thresholds and for each
    threshold train/evaluate a GAE with each convolution type.

    Returns
    -------
    all_results: dict[conv_type][threshold] -> result dict from run_gae_on_graph_list
    """
    if pt_thresholds is None:
        pt_thresholds = JET_PT_THRESHOLDS

    assert "ML_dict" in globals(), "ML_dict not found; please run the cell that loads the pickle."
    assert "build_obj_df_and_pyg_dataset" in globals(), "build_obj_df_and_pyg_dataset not defined; run that cell first."

    df_full = ML_dict[dataset_name].copy()
    print("\n" + "#" * 80)
    print(f"Full pT scan for dataset: {dataset_name}")
    print(f"Total events before any jet-pt selection: {len(df_full)}")

    all_results = {conv: {} for conv in conv_types}

    for T in pt_thresholds:
        print("\n" + "-" * 80)
        print(f"Jet pT threshold T = {T} GeV")
        print("-" * 80)

        df_T = apply_jet_pt_threshold(df_full.copy(), T)
        print(f"Events surviving after jet pT cut: {len(df_T)}")

        if len(df_T) == 0:
            print("No events survive; skipping this threshold.")
            continue

        if max_events is not None and len(df_T) > max_events:
            df_T = df_T.iloc[:max_events].copy()
            print(f"Downsampling to first {max_events} events.")

        obj_df_T, pyg_list_T = build_obj_df_and_pyg_dataset(df_T)
        labels_T = np.array([g.y.item() for g in pyg_list_T])

        n_graphs = len(pyg_list_T)
        n_bkg = int((labels_T == 0).sum())
        n_sig = int((labels_T == 1).sum())
        print(f"[T={T}] Graphs total={n_graphs}, background={n_bkg}, signal={n_sig}")

        if n_bkg < 20 or n_sig < 20:
            print(f"[T={T}] Not enough background or signal graphs; skipping this threshold.")
            continue

        # For each convolution type, reuse the same pyg_list_T
        for conv_type in conv_types:
            selection_label = f"T{T}GeV"
            print(f"\n>>> Threshold T={T} GeV, conv_type={conv_type}")

            res = run_gae_on_graph_list(
                pyg_list=pyg_list_T,
                dataset_name=dataset_name,
                selection_label=selection_label,
                conv_type=conv_type,
                num_epochs=num_epochs,
                batch_size=batch_size,
                hidden_channels=hidden_channels,
                latent_channels=latent_channels,
                lr=lr,
                weight_decay=weight_decay,
                dropout=dropout,
                random_state=random_state,
                sage_aggr=sage_aggr,
                use_batchnorm=use_batchnorm,
            )
            all_results[conv_type][T] = res

    # After the scan, make AUC-vs-threshold and multi-ROC plots per convolution
    for conv_type in conv_types:
        conv_results = {
            T: (res["fpr"], res["tpr"], res["roc_auc"])
            for T, res in all_results[conv_type].items()
            if res is not None
        }
        if not conv_results:
            continue

        conv_tag = conv_type.upper()
        save_dir = os.path.join(
            GNN_UNSUP_BASE_DIR,
            f"{dataset_name}__{conv_tag}",
            "pt_scan",
        )

        plot_gae_roc_all_thresholds(conv_results, dataset_name, save_dir, conv_tag)
        plot_gae_auc_vs_threshold(conv_results, dataset_name, save_dir, conv_tag)

        # Print best threshold for this conv
        best_T = max(conv_results.keys(), key=lambda T: conv_results[T][2])
        best_auc = conv_results[best_T][2]
        print(
            f"\nBest jet pT threshold for conv={conv_type} on {dataset_name}: "
            f"T={best_T} GeV with AUC={best_auc:.3f}"
        )

    print("\n" + "#" * 80 + "\nFinished full pT scan.\n" + "#" * 80)
    return all_results


### Hyperparameter scan

Takes approximately 1.5 hrs

In [17]:
# Cell X : Hyperparameter scan on a subset of "all_signals"
#                   and show 3 best configs per conv_type

# --------------------------------------------------------------------------------------
# 1. Build or reuse a PyG graph subset for the "all_signals" dataset
# --------------------------------------------------------------------------------------

MAX_SUBSET_GRAPHS = 2000   # small subset for quick scans; change if you want
RANDOM_STATE = 42

if "pyg_list_all_signals_subset" in globals():
    # Reuse existing subset to avoid recomputing graphs
    pyg_subset = pyg_list_all_signals_subset
    print(f"Reusing existing pyg_list_all_signals_subset with {len(pyg_subset)} graphs.")
else:
    # Build graphs from the "all_signals" dataframe
    if "ML_dict" not in globals():
        raise RuntimeError("ML_dict is not defined. Make sure you ran the data-loading cells.")

    print("Building PyG graphs for 'all_signals'...")
    df_all = ML_dict["all_signals"]
    obj_df_all, pyg_list_all_signals = build_obj_df_and_pyg_dataset(df_all)
    print(f"Total graphs in 'all_signals': {len(pyg_list_all_signals)}")

    rng = np.random.default_rng(RANDOM_STATE)
    subset_size = min(MAX_SUBSET_GRAPHS, len(pyg_list_all_signals))
    subset_indices = rng.choice(len(pyg_list_all_signals), size=subset_size, replace=False)
    pyg_subset = [pyg_list_all_signals[i] for i in subset_indices]

    # Cache for later cells, so you don't rebuild
    pyg_list_all_signals_subset = pyg_subset
    print(f"Using subset of {subset_size} graphs for hyperparameter search.")

# --------------------------------------------------------------------------------------
# 2. Define hyperparameter grid for the scan (includes lr_list)
# --------------------------------------------------------------------------------------

conv_types       = ["gcn", "sage", "gin"]
hidden_list      = [32, 64, 128]
latent_list      = [4, 8, 16]
dropout_list     = [0.0, 0.1, 0.2]
batchnorm_list   = [False, True]
sage_aggr_list   = ["mean", "max", "sum"]   # only used for SAGE
lr_list          = [1e-4, 5e-4, 1e-3]   # learning rates to scan

NUM_EPOCHS   = 10
BATCH_SIZE   = 100
WEIGHT_DECAY = 1e-5

search_results = []
config_id = 0

print("\nStarting hyperparameter scan on 'all_signals' subset...\n")

# Add imports for stdout redirection
import contextlib
import io

# Calculate total number of configurations for progress tracking
total_configs = 0
for conv_type in conv_types:
    num_aggrs = len(sage_aggr_list) if conv_type == "sage" else 1
    total_configs += len(hidden_list) * len(latent_list) * len(dropout_list) * len(batchnorm_list) * len(lr_list) * num_aggrs

# --------------------------------------------------------------------------------------
# 3. Run the scan using your existing run_gae_on_graph_list
# --------------------------------------------------------------------------------------

for conv_type in conv_types:
    for hidden_channels, latent_channels, dropout, use_batchnorm, lr in itertools.product(
        hidden_list, latent_list, dropout_list, batchnorm_list, lr_list
    ):
        # For SAGE we scan over aggregators; for GCN/GIN we just fix "mean"
        aggr_list = sage_aggr_list if conv_type == "sage" else ["mean"]

        for sage_aggr in aggr_list:
            config_id += 1

            cfg = dict(
                conv_type=conv_type,
                hidden_channels=hidden_channels,
                latent_channels=latent_channels,
                dropout=dropout,
                use_batchnorm=use_batchnorm,
                sage_aggr=sage_aggr,
                lr=lr,
                num_epochs=NUM_EPOCHS,
                batch_size=BATCH_SIZE,
                weight_decay=WEIGHT_DECAY,
            )

            # Use stdout redirection to suppress output from run_gae_on_graph_list
            f = io.StringIO() # Create a new StringIO object for each config
            with contextlib.redirect_stdout(f): # Redirect stdout within this context
                try:
                    result = run_gae_on_graph_list(
                        pyg_list=pyg_subset,
                        dataset_name="all_signals",
                        selection_label=f"subset_{len(pyg_subset)}",
                        conv_type=conv_type,
                        num_epochs=NUM_EPOCHS,
                        batch_size=BATCH_SIZE,
                        hidden_channels=hidden_channels,
                        latent_channels=latent_channels,
                        lr=lr,
                        weight_decay=WEIGHT_DECAY,
                        dropout=dropout,
                        random_state=RANDOM_STATE,
                        sage_aggr=sage_aggr,
                        use_batchnorm=use_batchnorm,
                        # Avoid generating tons of \u03b7\u2013\u03c6 plots during the scan
                        n_bkg_examples=0,
                        n_sig_examples=0,
                    )
                except Exception as e:
                    # Print failures explicitly to the main stdout
                    print(f"--> Config {config_id} FAILED with error: {e}")
                    continue

            # Print a concise summary for each config after it runs
            # roc_auc = float(result.get("roc_auc", np.nan))
            # print(f"[{config_id:3d}/{total_configs:3d}] "
            #       f"Conv: {conv_type:<4s}, H: {hidden_channels:<3d}, L: {latent_channels:<2d}, D: {dropout:<3.1f}, BN: {str(use_batchnorm):<5s}, LR: {lr:<6.0e}, AUC: {roc_auc:.4f}")

            entry = dict(cfg)
            entry["roc_auc"] = float(result.get("roc_auc", np.nan))
            search_results.append(entry)

# --------------------------------------------------------------------------------------
# 4. Collect and print the best hyperparameters
# --------------------------------------------------------------------------------------

if not search_results:
    print("\nNo successful configurations in the scan.")
else:
    hp_df = pd.DataFrame(search_results)

    # 3 best per conv_type
    best_hparams_all_signals_subset_by_conv = {}

    print("\n================ Best 3 configurations per conv_type ================\n")
    for conv in conv_types:
        sub = hp_df[hp_df["conv_type"] == conv]
        if sub.empty:
            print(f"\nNo successful configs for conv_type = {conv}")
            continue

        sub_sorted = sub.sort_values("roc_auc", ascending=False).reset_index(drop=True)
        top_k = sub_sorted.head(3)

        print(f"\n--- {conv.upper()} ---")
        print(top_k)

        # Store the single best config for this conv_type
        best_hparams_all_signals_subset_by_conv[conv] = top_k.iloc[0].to_dict()

    # Overall best configuration (across all conv_types)
    hp_df_sorted_global = hp_df.sort_values("roc_auc", ascending=False).reset_index(drop=True)
    best_global = hp_df_sorted_global.iloc[0]

    print("\n================ Best configuration overall (all conv_types) ================\n")
    for col in hp_df_sorted_global.columns:
        print(f"{col}: {best_global[col]}")

    # Also store global-best config if you want to reuse it
    best_hparams_all_signals_subset = best_global.to_dict()


Building PyG graphs for 'all_signals'...
Total graphs in 'all_signals': 604611
Using subset of 2000 graphs for hyperparameter search.

Starting hyperparameter scan on 'all_signals' subset...



  self.explained_variance_ratio_ = self.explained_variance_ / total_var
  self.explained_variance_ratio_ = self.explained_variance_ / total_var
  self.explained_variance_ratio_ = self.explained_variance_ / total_var
  self.explained_variance_ratio_ = self.explained_variance_ / total_var





--- GCN ---
  conv_type  hidden_channels  latent_channels  dropout  use_batchnorm  \
0       gcn               64               16      0.1          False   
1       gcn               64               16      0.0          False   
2       gcn               64                8      0.1          False   

  sage_aggr      lr  num_epochs  batch_size  weight_decay   roc_auc  
0      mean  0.0005          10         100       0.00001  0.905542  
1      mean  0.0005          10         100       0.00001  0.904822  
2      mean  0.0010          10         100       0.00001  0.904502  

--- SAGE ---
  conv_type  hidden_channels  latent_channels  dropout  use_batchnorm  \
0      sage              128               16      0.1          False   
1      sage               32                4      0.0          False   
2      sage              128               16      0.0          False   

  sage_aggr      lr  num_epochs  batch_size  weight_decay   roc_auc  
0       sum  0.0010          10    

### Final scan

In [None]:
# Cell 11: Example usage with hyperparameter control

# Example 1: NO pT cut, compare GCN / SAGE / GIN on "Znunu"
#            with dropout, SAGE aggregator, and BatchNorm turned on.
# gae_conv_results_all_signals_nopt = run_gae_compare_convs_single(
#     dataset_name="Znunu",
#     conv_types=("gcn", "sage", "gin"),
#     pt_threshold=None,
#     num_epochs=20,
#     hidden_channels=64,
#     latent_channels=16,
#     lr=1e-3,
#     weight_decay=1e-5,
#     dropout=0.2,
#     random_state=42,
#     sage_aggr="max",
#     use_batchnorm=True,
# )

# Example 2: ONE fixed pT cut (25 GeV), same hyperparams
# gae_conv_results_all_signals_T25 = run_gae_compare_convs_single(
#     dataset_name="all_signals",
#     conv_types=("gcn", "sage", "gin"),
#     pt_threshold=25,
#     num_epochs=10,
#     hidden_channels=64,
#     latent_channels=16,
#     lr=1e-3,
#     weight_decay=1e-5,
#     dropout=0.2,
#     random_state=42,
#     sage_aggr="mean",
#     use_batchnorm=True,
# )

# Example 3: FULL pT scan for all three convolutions (optional, slower)
# gae_conv_results_all_signals_scan = run_gae_compare_convs_pt_scan(
#     dataset_name="all_signals",
#     conv_types=("gcn", "sage", "gin"),
#     pt_thresholds=JET_PT_THRESHOLDS,
#     num_epochs=10,
#     batch_size=128,
#     hidden_channels=64,
#     latent_channels=16,
#     lr=1e-3,
#     weight_decay=1e-5,
#     dropout=0.2,
#     random_state=42,
#     sage_aggr="mean",
#     use_batchnorm=True,
# )


In [None]:
hlt_data_all_tags = ['all_signals', 'HAHMggf', 'HNLeemu', 'HtoSUEP',
'VBF_H125_a55a55_4b_ctau1_filtered', 'Znunu',
'ggF_H125_a16a16_4b_ctau10_filtered', 'hh_bbbb_vbf_novhh_5fs_l1cvv1cv1']

In [None]:
for tag in hlt_data_all_tags:
  gae_conv_results_all_signals_nopt = run_gae_compare_convs_single(
    dataset_name=tag,
    conv_types=("gcn", "sage", "gin"),
    pt_threshold=None,
    num_epochs=20,
    hidden_channels=64,
    latent_channels=16,
    lr=1e-3,
    weight_decay=1e-5,
    dropout=0.2,
    random_state=42,
    sage_aggr="mean",
    use_batchnorm=True,
)

In [None]:
for tag in hlt_data_all_tags:
  gae_conv_results_all_signals_scan = run_gae_compare_convs_pt_scan(
    dataset_name=tag,
    conv_types=("gcn", "sage", "gin"),
    pt_thresholds=JET_PT_THRESHOLDS,
    num_epochs=10,
    batch_size=128,
    hidden_channels=64,
    latent_channels=16,
    lr=1e-3,
    weight_decay=1e-5,
    dropout=0.2,
    random_state=42,
    sage_aggr="mean",
    use_batchnorm=True,
)

**To load the model in a different notebook:**

```
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import os

# --- Step 1: Mount Google Drive (if your model is saved there) ---
# This is necessary if your model is in MyDrive. Skip if not applicable.
from google.colab import drive
drive.mount('/content/drive')

# --- Step 2: Re-define the GraphAutoEncoder class (EXACTLY as it was trained) ---
# This class definition must match the one from cell TgGOtSB2GnKe when the model was trained.
class GraphAutoEncoder(nn.Module):
    def __init__(
        self,
        in_channels=8,
        hidden_channels1=64, # These parameters are part of the __init__ signature
        hidden_channels2=32, # but were not used in the 'shallow' model's layers.
        latent_channels=16,
    ):
        super().__init__()

        # --- Encoder GNN: 1 layer ---
        self.conv1 = GCNConv(in_channels, latent_channels)

        # --- Decoder MLP: 1 layer ---
        self.dec_lin1 = nn.Linear(latent_channels, in_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        z = F.relu(x)  # Latent node embeddings
        return z

    def decode(self, z):
        x_hat = self.dec_lin1(z)
        return x_hat

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        z = self.encode(x, edge_index)
        x_hat = self.decode(z)
        return x_hat, z

# Define the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Step 3: Instantiate the model with the EXACT parameters used during training ---
# These values must match those found in cell `5mQeb_tVI3dq` of the training notebook.
TRAINING_IN_CHANNELS = 8
TRAINING_HIDDEN_CHANNELS1 = 128
TRAINING_HIDDEN_CHANNELS2 = 64
TRAINING_LATENT_CHANNELS = 4

loaded_model = GraphAutoEncoder(
    in_channels=TRAINING_IN_CHANNELS,
    hidden_channels1=TRAINING_HIDDEN_CHANNELS1,
    hidden_channels2=TRAINING_HIDDEN_CHANNELS2,
    latent_channels=TRAINING_LATENT_CHANNELS,
).to(device)

# --- Step 4: Define the path to your saved model file ---
# IMPORTANT: You MUST update this path to where YOUR model is saved.
# The example path below uses the last saved path from this notebook.
model_save_path = "/content/drive/MyDrive/gae_experiments/exp_20251125_100949/gae_model_state_dict.pth"

# --- Step 5 & 6: Load the state dictionary and set to evaluation mode ---
if not os.path.exists(model_save_path):
    print(f"Error: Model file not found at {model_save_path}")
    print("Please ensure Google Drive is mounted and the path is correct.")
else:
    # Load the state dictionary, mapping to the current device
    loaded_state_dict = torch.load(model_save_path, map_location=device)

    # Apply the loaded state dictionary to the model
    loaded_model.load_state_dict(loaded_state_dict)

    # Set the model to evaluation mode (important for inference)
    loaded_model.eval()

    print(f"Model successfully loaded from: {model_save_path}")
    print("Loaded Model Architecture:")
    print(loaded_model)

    # You can now use `loaded_model` for predictions in your new notebook.
```

