# Graph Autoencoder (GAE) for Anomaly Detection on Event Graphs

This notebook builds a flexible **Graph Autoencoder** for anomaly detection,
using **PyTorch Geometric**.

**Design choices:**

- Events → graphs:
  - Nodes = physics objects (jets, e, μ, γ, MET).
  - Node features: `[pT, η, φ, one-hot(type)]`.
  - Empty objects are dropped (based on `pT <= 0`).
  - Fully connected directed graphs (all pairs `i ≠ j`).
  - Edge attributes include:
    - `Δη`, `Δφ` (wrapped to `[-π, π]`), `ΔR`,
    - `log(pT_i + ε)`, `log(pT_j + ε)`,
    - `(pT_i - pT_j) / (pT_i + pT_j + ε)`.

- Labels:
  - `target == "EB_test"` → **background** (label 0).
  - Everything else → **signal** (label 1).
  - Train **only on background**, test on background + signal.

- Convolution types (configurable):
  - `SAGEConv` (ignores edge_attr, baseline).
  - `NNConv` (edge_attr → dynamic filters).
  - `GINEConv` (edge-aware GIN).
  - `TransformerConv` (attention with edge_attr).

- Loss:
  - Node feature **MSE** + optional **exponential loss**:

$$
L = \text{MSE} + \alpha \, \mathbb{E}\big[\exp(\beta \cdot \text{MSE}_\text{node}) - 1\big]
$$


- Visualisations:
  - Training / validation loss curves.
  - Anomaly score histograms + ROC AUC.
  - Latent space (PCA, optional t-SNE).
  - η–φ plots: **input vs reconstruction**, using your provided functions.
  - All plots are saved under a chosen `exp_dir` directory.

- Hyperparameter grid search:
  - Simple loop over a small config grid (conv type, hidden size, latent size, lr, etc.).


In [1]:
# All non-PyTorch imports here

import os
import sys
import math
import copy

import pickle

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import roc_auc_score, roc_curve

# For reproducibility
import random

SEED = 42
np.random.seed(SEED)
random.seed(SEED)


# Colab drive (only needed if you actually run in Google Colab)
# from google.colab import drive
# drive.mount('/content/drive')


In [2]:
modules_to_check = [
    "torch_scatter",
    "torch_sparse",
    "torch_cluster",
    "torch_spline_conv",
    "torch_geometric",
]

print("--- Checking PyTorch Geometric related module installations ---")
for module_name in modules_to_check:
    try:
        __import__(module_name)
        print(f"{module_name}: Installed")
    except ImportError:
        print(f"{module_name}: NOT installed")
print("----------------------------------------------------------")

--- Checking PyTorch Geometric related module installations ---
torch_scatter: NOT installed
torch_sparse: NOT installed
torch_cluster: NOT installed
torch_spline_conv: NOT installed
torch_geometric: NOT installed
----------------------------------------------------------


In [3]:
# Install PyTorch Geometric and dependencies

# Typical Colab install sequence; you may adjust CUDA versions if needed.
!pip install -q torch-scatter torch-sparse torch-cluster torch-geometric

import torch
print("Torch version (after install):", torch.__version__)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m53.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hdone
  Building wheel for torch-cluster (setup.py) ... [?25l[?25hdone
Torch version (aft

In [4]:
# All PyTorch / PyTorch Geometric related imports here

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import random_split

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    SAGEConv,
    NNConv,
    GINEConv,
    TransformerConv,
    global_mean_pool,
)


In [5]:
# Set plotting style at module level
plt.rcParams.update({
    # Font sizes
    'font.size': 18,
    'axes.labelsize': 18,
    'axes.titlesize': 18,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 16,
    'legend.frameon': False,  # No box around legend
    'axes.grid': False,
    # Tick settings
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 10,
    'ytick.major.size': 10,
    'xtick.minor.size': 5,
    'ytick.minor.size': 5,
    'xtick.major.width': 1,
    'ytick.major.width': 1,
    'xtick.top': True,
    'ytick.right': True,
    'xtick.minor.visible': True,
    'ytick.minor.visible': True
})

In [6]:
with open("/kaggle/input/balanced-with-dr-or/balanced_dfs_no_dup_OR.pkl", "rb") as f:
    ML_dict = pickle.load(f)

## Details of the data
ML_dict is a dictionary with the dataframes

```
'all_signals', 'HAHMggf', 'HNLeemu', 'HtoSUEP',
'VBF_H125_a55a55_4b_ctau1_filtered', 'Znunu',
'ggF_H125_a16a16_4b_ctau10_filtered', 'hh_bbbb_vbf_novhh_5fs_l1cvv1cv1'
```
All of them have the same columns:

```
'j0pt', 'j0eta', 'j0phi', 'j1pt', 'j1eta', 'j1phi', 'j2pt', 'j2eta',
       'j2phi', 'j3pt', 'j3eta', 'j3phi', 'j4pt', 'j4eta', 'j4phi', 'j5pt',
       'j5eta', 'j5phi', 'e0pt', 'e0eta', 'e0phi', 'e1pt', 'e1eta', 'e1phi',
       'e2pt', 'e2eta', 'e2phi', 'mu0pt', 'mu0eta', 'mu0phi', 'mu1pt',
       'mu1eta', 'mu1phi', 'mu2pt', 'mu2eta', 'mu2phi', 'ph0pt', 'ph0eta',
       'ph0phi', 'ph1pt', 'ph1eta', 'ph1phi', 'ph2pt', 'ph2eta', 'ph2phi',
       'METpt', 'METeta', 'METphi', 'run_number', 'event_number', 'weight',
       'target'
```
When loaded with `balanced_dfs_no_dup_processed.pkl`, the dataframes contain events for which there are no duplicate objects. Events with undefined METpt have been removed. All events where all objects have 0 pt have been removed. All of them have equal amount of signal and background ('target' == 'EB_test').

When loaded with `balanced_dfs_no_dup_OR.pkl`, in addition to the above mentioned processing, overlap removal has also been performed.

In [7]:
# Define object types and layout of columns in the DataFrame

OBJ_TYPES = ["MET", "e", "j", "mu", "ph"]  # order matters for one-hot and plots

# How many of each object (based on your column naming)
N_JETS = 6   # j0..j5
N_ELEC = 3   # e0..e2
N_MU   = 3   # mu0..mu2
N_PH   = 3   # ph0..ph2

# Specification of objects and their (pt, eta, phi) columns
OBJECT_SPECS = [
    ("MET", [("METpt", "METeta", "METphi")]),
    ("e",   [(f"e{i}pt",  f"e{i}eta",  f"e{i}phi")  for i in range(N_ELEC)]),
    ("j",   [(f"j{i}pt",  f"j{i}eta",  f"j{i}phi")  for i in range(N_JETS)]),
    ("mu",  [(f"mu{i}pt", f"mu{i}eta", f"mu{i}phi") for i in range(N_MU)]),
    ("ph",  [(f"ph{i}pt", f"ph{i}eta", f"ph{i}phi") for i in range(N_PH)]),
]

def object_type_index(obj_type):
    return OBJ_TYPES.index(obj_type)

print("OBJECT_SPECS defined with types:", [t for t, _ in OBJECT_SPECS])


OBJECT_SPECS defined with types: ['MET', 'e', 'j', 'mu', 'ph']


## All helper functions

### Fit node feature scaler (pT, η, φ)

We will standardize node features `[pT, η, φ]` over all **background training** nodes.

Node features per node:
- `[scaled_pT, scaled_η, scaled_φ, one-hot(type)]`

Edge attributes use **log pT** and geometric differences. We keep them in natural units (no extra scaler)
for simplicity, but this can be extended.


In [8]:
# --- Node feature scaler helpers ---

def is_empty_object(pt, eta, phi):
    """
    Criterion to drop empty objects.
    Here we use pT <= 0 as 'empty' (adjust if needed).
    """
    return pt <= 0

def collect_node_features_for_scaler(df):
    """
    Collect [pt, eta, phi] from all non-empty nodes in a dataframe
    to fit the StandardScaler on background training events.
    """
    all_feats = []
    for _, row in df.iterrows():
        for obj_type, fields in OBJECT_SPECS:
            for pt_col, eta_col, phi_col in fields:
                pt = row[pt_col]
                eta = row[eta_col]
                phi = row[phi_col]
                if is_empty_object(pt, eta, phi):
                    continue
                all_feats.append([pt, eta, phi])
    if len(all_feats) == 0:
        return np.zeros((0, 3))
    return np.array(all_feats)

def fit_node_feature_scaler(df_bg_train):
    """
    Fit a StandardScaler for [pt, eta, phi] on background training events.
    """
    X_node_bg_train = collect_node_features_for_scaler(df_bg_train)
    print("Collected node feature samples for scaler:", X_node_bg_train.shape)
    scaler = StandardScaler()
    scaler.fit(X_node_bg_train)
    print("Node feature scaler mean:", scaler.mean_)
    print("Node feature scaler scale:", scaler.scale_)
    return scaler

# We'll assign node_feature_scaler inside the driver
node_feature_scaler = None


### Graph construction utilities

We now define:

- `build_event_graph(row, edge_attr_mode)`:
  - builds `torch_geometric.data.Data` for a single event.

- `build_graph_dataset(df, edge_attr_mode)`:
  - builds a list of graphs from a DataFrame.

Edge attribute modes:
- `"geo"`: `[Δη, Δφ, ΔR]`
- `"geo_pt"`: `[Δη, Δφ, ΔR, log(pT_i + ε), log(pT_j + ε), frac_diff]`


In [9]:
def wrap_delta_phi(dphi):
    """
    Wrap Δφ into [-π, π].
    """
    return (dphi + np.pi) % (2 * np.pi) - np.pi


def build_event_graph(row, edge_attr_mode="geo_pt"):
    """
    Build a PyG Data object for a single event (row from DataFrame).
    Node features: [scaled_pT, scaled_eta, scaled_phi, one-hot(type)]
    Edge attributes: depends on edge_attr_mode.
    """
    global node_feature_scaler
    if node_feature_scaler is None:
        raise RuntimeError("node_feature_scaler is not fitted yet. Call fit_node_feature_scaler first.")

    node_raw_feats = []   # [ [pt, eta, phi, type_idx], ... ]
    labels = int(row["y"])
    weight = float(row.get("weight", 1.0))

    # Collect nodes
    for obj_type, fields in OBJECT_SPECS:
        t_idx = object_type_index(obj_type)
        for pt_col, eta_col, phi_col in fields:
            pt = float(row[pt_col])
            eta = float(row[eta_col])
            phi = float(row[phi_col])
            if is_empty_object(pt, eta, phi):
                continue
            node_raw_feats.append([pt, eta, phi, t_idx])

    if len(node_raw_feats) == 0:
        # If an event has no valid nodes, skip it (rare, but just in case)
        return None

    node_raw_feats = np.array(node_raw_feats)  # [N, 4]
    pts = node_raw_feats[:, 0]
    etas = node_raw_feats[:, 1]
    phis = node_raw_feats[:, 2]
    t_indices = node_raw_feats[:, 3].astype(int)

    # Scale [pt, eta, phi]
    scaled_pep = node_feature_scaler.transform(node_raw_feats[:, :3])  # [N, 3]

    # One-hot encode types
    num_node_types = len(OBJ_TYPES)
    one_hot = np.zeros((len(t_indices), num_node_types), dtype=np.float32)
    one_hot[np.arange(len(t_indices)), t_indices] = 1.0

    x_np = np.concatenate([scaled_pep, one_hot], axis=1).astype(np.float32)  # [N, 3+5]
    x = torch.tensor(x_np, dtype=torch.float32)

    num_nodes = x.shape[0]

    # Fully connected directed edges (i != j)
    src = []
    dst = []
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i == j:
                continue
            src.append(i)
            dst.append(j)
    edge_index = torch.tensor([src, dst], dtype=torch.long)

    # Edge attributes
    edge_attrs = []
    eps = 1e-6

    for s, d in zip(src, dst):
        deta = etas[s] - etas[d]
        dphi = wrap_delta_phi(phis[s] - phis[d])
        dR = math.sqrt(deta**2 + dphi**2)

        if edge_attr_mode == "geo":
            edge_attrs.append([deta, dphi, dR])
        elif edge_attr_mode == "geo_pt":
            log_pt_s = math.log(pts[s] + 1.0)
            log_pt_d = math.log(pts[d] + 1.0)
            frac_diff = (pts[s] - pts[d]) / (pts[s] + pts[d] + eps)
            edge_attrs.append([deta, dphi, dR, log_pt_s, log_pt_d, frac_diff])
        else:
            raise ValueError(f"Unknown edge_attr_mode: {edge_attr_mode}")

    edge_attr = torch.tensor(np.array(edge_attrs, dtype=np.float32), dtype=torch.float32)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=torch.tensor([labels], dtype=torch.long),
        weight=torch.tensor([weight], dtype=torch.float32),
    )

    return data


def build_graph_dataset(df, edge_attr_mode="geo_pt"):
    graphs = []
    labels = []
    sample_labels = []
    for _, row in df.iterrows():
        g = build_event_graph(row, edge_attr_mode=edge_attr_mode)
        if g is None:
            continue
        graphs.append(g)
        labels.append(int(row["y"]))
        sample_labels.append(row.get("sample_label", "unknown"))
    return graphs, np.array(labels), np.array(sample_labels)


In [10]:
def build_conv_layer(conv_type, in_channels, out_channels, edge_dim=None, conv_config=None):
    """
    Factory for different graph conv layers, now with conv-specific hyperparameters
    provided via conv_config (a dict).

    Supported conv_type values:
      - 'sage'
      - 'nn'
      - 'gine'
      - 'transformer'
    """
    conv_type = conv_type.lower()
    cfg = conv_config or {}

    if conv_type == "sage":
        # SAGEConv-specific hyperparameters
        sage_aggr        = cfg.get("sage_aggr", "mean")      # 'mean', 'max', 'add', ...
        sage_normalize   = cfg.get("sage_normalize", False)  # L2-normalize embeddings
        sage_root_weight = cfg.get("sage_root_weight", True)

        # Signature in typical PyG versions: SAGEConv(in_channels, out_channels, aggr='mean', normalize=False, root_weight=True...)
        conv = SAGEConv(
            in_channels,
            out_channels,
            aggr=sage_aggr,
            normalize=sage_normalize,
            root_weight=sage_root_weight,
        )
        return conv

    elif conv_type == "nn":
        # NNConv-specific hyperparameters
        if edge_dim is None:
            raise ValueError("NNConv requires edge_dim when using edge_attr.")

        edge_hidden_dim   = cfg.get("nn_edge_hidden_dim", max(16, edge_dim * 2))
        edge_mlp_layers   = cfg.get("nn_edge_mlp_layers", 2)
        nn_aggr           = cfg.get("nn_aggr", "mean")  # 'mean' or 'add' typically

        # Build edge MLP: edge_attr -> (in_channels * out_channels)
        layers = []
        in_dim = edge_dim
        for i in range(edge_mlp_layers - 1):
            layers.append(nn.Linear(in_dim, edge_hidden_dim))
            layers.append(nn.ReLU())
            in_dim = edge_hidden_dim
        layers.append(nn.Linear(in_dim, in_channels * out_channels))
        edge_mlp = nn.Sequential(*layers)

        conv = NNConv(
            in_channels,
            out_channels,
            edge_mlp,
            aggr=nn_aggr,
        )
        return conv

    elif conv_type == "gine":
        # GINEConv-specific hyperparameters
        gine_mlp_hidden_dim   = cfg.get("gine_mlp_hidden_dim", out_channels)
        gine_mlp_layers       = cfg.get("gine_mlp_layers", 2)
        gine_eps_init         = cfg.get("gine_eps_init", 0.0)
        gine_train_eps        = cfg.get("gine_train_eps", False)

        # Build MLP on node features
        layers = []
        in_dim = in_channels
        for i in range(gine_mlp_layers - 1):
            layers.append(nn.Linear(in_dim, gine_mlp_hidden_dim))
            layers.append(nn.ReLU())
            in_dim = gine_mlp_hidden_dim
        layers.append(nn.Linear(in_dim, out_channels))
        mlp = nn.Sequential(*layers)

        conv = GINEConv(
            nn=mlp,
            eps=gine_eps_init,
            train_eps=gine_train_eps,
            edge_dim=edge_dim,
        )
        return conv

    elif conv_type == "transformer":
        # TransformerConv-specific hyperparameters
        tr_heads   = cfg.get("tr_heads", 1)
        tr_concat  = cfg.get("tr_concat", False)  # for simplicity keep False so out_dim == out_channels
        tr_dropout = cfg.get("tr_dropout", 0.0)
        tr_beta    = cfg.get("tr_beta", False)

        conv = TransformerConv(
            in_channels,
            out_channels,
            heads=tr_heads,
            concat=tr_concat,
            beta=tr_beta,
            dropout=tr_dropout,
            edge_dim=edge_dim,
        )
        return conv

    else:
        raise ValueError(f"Unknown conv_type: {conv_type}")


class GraphEncoder(nn.Module):
    def __init__(
        self,
        in_channels,
        edge_dim,
        hidden_dim=64,
        latent_dim=16,
        num_layers=3,
        conv_type="gine",
        dropout=0.0,
        conv_config=None,
    ):
        super().__init__()
        self.conv_type = conv_type.lower()
        self.dropout_layer = nn.Dropout(dropout)
        self.activ = nn.ReLU()
        self.conv_config = conv_config or {}

        layers = []
        # Example: dims = [in_channels, hidden_dim, ..., latent_dim]
        dims = [in_channels] + [hidden_dim] * (num_layers - 1) + [latent_dim]

        for l in range(len(dims) - 1):
            in_dim = dims[l]
            out_dim = dims[l + 1]
            conv = build_conv_layer(
                self.conv_type,
                in_dim,
                out_dim,
                edge_dim=edge_dim,
                conv_config=self.conv_config,
            )
            layers.append(conv)

        self.layers = nn.ModuleList(layers)

    def forward(self, x, edge_index, edge_attr=None):
        for conv in self.layers:
            if self.conv_type == "sage":
                x = conv(x, edge_index)
            else:
                x = conv(x, edge_index, edge_attr)
            x = self.activ(x)
            x = self.dropout_layer(x)
        return x  # node-level latent features


class GAEModel(nn.Module):
    def __init__(
        self,
        in_channels,
        edge_dim,
        hidden_dim=64,
        latent_dim=16,
        num_layers=3,
        conv_type="gine",
        dropout=0.0,
        decoder_hidden_dim=64,
        conv_config=None,  # NEW
    ):
        super().__init__()
        self.encoder = GraphEncoder(
            in_channels=in_channels,
            edge_dim=edge_dim,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            num_layers=num_layers,
            conv_type=conv_type,
            dropout=dropout,
            conv_config=conv_config,
        )

        # Simple MLP decoder: latent -> reconstructed node features
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, decoder_hidden_dim),
            nn.ReLU(),
            nn.Linear(decoder_hidden_dim, in_channels),
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        edge_attr = getattr(data, "edge_attr", None)
        z = self.encoder(x, edge_index, edge_attr)
        x_hat = self.decoder(z)
        return x_hat, z

    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        edge_attr = getattr(data, "edge_attr", None)
        return self.encoder(x, edge_index, edge_attr)


In [11]:
# ---- Reconstruction loss: MSE + optional exponential term ----

def reconstruction_loss(x, x_hat, use_exponential=True, alpha=1.0, beta=1.0):
    """
    x, x_hat: [N_nodes, n_features]
    L = MSE + alpha * E[exp(beta * MSE_node) - 1]
    where MSE_node is the per-node mean squared error over features.
    """
    # Per-node MSE over features
    mse_per_node = torch.mean((x_hat - x) ** 2, dim=-1)   # [N_nodes]
    mse = mse_per_node.mean()

    if use_exponential:
        exp_term = torch.mean(torch.exp(beta * mse_per_node) - 1.0)
        loss = mse + alpha * exp_term
    else:
        # Keep a tensor on the same device for consistency
        exp_term = torch.zeros(1, device=x.device)
        loss = mse

    details = {
        "mse": mse.item(),
        "exp_term": exp_term.item(),
    }
    return loss, details


In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def train_epoch(model, loader, optimizer, use_exponential=False, alpha=1.0, beta=1.0):
    model.train()
    total_loss = 0.0
    total_mse = 0.0
    total_exp = 0.0
    count = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        x_hat, z = model(data)
        loss, details = reconstruction_loss(
            data.x, x_hat,
            use_exponential=use_exponential,
            alpha=alpha,
            beta=beta,
        )
        loss.backward()
        optimizer.step()

        batch_size_nodes = data.x.size(0)
        total_loss += loss.item() * batch_size_nodes
        total_mse  += details["mse"] * batch_size_nodes
        if "exp" in details:
            total_exp += details["exp"] * batch_size_nodes
        count += batch_size_nodes

    avg_loss = total_loss / count
    avg_mse  = total_mse  / count
    avg_exp  = total_exp  / count if use_exponential else 0.0

    return {
        "loss": avg_loss,
        "mse": avg_mse,
        "exp": avg_exp,
    }


def eval_epoch(model, loader, use_exponential=False, alpha=1.0, beta=1.0):
    model.eval()
    total_loss = 0.0
    total_mse = 0.0
    total_exp = 0.0
    count = 0

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            x_hat, z = model(data)
            loss, details = reconstruction_loss(
                data.x, x_hat,
                use_exponential=use_exponential,
                alpha=alpha,
                beta=beta,
            )

            batch_size_nodes = data.x.size(0)
            total_loss += loss.item() * batch_size_nodes
            total_mse  += details["mse"] * batch_size_nodes
            if "exp" in details:
                total_exp += details["exp"] * batch_size_nodes
            count += batch_size_nodes

    avg_loss = total_loss / count
    avg_mse  = total_mse  / count
    avg_exp  = total_exp  / count if use_exponential else 0.0

    return {
        "loss": avg_loss,
        "mse": avg_mse,
        "exp": avg_exp,
    }


def train_model(
    model,
    train_loader,
    val_loader,
    epochs=50,
    lr=1e-3,
    weight_decay=0.0,
    use_exponential=False,
    alpha=1.0,
    beta=1.0,
):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_mse": [],
        "val_mse": [],
    }

    best_val_loss = float("inf")
    best_state = None

    for epoch in range(1, epochs + 1):
        train_stats = train_epoch(
            model,
            train_loader,
            optimizer,
            use_exponential=use_exponential,
            alpha=alpha,
            beta=beta,
        )
        val_stats = eval_epoch(
            model,
            val_loader,
            use_exponential=use_exponential,
            alpha=alpha,
            beta=beta,
        )

        history["train_loss"].append(train_stats["loss"])
        history["val_loss"].append(val_stats["loss"])
        history["train_mse"].append(train_stats["mse"])
        history["val_mse"].append(val_stats["mse"])

        if val_stats["loss"] < best_val_loss:
            best_val_loss = val_stats["loss"]
            best_state = model.state_dict()

        # print(
        #     f"Epoch {epoch:03d} | "
        #     f"Train loss: {train_stats['loss']:.4f}, Val loss: {val_stats['loss']:.4f}"
        # )

    # Load best weights
    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history


Using device: cuda


### Anomaly scores and ROC AUC

We compute:
- per-node reconstruction error,
- per-graph anomaly score (mean reconstruction error),
- ROC AUC for background (0) vs signal (1) on the test set,
- histograms of anomaly scores for BG vs signal.


In [13]:
def compute_anomaly_scores(model, loader):
    """
    Compute per-graph anomaly scores (mean node-wise MSE) and labels.
    """
    model.eval()
    all_scores = []
    all_labels = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            x_hat, z = model(data)

            # Node-wise MSE
            mse_per_node = ((data.x - x_hat) ** 2).mean(dim=1)  # [N_nodes]
            # Pool per graph
            graph_scores = global_mean_pool(mse_per_node, data.batch)  # [N_graphs]

            all_scores.extend(graph_scores.cpu().numpy().tolist())
            all_labels.extend(data.y.cpu().numpy().tolist())

    return np.array(all_scores), np.array(all_labels)


# test_scores, test_labels_arr = compute_anomaly_scores(model, test_loader)
# print("Anomaly scores shape:", test_scores.shape)
# print("Test labels shape:", test_labels_arr.shape)

# # Compute ROC AUC (signal label = 1)
# auc = roc_auc_score(test_labels_arr, test_scores)
# print("Test ROC AUC:", auc)

# # Plot histogram of scores
# plt.figure(figsize=(6, 4))
# mask_bg = (test_labels_arr == 0)
# mask_sig = (test_labels_arr == 1)
# plt.hist(test_scores[mask_bg], bins=50, alpha=0.6, label="Background", density=True)
# plt.hist(test_scores[mask_sig], bins=50, alpha=0.6, label="Signal",     density=True)
# plt.xlabel("Anomaly score (mean node MSE)")
# plt.ylabel("Density")
# plt.title(f"Anomaly Score Distribution (AUC = {auc:.3f})")
# plt.legend()
# plt.grid(alpha=0.3)

# scores_hist_path = os.path.join(exp_dir, f"anomaly_scores_hist_{config['conv_type']}.png")
# plt.savefig(scores_hist_path, dpi=150, bbox_inches="tight")
# plt.close()
# print("Saved anomaly score histogram to:", scores_hist_path)


In [14]:
def extract_graph_latent(model, loader):
    """
    Returns:
        Z_graph: [num_graphs, latent_dim]
        labels: [num_graphs]
    """
    model.eval()
    Z_graph = []
    labels = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            z_node = model.encode(data)  # [N_nodes, latent_dim]
            z_graph = global_mean_pool(z_node, data.batch)  # [N_graphs, latent_dim]
            Z_graph.append(z_graph.cpu().numpy())
            labels.append(data.y.cpu().numpy())

    Z_graph = np.concatenate(Z_graph, axis=0)
    labels = np.concatenate(labels, axis=0)
    return Z_graph, labels

# Z_graph, labels_graph = extract_graph_latent(model, test_loader)
# print("Graph latent shape:", Z_graph.shape)

# # PCA to 2D
# pca = PCA(n_components=2, random_state=SEED)
# Z_pca = pca.fit_transform(Z_graph)

# plt.figure(figsize=(6, 5))
# mask_bg = (labels_graph == 0)
# mask_sig = (labels_graph == 1)

# plt.scatter(Z_pca[mask_bg, 0], Z_pca[mask_bg, 1], s=20, alpha=0.6, label="Background")
# plt.scatter(Z_pca[mask_sig, 0], Z_pca[mask_sig, 1], s=20, alpha=0.6, label="Signal")

# plt.xlabel("PC1")
# plt.ylabel("PC2")
# plt.title(f"Latent PCA (conv = {config['conv_type']})")
# plt.legend()
# plt.grid(alpha=0.3)

# pca_plot_path = os.path.join(exp_dir, f"latent_pca_{config['conv_type']}.png")
# plt.savefig(pca_plot_path, dpi=150, bbox_inches="tight")
# plt.close()
# print("Saved PCA latent plot to:", pca_plot_path)

# # Optional: t-SNE (can be slow on large sets). Uncomment if needed.
# # tsne = TSNE(n_components=2, random_state=SEED, perplexity=30)
# # Z_tsne = tsne.fit_transform(Z_graph)
# #
# # plt.figure(figsize=(6, 5))
# # plt.scatter(Z_tsne[mask_bg, 0], Z_tsne[mask_bg, 1], s=20, alpha=0.6, label="Background")
# # plt.scatter(Z_tsne[mask_sig, 0], Z_tsne[mask_sig, 1], s=20, alpha=0.6, label="Signal")
# # plt.xlabel("t-SNE 1")
# # plt.ylabel("t-SNE 2")
# # plt.title(f"Latent t-SNE (conv = {config['conv_type']})")
# # plt.legend()
# # plt.grid(alpha=0.3)
# #
# # tsne_plot_path = os.path.join(exp_dir, f"latent_tsne_{config['conv_type']}.png")
# # plt.savefig(tsne_plot_path, dpi=150, bbox_inches="tight")
# # plt.close()
# # print("Saved t-SNE latent plot to:", tsne_plot_path)


In [15]:
def plot_eta_phi_input_vs_reco_and_save(
    data,
    model,
    device,
    exp_dir,
    filename_prefix,
    title_suffix="",
):
    """
    Visualise one event as η–φ scatter plots (input vs reconstruction) and save it.

    IMPORTANT: we clone `data` before moving to device to avoid mutating the
    original dataset graphs (which would cause CPU/GPU mixing inside DataLoader).
    """
    model.eval()
    with torch.no_grad():
        # Make a deep copy so we do NOT modify the original graph in test_graphs
        batch = copy.deepcopy(data).to(device)
        x_hat, z = model(batch)
        x_hat = x_hat.cpu().numpy()

    # Use the original CPU graph for inputs
    x_in = data.x.cpu().numpy()

    # Input features
    pt_in  = x_in[:, 0]
    eta_in = x_in[:, 1]
    phi_in = x_in[:, 2]

    # Reconstructed features
    pt_out  = x_hat[:, 0]
    eta_out = x_hat[:, 1]
    phi_out = x_hat[:, 2]

    # Object type from one-hot
    try:
        OBJ_TYPES  # noqa: F823
    except NameError:
        OBJ_TYPES = ["MET", "e", "j", "mu", "ph"]

    onehot_in = x_in[:, 3:]
    type_idx  = onehot_in.argmax(axis=1)

    cmap = mpl.colormaps.get_cmap("tab10")

    def pt_to_size(pt):
        pt = np.clip(pt, 0, None)
        if pt.max() <= 0:
            return np.full_like(pt, 30.0)
        pt_norm = pt / (pt.max() + 1e-8)
        return 20.0 + 80.0 * pt_norm

    sizes_in  = pt_to_size(pt_in)
    sizes_out = pt_to_size(pt_out)

    eta_min = min(eta_in.min(), eta_out.min())
    eta_max = max(eta_in.max(), eta_out.max())
    phi_min = min(phi_in.min(), phi_out.min())
    phi_max = max(phi_in.max(), phi_out.max())

    node_colors = [cmap(int(i)) for i in type_idx]

    fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

    # ---- Left: input ----
    ax = axes[0]
    ax.scatter(
        eta_in,
        phi_in,
        s=sizes_in,
        c=node_colors,
        alpha=0.8,
        edgecolors="k",
    )
    ax.set_xlabel(r"$\eta$")
    ax.set_ylabel(r"$\phi$")
    ax.set_title("Input features " + title_suffix)
    ax.set_xlim(eta_min, eta_max)
    ax.set_ylim(phi_min, phi_max)
    ax.grid(True, alpha=0.3)

    # ---- Right: reconstruction ----
    ax = axes[1]
    ax.scatter(
        eta_out,
        phi_out,
        s=sizes_out,
        c=node_colors,
        alpha=0.8,
        edgecolors="k",
    )
    ax.set_xlabel(r"$\eta$")
    ax.set_ylabel(r"$\phi$")
    ax.set_title("Reconstructed features " + title_suffix)
    ax.set_xlim(eta_min, eta_max)
    ax.set_ylim(phi_min, phi_max)
    ax.grid(True, alpha=0.3)

    # Legend for object types
    handles = []
    labels  = []
    for i, name in enumerate(OBJ_TYPES):
        handles.append(
            plt.Line2D(
                [], [],
                marker="o",
                linestyle="",
                markersize=8,
                markerfacecolor=cmap(i),
                markeredgecolor="k",
            )
        )
        labels.append(name)

    fig.legend(
        handles,
        labels,
        title="Object type",
        loc="upper center",
        bbox_to_anchor=(0.5, 0.02),
        ncol=len(OBJ_TYPES),
    )

    os.makedirs(exp_dir, exist_ok=True)
    save_path = os.path.join(exp_dir, f"{filename_prefix}.png")
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved {filename_prefix} plot to: {save_path}")


def plot_eta_phi_examples(
    model,
    device,
    test_graphs,
    test_labels,
    exp_dir,
    max_bkg=3,
    max_sig=3,
):
    """
    Pick up to `max_bkg` background and `max_sig` signal graphs from the test set
    and make η–φ plots for each.

    Note: Uses the updated `plot_eta_phi_input_vs_reco_and_save` which clones graphs.
    """
    test_labels = np.array(test_labels)

    bkg_indices = np.where(test_labels == 0)[0]
    sig_indices = np.where(test_labels == 1)[0]

    print(f"Found {len(bkg_indices)} background and {len(sig_indices)} signal test events.")

    n_bkg_to_show = min(max_bkg, len(bkg_indices))
    n_sig_to_show = min(max_sig, len(sig_indices))

    print(f"Showing {n_bkg_to_show} background and {n_sig_to_show} signal events.")

    for k in range(n_bkg_to_show):
        idx = bkg_indices[k]
        print(f"\nBackground example {k+1} (test index = {idx})")
        example_bkg = test_graphs[idx]
        plot_eta_phi_input_vs_reco_and_save(
            example_bkg,
            model,
            device,
            exp_dir,
            filename_prefix=f"bkg_example_{k+1}_eta_phi",
            title_suffix="(Background)",
        )

    for k in range(n_sig_to_show):
        idx = sig_indices[k]
        print(f"\nSignal example {k+1} (test index = {idx})")
        example_sig = test_graphs[idx]
        plot_eta_phi_input_vs_reco_and_save(
            example_sig,
            model,
            device,
            exp_dir,
            filename_prefix=f"sig_example_{k+1}_eta_phi",
            title_suffix="(Signal)",
        )


In [16]:
def run_full_pipeline(
    ML_dict,
    dataset="all_signals",
    conv_type="gine",          # "sage", "nn", "gine", "transformer"
    edge_attr_mode="geo_pt",   # "geo" or "geo_pt"
    exp_root="/content/drive/MyDrive/Datasets/GAEwEdge",
    batch_size=64,
    hidden_dim=64,
    latent_dim=16,
    num_layers=3,
    dropout=0.0,
    decoder_hidden_dim=64,
    epochs=50,
    lr=1e-3,
    weight_decay=0.0,
    use_exponential=True,
    alpha=1.0,
    beta=1.0,
    max_bkg_eta_phi=3,
    max_sig_eta_phi=3,
    conv_config=None,
):
    """
    Main driver function: run the full GAE pipeline for a given dataset key in ML_dict.
    """

    print(f"\n=== Running pipeline for dataset = '{dataset}' ===")
    if dataset not in ML_dict:
        raise KeyError(f"Dataset key '{dataset}' not found in ML_dict. Available: {list(ML_dict.keys())}")

    # 1) Select dataframe for this dataset
    df = ML_dict[dataset].copy()
    print("Dataset shape:", df.shape)

    # 2) Background vs signal split
    BG_LABEL = "EB_test"
    df_bg = df[df["target"] == BG_LABEL].reset_index(drop=True)
    df_sig = df[df["target"] != BG_LABEL].reset_index(drop=True)

    df_bg["y"] = 0
    df_sig["y"] = 1

    print("Background events:", len(df_bg))
    print("Signal events:", len(df_sig))

    # 3) Train/val/test split (BG only for train/val, BG+SIG for test)
    df_bg_train, df_bg_temp = train_test_split(
        df_bg, test_size=0.3, random_state=SEED, shuffle=True
    )
    df_bg_val, df_bg_test_bg = train_test_split(
        df_bg_temp, test_size=0.5, random_state=SEED, shuffle=True
    )

    df_test = pd.concat([df_bg_test_bg, df_sig], ignore_index=True)

    print("Train BG:", len(df_bg_train))
    print("Val BG:", len(df_bg_val))
    print("Test BG:", len(df_bg_test_bg))
    print("Total test events:", len(df_test))
    print("Test label counts:\n", df_test["y"].value_counts())

    # 4) Fit node feature scaler on background training events
    global node_feature_scaler
    node_feature_scaler = fit_node_feature_scaler(df_bg_train)

    # 5) Build graphs
    train_graphs, train_labels, train_samples = build_graph_dataset(df_bg_train, edge_attr_mode=edge_attr_mode)
    val_graphs,   val_labels,   val_samples   = build_graph_dataset(df_bg_val,   edge_attr_mode=edge_attr_mode)
    test_graphs,  test_labels,  test_samples  = build_graph_dataset(df_test,     edge_attr_mode=edge_attr_mode)

    print(f"Train graphs: {len(train_graphs)}")
    print(f"Val graphs:   {len(val_graphs)}")
    print(f"Test graphs:  {len(test_graphs)}")

    # 6) DataLoaders
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_graphs,   batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_graphs,  batch_size=batch_size, shuffle=False)

    # 7) Experiment directory for this dataset + conv_type
    exp_dir = os.path.join(exp_root, f"{dataset}_{conv_type}")
    os.makedirs(exp_dir, exist_ok=True)
    print("Experiment directory:", exp_dir)

    # 8) Infer input & edge dimensions
    in_channels = train_graphs[0].x.shape[1]
    edge_dim = train_graphs[0].edge_attr.shape[1]

    print("in_channels (node features):", in_channels)
    print("edge_dim (edge attributes):", edge_dim)

    # 9) Build and train model
    model = GAEModel(
        in_channels=in_channels,
        edge_dim=edge_dim,
        hidden_dim=hidden_dim,
        latent_dim=latent_dim,
        num_layers=num_layers,
        conv_type=conv_type,
        dropout=dropout,
        decoder_hidden_dim=decoder_hidden_dim,
    )

    config = {
        "conv_type": conv_type,
        "hidden_dim": hidden_dim,
        "latent_dim": latent_dim,
        "num_layers": num_layers,
        "dropout": dropout,
        "decoder_hidden_dim": decoder_hidden_dim,
        "epochs": epochs,
        "lr": lr,
        "weight_decay": weight_decay,
        "use_exponential": use_exponential,
        "alpha": alpha,
        "beta": beta,
        "edge_attr_mode": edge_attr_mode,
        "dataset": dataset,
        "batch_size": batch_size,
    }

    print("\nTraining configuration:", config)

    model, history = train_model(
        model,
        train_loader,
        val_loader,
        epochs=epochs,
        lr=lr,
        weight_decay=weight_decay,
        use_exponential=use_exponential,
        alpha=alpha,
        beta=beta,
    )

    # 10) Plot & save training/validation loss
    epochs_arr = np.arange(1, epochs + 1)
    plt.figure(figsize=(6, 4))
    plt.plot(epochs_arr, history["train_loss"], label="Train loss")
    plt.plot(epochs_arr, history["val_loss"],   label="Val loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Training vs Validation Loss ({conv_type}, dataset={dataset})")
    plt.legend()
    plt.grid(alpha=0.3)

    loss_plot_path = os.path.join(exp_dir, f"loss_curve_{dataset}_{conv_type}.png")
    plt.savefig(loss_plot_path, dpi=150, bbox_inches="tight")
    plt.close()
    print("Saved loss curve to:", loss_plot_path)

    # 11) Anomaly scores, ROC, and histogram
    test_scores, test_labels_arr = compute_anomaly_scores(model, test_loader)
    auc = roc_auc_score(test_labels_arr, test_scores)
    print("Test ROC AUC:", auc)

    plt.figure(figsize=(6, 4))
    mask_bg = (test_labels_arr == 0)
    mask_sig = (test_labels_arr == 1)
    plt.hist(test_scores[mask_bg], bins=50, alpha=0.6, label="Background", density=True)
    plt.hist(test_scores[mask_sig], bins=50, alpha=0.6, label="Signal",     density=True)
    plt.xlabel("Anomaly score (mean node MSE)")
    plt.ylabel("Density")
    plt.title(f"Anomaly Score Dist (AUC = {auc:.3f}, {dataset}, {conv_type})")
    plt.legend()
    plt.grid(alpha=0.3)

    scores_hist_path = os.path.join(exp_dir, f"anomaly_scores_hist_{dataset}_{conv_type}.png")
    plt.savefig(scores_hist_path, dpi=150, bbox_inches="tight")
    plt.close()
    print("Saved anomaly score histogram to:", scores_hist_path)

    # 12) Latent PCA
    Z_graph, labels_graph = extract_graph_latent(model, test_loader)
    print("Graph latent shape:", Z_graph.shape)

    pca = PCA(n_components=2, random_state=SEED)
    Z_pca = pca.fit_transform(Z_graph)

    plt.figure(figsize=(6, 5))
    mask_bg_p = (labels_graph == 0)
    mask_sig_p = (labels_graph == 1)
    plt.scatter(Z_pca[mask_bg_p, 0], Z_pca[mask_bg_p, 1], s=20, alpha=0.6, label="Background")
    plt.scatter(Z_pca[mask_sig_p, 0], Z_pca[mask_sig_p, 1], s=20, alpha=0.6, label="Signal")

    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.title(f"Latent PCA ({dataset}, {conv_type})")
    plt.legend()
    plt.grid(alpha=0.3)

    pca_plot_path = os.path.join(exp_dir, f"latent_pca_{dataset}_{conv_type}.png")
    plt.savefig(pca_plot_path, dpi=150, bbox_inches="tight")
    plt.close()
    print("Saved PCA latent plot to:", pca_plot_path)

    # 13) η–φ examples (input vs reconstruction)
    plot_eta_phi_examples(
        model,
        device,
        test_graphs,
        test_labels,
        exp_dir,
        max_bkg=max_bkg_eta_phi,
        max_sig=max_sig_eta_phi,
    )

    # 14) Return everything useful
    results = {
        "model": model,
        "exp_dir": exp_dir,
        "config": config,
        "history": history,
        "test_auc": auc,
        "test_scores": test_scores,
        "test_labels": test_labels_arr,
    }
    print("\n=== Pipeline finished for dataset =", dataset, "===")
    return results


## Hyperparameter grid search

We now define a small grid search over:
- `conv_type` ∈ {`"sage"`, `"nn"`, `"gine"`, `"transformer"`}
- `hidden_dim`, `latent_dim`, `lr`

Metric: **validation reconstruction loss** (background-only).
You can extend / shrink this grid as needed.


In [22]:
def prepare_loaders_for_dataset(
    ML_dict,
    dataset="all_signals",
    edge_attr_mode="geo_pt",
    batch_size=64,
    max_bg_events_for_gs=2000,   # NEW: limit number of background events
):
    """
    Prepare train/val DataLoaders, in_channels, and edge_dim for a given dataset key in ML_dict,
    using ONLY a subset of background events (for fast grid search).

    - We randomly sample up to `max_bg_events_for_gs` background events.
    - Then we split that subset into train/val and build graphs.

    This is meant specifically for grid search, not for final training.
    """
    print(f"\n=== Preparing loaders for dataset = '{dataset}' (for grid search) ===")
    if dataset not in ML_dict:
        raise KeyError(f"Dataset key '{dataset}' not found in ML_dict. Available: {list(ML_dict.keys())}")

    df = ML_dict[dataset].copy()
    BG_LABEL = "EB_test"

    df_bg = df[df["target"] == BG_LABEL].reset_index(drop=True)
    df_sig = df[df["target"] != BG_LABEL].reset_index(drop=True)

    df_bg["y"] = 0
    df_sig["y"] = 1  # signal unused here, but kept for consistency

    print("Total background events available:", len(df_bg))
    print("Total signal events (unused for grid search):", len(df_sig))

    # ---- Subsample background for grid search ----
    if max_bg_events_for_gs is not None and len(df_bg) > max_bg_events_for_gs:
        df_bg = df_bg.sample(
            n=max_bg_events_for_gs,
            random_state=SEED,
            replace=False,
        ).reset_index(drop=True)
        print(f"Subsampled background events to: {len(df_bg)} for grid search.")
    else:
        print("Using all available background events for grid search subset.")

    # ---- Train/val split on this subset ----
    df_bg_train, df_bg_val = train_test_split(
        df_bg, test_size=0.3, random_state=SEED, shuffle=True
    )

    print("Grid-search Train BG:", len(df_bg_train))
    print("Grid-search Val BG:", len(df_bg_val))

    # Fit scaler on training background subset
    global node_feature_scaler
    node_feature_scaler = fit_node_feature_scaler(df_bg_train)

    # Build graphs
    train_graphs, _, _ = build_graph_dataset(df_bg_train, edge_attr_mode=edge_attr_mode)
    val_graphs,   _, _ = build_graph_dataset(df_bg_val,   edge_attr_mode=edge_attr_mode)

    print(f"Train graphs: {len(train_graphs)}")
    print(f"Val graphs:   {len(val_graphs)}")

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_graphs,   batch_size=batch_size, shuffle=False)

    in_channels = train_graphs[0].x.shape[1]
    edge_dim    = train_graphs[0].edge_attr.shape[1]

    print("in_channels (node features):", in_channels)
    print("edge_dim (edge attributes):", edge_dim)

    return train_loader, val_loader, in_channels, edge_dim


In [18]:
from itertools import product

def run_extended_grid_search(
    train_loader,
    val_loader,
    in_channels,
    edge_dim,
    dataset_name="all_signals",
    use_exponential=True,
    alpha=1.0,
    beta=1.0,
):
    """
    Extended grid search over:
      - Global hyperparameters: hidden_dim, latent_dim, num_layers, lr, dropout
      - Convolution-specific hyperparameters (small grids per conv_type)

    Returns:
      results_df: DataFrame of all runs
      best_by_conv: dict conv_type -> best row (as dict)
    """
    conv_types = ["sage", "nn", "gine", "transformer"]

    # Global hyperparameter ranges (keep modest for Colab)
    hidden_dims   = [32, 64, 128]
    latent_dims   = [8, 16, 32]
    num_layers_l  = [2, 3]
    lrs           = [1e-4, 5e-4, 1e-3]
    dropouts      = [0.0, 0.1, 0.2]

    results = []

    def run_single_config(
        conv_type,
        global_cfg,
        conv_cfg,
        epochs=10,  # shorten for grid search
    ):
        """
        Train a single configuration and return best validation loss.
        """
        print(
            f"\n=== conv={conv_type}, "
            f"hid={global_cfg['hidden_dim']}, lat={global_cfg['latent_dim']}, "
            f"layers={global_cfg['num_layers']}, lr={global_cfg['lr']}, "
            f"drop={global_cfg['dropout']}, conv_cfg={conv_cfg} ==="
        )

        model_gs = GAEModel(
            in_channels=in_channels,
            edge_dim=edge_dim,
            hidden_dim=global_cfg["hidden_dim"],
            latent_dim=global_cfg["latent_dim"],
            num_layers=global_cfg["num_layers"],
            conv_type=conv_type,
            dropout=global_cfg["dropout"],
            decoder_hidden_dim=global_cfg["hidden_dim"],
            conv_config=conv_cfg,
        )

        model_gs, hist_gs = train_model(
            model_gs,
            train_loader,
            val_loader,
            epochs=epochs,
            lr=global_cfg["lr"],
            weight_decay=0.0,
            use_exponential=use_exponential,
            alpha=alpha,
            beta=beta,
        )

        best_val_loss = min(hist_gs["val_loss"])
        print(f"--> Best val loss for this config: {best_val_loss:.6f}")
        return best_val_loss

    for conv_type in conv_types:
        # Define a small conv-specific grid for this conv_type
        if conv_type == "sage":
            conv_specific_grid = [
                {"sage_aggr": "mean"},
                {"sage_aggr": "max"},
            ]
        elif conv_type == "nn":
            conv_specific_grid = [
                {"nn_edge_hidden_dim": 32, "nn_edge_mlp_layers": 1, "nn_aggr": "mean"},
                {"nn_edge_hidden_dim": 64, "nn_edge_mlp_layers": 2, "nn_aggr": "mean"},
            ]
        elif conv_type == "gine":
            conv_specific_grid = [
                {
                    "gine_mlp_hidden_dim": None,  # None -> default = out_channels
                    "gine_mlp_layers": 2,
                    "gine_train_eps": False,
                    "gine_eps_init": 0.0,
                },
                {
                    "gine_mlp_hidden_dim": None,
                    "gine_mlp_layers": 3,
                    "gine_train_eps": True,
                    "gine_eps_init": 0.0,
                },
            ]
        elif conv_type == "transformer":
            conv_specific_grid = [
                {"tr_heads": 1, "tr_concat": False, "tr_dropout": 0.0, "tr_beta": False},
                {"tr_heads": 2, "tr_concat": False, "tr_dropout": 0.2, "tr_beta": True},
            ]
        else:
            conv_specific_grid = [{}]

        # Global hyperparameters
        for hdim, ldim, nl, lr, do in product(
            hidden_dims, latent_dims, num_layers_l, lrs, dropouts
        ):
            global_cfg = {
                "hidden_dim": hdim,
                "latent_dim": ldim,
                "num_layers": nl,
                "lr": lr,
                "dropout": do,
            }

            for conv_cfg_base in conv_specific_grid:
                # Fill in defaults where needed
                conv_cfg = dict(conv_cfg_base)  # copy
                # If GINE-specific MLP hidden dim is None, set to hidden_dim
                if conv_type == "gine" and conv_cfg.get("gine_mlp_hidden_dim") is None:
                    conv_cfg["gine_mlp_hidden_dim"] = hdim

                best_val_loss = run_single_config(conv_type, global_cfg, conv_cfg)

                results.append({
                    "dataset": dataset_name,
                    "conv_type": conv_type,
                    "hidden_dim": hdim,
                    "latent_dim": ldim,
                    "num_layers": nl,
                    "lr": lr,
                    "dropout": do,
                    "best_val_loss": best_val_loss,
                    # Conv-specific stuff (some columns will be NaN for other convs)
                    "sage_aggr": conv_cfg.get("sage_aggr"),
                    "nn_edge_hidden_dim": conv_cfg.get("nn_edge_hidden_dim"),
                    "nn_edge_mlp_layers": conv_cfg.get("nn_edge_mlp_layers"),
                    "nn_aggr": conv_cfg.get("nn_aggr"),
                    "gine_mlp_hidden_dim": conv_cfg.get("gine_mlp_hidden_dim"),
                    "gine_mlp_layers": conv_cfg.get("gine_mlp_layers"),
                    "gine_train_eps": conv_cfg.get("gine_train_eps"),
                    "gine_eps_init": conv_cfg.get("gine_eps_init"),
                    "tr_heads": conv_cfg.get("tr_heads"),
                    "tr_concat": conv_cfg.get("tr_concat"),
                    "tr_dropout": conv_cfg.get("tr_dropout"),
                    "tr_beta": conv_cfg.get("tr_beta"),
                })

    results_df = pd.DataFrame(results)

    print("\n===== Full extended grid search results (sorted globally) =====")
    print(results_df.sort_values("best_val_loss").head(20))  # top 20 for brevity

    # Best configuration per conv_type
    print("\n===== Best configuration per conv_type =====")
    best_by_conv = {}
    for ct in conv_types:
        sub = results_df[results_df["conv_type"] == ct]
        if len(sub) == 0:
            continue
        best_row = sub.loc[sub["best_val_loss"].idxmin()]
        best_by_conv[ct] = best_row.to_dict()
        print(
            f"\nConv type: {ct}\n"
            f"  hidden_dim      = {best_row['hidden_dim']}\n"
            f"  latent_dim      = {best_row['latent_dim']}\n"
            f"  num_layers      = {best_row['num_layers']}\n"
            f"  lr              = {best_row['lr']}\n"
            f"  dropout         = {best_row['dropout']}\n"
            f"  best_val_loss   = {best_row['best_val_loss']:.6f}"
        )
        if ct == "sage":
            print(f"  sage_aggr       = {best_row['sage_aggr']}")
        elif ct == "nn":
            print(
                f"  nn_edge_hidden_dim = {best_row['nn_edge_hidden_dim']}, "
                f"nn_edge_mlp_layers = {best_row['nn_edge_mlp_layers']}, "
                f"nn_aggr = {best_row['nn_aggr']}"
            )
        elif ct == "gine":
            print(
                f"  gine_mlp_hidden_dim = {best_row['gine_mlp_hidden_dim']}, "
                f"gine_mlp_layers = {best_row['gine_mlp_layers']}, "
                f"gine_train_eps = {best_row['gine_train_eps']}, "
                f"gine_eps_init = {best_row['gine_eps_init']}"
            )
        elif ct == "transformer":
            print(
                f"  tr_heads = {best_row['tr_heads']}, "
                f"tr_dropout = {best_row['tr_dropout']}, "
                f"tr_beta = {best_row['tr_beta']}"
            )

    return results_df, best_by_conv


In [25]:
# 1) Prepare loaders for the dataset you want to tune on
train_loader, val_loader, in_channels, edge_dim = prepare_loaders_for_dataset(
    ML_dict,
    dataset="all_signals",
    edge_attr_mode="geo_pt",
    batch_size=64,
    max_bg_events_for_gs=2000,   # ← here you control how many BG events to use
)



=== Preparing loaders for dataset = 'all_signals' (for grid search) ===
Total background events available: 317657
Total signal events (unused for grid search): 317657
Subsampled background events to: 2000 for grid search.
Grid-search Train BG: 1400
Grid-search Val BG: 600
Collected node feature samples for scaler: (7887, 3)
Node feature scaler mean: [2.10414636e+01 8.35084668e-03 4.12605667e-02]
Node feature scaler scale: [9.87374194 1.42385701 1.80853129]
Train graphs: 1272
Val graphs:   528
in_channels (node features): 8
edge_dim (edge attributes): 6


In [23]:
def pretty_print_best_by_conv(best_by_conv):
    """
    Nicely print best_by_conv as a small summary table.

    best_by_conv is the dict returned by run_extended_grid_search, of the form:
        { conv_type: row_dict_from_results_df, ... }
    """
    if not best_by_conv:
        print("best_by_conv is empty.")
        return None

    # Turn values (row dicts) into a DataFrame
    rows = list(best_by_conv.values())
    df = pd.DataFrame(rows)

    # Choose a sensible column order
    preferred_cols = [
        "dataset",
        "conv_type",
        "hidden_dim",
        "latent_dim",
        "num_layers",
        "lr",
        "dropout",
        "best_val_loss",
        # conv-specific (will be NaN where not applicable)
        "sage_aggr",
        "nn_edge_hidden_dim",
        "nn_edge_mlp_layers",
        "nn_aggr",
        "gine_mlp_hidden_dim",
        "gine_mlp_layers",
        "gine_train_eps",
        "gine_eps_init",
        "tr_heads",
        "tr_dropout",
        "tr_beta",
    ]

    cols = [c for c in preferred_cols if c in df.columns]
    df = df[cols]

    # Sort by best_val_loss ascending
    if "best_val_loss" in df.columns:
        df = df.sort_values("best_val_loss")

    print("\n===== best_by_conv summary =====")
    print(df.to_string(index=False))

    return df


def build_conv_config_from_best(conv_type, best_row):
    """
    Extract conv-specific hyperparameters from a best_row dict
    (one entry of best_by_conv) and build a conv_config dict
    suitable for passing into GAEModel / GraphEncoder.
    """
    conv_type = conv_type.lower()
    cfg = {}

    if conv_type == "sage":
        cfg["sage_aggr"] = best_row.get("sage_aggr", "mean")

    elif conv_type == "nn":
        if pd.isna(best_row.get("nn_edge_hidden_dim", np.nan)):
            pass
        else:
            cfg["nn_edge_hidden_dim"] = int(best_row["nn_edge_hidden_dim"])
        if pd.isna(best_row.get("nn_edge_mlp_layers", np.nan)):
            pass
        else:
            cfg["nn_edge_mlp_layers"] = int(best_row["nn_edge_mlp_layers"])
        if pd.isna(best_row.get("nn_aggr", np.nan)):
            pass
        else:
            cfg["nn_aggr"] = best_row["nn_aggr"]

    elif conv_type == "gine":
        if not pd.isna(best_row.get("gine_mlp_hidden_dim", np.nan)):
            cfg["gine_mlp_hidden_dim"] = int(best_row["gine_mlp_hidden_dim"])
        if not pd.isna(best_row.get("gine_mlp_layers", np.nan)):
            cfg["gine_mlp_layers"] = int(best_row["gine_mlp_layers"])
        if not pd.isna(best_row.get("gine_train_eps", np.nan)):
            cfg["gine_train_eps"] = bool(best_row["gine_train_eps"])
        if not pd.isna(best_row.get("gine_eps_init", np.nan)):
            cfg["gine_eps_init"] = float(best_row["gine_eps_init"])

    elif conv_type == "transformer":
        if not pd.isna(best_row.get("tr_heads", np.nan)):
            cfg["tr_heads"] = int(best_row["tr_heads"])
        if not pd.isna(best_row.get("tr_concat", np.nan)):
            cfg["tr_concat"] = bool(best_row["tr_concat"])
        if not pd.isna(best_row.get("tr_dropout", np.nan)):
            cfg["tr_dropout"] = float(best_row["tr_dropout"])
        if not pd.isna(best_row.get("tr_beta", np.nan)):
            cfg["tr_beta"] = bool(best_row["tr_beta"])

    return cfg


def run_full_from_best(
    ML_dict,
    best_by_conv,
    conv_type,
    dataset="all_signals",
    edge_attr_mode="geo_pt",
    exp_root="gae_experiments",
    epochs=50,
    batch_size=64,
    use_exponential=True,
    alpha=1.0,
    beta=1.0,
):
    """
    Convenience helper:
    Take best_by_conv[conv_type] from the extended grid search,
    and run the full pipeline (training + all plots) with matching hyperparameters.

    NOTE:
    - Assumes run_full_pipeline(...) has a 'conv_config' argument and passes it to GAEModel.
    """
    ct = conv_type.lower()
    if ct not in best_by_conv:
        raise KeyError(f"conv_type '{ct}' not found in best_by_conv. Keys: {list(best_by_conv.keys())}")

    best_row = best_by_conv[ct]

    # Extract global hyperparameters
    hidden_dim = int(best_row["hidden_dim"])
    latent_dim = int(best_row["latent_dim"])
    num_layers = int(best_row["num_layers"])
    lr = float(best_row["lr"])
    dropout = float(best_row["dropout"])

    # Build conv-specific config dict
    conv_config = build_conv_config_from_best(ct, best_row)

    print(f"\n>>> Running full pipeline for dataset='{dataset}', conv_type='{ct}'")
    print(f"    Using best global hyperparameters from grid search:")
    print(f"      hidden_dim  = {hidden_dim}")
    print(f"      latent_dim  = {latent_dim}")
    print(f"      num_layers  = {num_layers}")
    print(f"      lr          = {lr}")
    print(f"      dropout     = {dropout}")
    print(f"    Conv-specific config: {conv_config}")

    # IMPORTANT:
    # Make sure your run_full_pipeline signature includes 'conv_config=None'
    # and passes it to GAEModel(..., conv_config=conv_config).
    results = run_full_pipeline(
        ML_dict,
        dataset=dataset,
        conv_type=ct,
        edge_attr_mode=edge_attr_mode,
        exp_root=exp_root,
        batch_size=batch_size,
        hidden_dim=hidden_dim,
        latent_dim=latent_dim,
        num_layers=num_layers,
        dropout=dropout,
        decoder_hidden_dim=hidden_dim,  # you can change if you want
        epochs=epochs,
        lr=lr,
        weight_decay=0.0,
        use_exponential=use_exponential,
        alpha=alpha,
        beta=beta,
        max_bkg_eta_phi=3,
        max_sig_eta_phi=3,
        conv_config=conv_config,
    )

    return results


In [27]:
results_df, best_by_conv = run_extended_grid_search(
    train_loader,
    val_loader,
    in_channels,
    edge_dim,
    use_exponential=False,
    dataset_name="all_signals",
)

# Neatly print best_by_conv
summary_df = pretty_print_best_by_conv(best_by_conv)

# Then, say you want to run the full pipeline for the best GINE config:
# results_gine = run_full_from_best(
#     ML_dict,
#     best_by_conv,
#     conv_type="gine",
#     dataset="all_signals",
#     edge_attr_mode="geo_pt",
#     epochs=50,
# )



=== conv=sage, hid=32, lat=8, layers=2, lr=0.0001, drop=0.0, conv_cfg={'sage_aggr': 'mean'} ===
--> Best val loss for this config: 0.421556

=== conv=sage, hid=32, lat=8, layers=2, lr=0.0001, drop=0.0, conv_cfg={'sage_aggr': 'max'} ===
--> Best val loss for this config: 0.413530

=== conv=sage, hid=32, lat=8, layers=2, lr=0.0001, drop=0.1, conv_cfg={'sage_aggr': 'mean'} ===
--> Best val loss for this config: 0.445864

=== conv=sage, hid=32, lat=8, layers=2, lr=0.0001, drop=0.1, conv_cfg={'sage_aggr': 'max'} ===
--> Best val loss for this config: 0.394501

=== conv=sage, hid=32, lat=8, layers=2, lr=0.0001, drop=0.2, conv_cfg={'sage_aggr': 'mean'} ===
--> Best val loss for this config: 0.441793

=== conv=sage, hid=32, lat=8, layers=2, lr=0.0001, drop=0.2, conv_cfg={'sage_aggr': 'max'} ===
--> Best val loss for this config: 0.405721

=== conv=sage, hid=32, lat=8, layers=2, lr=0.0005, drop=0.0, conv_cfg={'sage_aggr': 'mean'} ===
--> Best val loss for this config: 0.110466

=== conv=sage, 

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()


## Driver: run_full_pipeline(ML_dict, dataset=..., ...)

This function is the main entry point.  
You call it once at the end, and it:

1. Picks `df = ML_dict[dataset]`.
2. Splits into background (`target == "EB_test"`) and signal.
3. Trains a GAE (background only) with the chosen convolution.
4. Produces and saves:
   - training/validation loss curve,
   - anomaly score histogram + ROC AUC,
   - latent PCA plot,
   - η–φ input vs reconstruction plots for a few BG & signal events.


In [None]:
# Example: run on the combined dataset 'all_signals'
results_all = run_full_pipeline(
    ML_dict,
    dataset="all_signals",   # or "Znunu", "HNLeemu", etc.
    conv_type="gine",        # "sage", "nn", "gine", "transformer"
    edge_attr_mode="geo_pt", # geometric + log pT edge attributes
    epochs=50,
    hidden_dim=64,
    latent_dim=16,
)

# Example: run on a specific dataset, e.g. only Znunu vs EB_test
# results_znunu = run_full_pipeline(ML_dict, dataset="Znunu", conv_type="nn")
