<a href="https://colab.research.google.com/github/raktim711/AIMS-project---Anomaly-Detection/blob/main/RFClassification_after_processing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification power with varying jet $p_T$ thresholds

In this notebook we take a preprocessed file and test the how well can the data be classified by varying minimum jet $p_T$ thresholds. The plots get stored in a folder in google drive.

In [1]:
import os
import math
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

In [2]:
# Set plotting style at module level
plt.rcParams.update({
    # Font sizes
    'font.size': 18,
    'axes.labelsize': 18,
    'axes.titlesize': 18,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 16,
    'legend.frameon': False,  # No box around legend
    'axes.grid': False,
    # Tick settings
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 10,
    'ytick.major.size': 10,
    'xtick.minor.size': 5,
    'ytick.minor.size': 5,
    'xtick.major.width': 1,
    'ytick.major.width': 1,
    'xtick.top': True,
    'ytick.right': True,
    'xtick.minor.visible': True,
    'ytick.minor.visible': True
})

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [25]:
import pickle

with open("/content/drive/MyDrive/Datasets/balanced_dfs_no_dup_OR.pkl", "rb") as f:
    ML_dict = pickle.load(f)


## Details of the data
ML_dict is a dictionary with the dataframes

```
'all_signals', 'HAHMggf', 'HNLeemu', 'HtoSUEP',
'VBF_H125_a55a55_4b_ctau1_filtered', 'Znunu',
'ggF_H125_a16a16_4b_ctau10_filtered', 'hh_bbbb_vbf_novhh_5fs_l1cvv1cv1'
```
All of them have the same columns:

```
'j0pt', 'j0eta', 'j0phi', 'j1pt', 'j1eta', 'j1phi', 'j2pt', 'j2eta',
       'j2phi', 'j3pt', 'j3eta', 'j3phi', 'j4pt', 'j4eta', 'j4phi', 'j5pt',
       'j5eta', 'j5phi', 'e0pt', 'e0eta', 'e0phi', 'e1pt', 'e1eta', 'e1phi',
       'e2pt', 'e2eta', 'e2phi', 'mu0pt', 'mu0eta', 'mu0phi', 'mu1pt',
       'mu1eta', 'mu1phi', 'mu2pt', 'mu2eta', 'mu2phi', 'ph0pt', 'ph0eta',
       'ph0phi', 'ph1pt', 'ph1eta', 'ph1phi', 'ph2pt', 'ph2eta', 'ph2phi',
       'METpt', 'METeta', 'METphi', 'run_number', 'event_number', 'weight',
       'target'
```
When loaded with `balanced_dfs_no_dup_processed.pkl`, the dataframes contain events for which there are no duplicate objects. Events with undefined METpt have been removed. All events where all objects have 0 pt have been removed. All of them have equal amount of signal and background ('target' == 'EB_test').



In [6]:
# Consistent style
plt.rcParams['figure.figsize'] = (8,6)
plt.rcParams['font.size'] = 12

# Jet-pt thresholds you want to test
JET_PT_THRESHOLDS = [5, 15, 25, 45, 60, 80]

# Which jet columns to use
JET_PT_COLS = [f"j{i}pt" for i in range(6)]   # j0pt ... j5pt


In [26]:
PLOT_BASE_DIR = "/content/drive/MyDrive/Datasets/plots_with_OR"
os.makedirs(PLOT_BASE_DIR, exist_ok=True)

## Define all the necessary helper functions

In [8]:
def apply_jet_pt_threshold(df, threshold):
    """
    Returns a filtered dataframe where all nonzero jets have pt >= threshold.
    Condition: for each jet jX,
      keep event if (jXpt == 0) or (jXpt >= threshold)
    """
    mask = np.ones(len(df), dtype=bool)
    for col in JET_PT_COLS:
        if col in df.columns:
            mask &= (df[col] == 0) | (df[col] >= threshold)
    return df[mask]


In [9]:
def prepare_dataset(df):
    """
    Drops unwanted columns, splits dataset once for reproducibility.
    """
    # Features = all physics columns except bookkeeping
    drop_cols = ["run_number", "event_number", "target", "weight"]
    feature_cols = [c for c in df.columns if c not in drop_cols]

    X = df[feature_cols].copy()
    y = (df["target"] == "EB_test").astype(int)

    # Single train/test split to be reused for all thresholds
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, stratify=y, random_state=42
    )

    return X_train, X_test, y_train, y_test, feature_cols


In [10]:
def train_and_evaluate_rf(X_train, y_train, X_test, y_test):
    """Trains a simple, stable Random Forest and returns ROC curve + AUC."""

    rf = RandomForestClassifier(
        n_estimators=300,
        max_depth=None,
        min_samples_split=2,
        random_state=42,
        n_jobs=-1
    )
    rf.fit(X_train, y_train)

    # Probabilities for ROC
    y_score = rf.predict_proba(X_test)[:, 1]

    fpr, tpr, _ = roc_curve(y_test, y_score)
    roc_auc = auc(fpr, tpr)

    return rf, fpr, tpr, roc_auc


In [11]:
def plot_roc_curves(roc_results, dataset_name, save_dir):
    plt.figure()
    for T, (fpr, tpr, roc_auc) in roc_results.items():
        plt.plot(fpr, tpr, label=f"T={T} GeV (AUC={roc_auc:.3f})")

    plt.plot([0,1], [0,1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curves — {dataset_name}")
    plt.legend()
    plt.grid(True)

    plt.savefig(os.path.join(save_dir, "roc_curves.png"), dpi=200, bbox_inches="tight")
    plt.close()


In [12]:
def plot_auc_vs_threshold(roc_results, dataset_name, save_dir):
    thresholds = list(roc_results.keys())
    auc_vals = [roc_results[T][2] for T in thresholds]

    plt.figure()
    plt.plot(thresholds, auc_vals, marker="o")
    plt.xlabel("Jet $p_T$ Threshold [GeV]")
    plt.ylabel("AUC")
    plt.title(f"AUC vs Jet $p_T$ Threshold — {dataset_name}")
    plt.grid(True)

    plt.savefig(os.path.join(save_dir, "auc_vs_threshold.png"), dpi=200, bbox_inches="tight")
    plt.close()


In [13]:
def plot_event_yields(counts_sig, counts_bkg, dataset_name, save_dir):
    plt.figure()
    plt.plot(list(counts_sig.keys()), list(counts_sig.values()),
             marker="o", label="Signal")
    plt.plot(list(counts_bkg.keys()), list(counts_bkg.values()),
             marker="s", label="Background")

    plt.xlabel("Jet $p_T$ Threshold [GeV]")
    plt.ylabel("Events passing selection")
    plt.title(f"Event Yields — {dataset_name}")
    plt.legend()
    plt.grid(True)

    plt.savefig(os.path.join(save_dir, "event_yields.png"), dpi=200, bbox_inches="tight")
    plt.close()


In [14]:
def compare_feature_importances(
        rf_low, rf_high, feature_cols, dataset_name, save_dir,
        low_T=15, high_T=60, topN=12):

    importances_low = rf_low.feature_importances_
    importances_high = rf_high.feature_importances_

    idx = np.argsort(importances_high)[::-1][:topN]

    plt.figure(figsize=(9,6))
    plt.barh(
        [feature_cols[i] for i in idx],
        importances_high[idx],
        alpha=0.7,
        label=f"T={high_T} GeV"
    )
    plt.barh(
        [feature_cols[i] for i in idx],
        importances_low[idx],
        alpha=0.7,
        label=f"T={low_T} GeV"
    )

    plt.gca().invert_yaxis()
    plt.xlabel("Feature Importance")
    plt.title(f"Feature Importance Comparison — {dataset_name}")
    plt.legend()

    plt.savefig(os.path.join(save_dir, "feature_importances.png"), dpi=200, bbox_inches="tight")
    plt.close()


## Main function

In [15]:
def run_full_analysis(dataset_name):
    print(f"=== Running full analysis for: {dataset_name} ===")

    df = ML_dict[dataset_name].copy()

    # Create directory for this dataset's plots
    save_dir = os.path.join(PLOT_BASE_DIR, dataset_name)
    os.makedirs(save_dir, exist_ok=True)

    # Prepare once
    X_train_all, X_test_all, y_train_all, y_test_all, feature_cols = prepare_dataset(df)

    roc_results = {}
    counts_sig = {}
    counts_bkg = {}
    rf_models = {}

    # Loop thresholds
    for T in JET_PT_THRESHOLDS:
        print(f"\n→ Applying jet pt threshold T = {T} GeV")

        X_train = apply_jet_pt_threshold(X_train_all, T)
        y_train = y_train_all.loc[X_train.index]

        X_test = apply_jet_pt_threshold(X_test_all, T)
        y_test = y_test_all.loc[X_test.index]

        counts_sig[T] = (y_test == 1).sum()
        counts_bkg[T] = (y_test == 0).sum()

        rf, fpr, tpr, roc_auc = train_and_evaluate_rf(
            X_train, y_train, X_test, y_test
        )

        roc_results[T] = (fpr, tpr, roc_auc)
        rf_models[T] = rf

    # Save all plots
    plot_roc_curves(roc_results, dataset_name, save_dir)
    plot_auc_vs_threshold(roc_results, dataset_name, save_dir)
    plot_event_yields(counts_sig, counts_bkg, dataset_name, save_dir)

    compare_feature_importances(
        rf_low=rf_models[15],
        rf_high=rf_models[60],
        feature_cols=feature_cols,
        dataset_name=dataset_name,
        save_dir=save_dir
    )

    print(f" Completed. Plots saved in: {save_dir}")


# Run the analysis for different datasets

In [28]:
run_full_analysis("HAHMggf")

=== Running full analysis for: HAHMggf ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/HAHMggf


In [29]:
run_full_analysis("HNLeemu")

=== Running full analysis for: HNLeemu ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/HNLeemu


In [30]:
run_full_analysis("HtoSUEP")

=== Running full analysis for: HtoSUEP ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/HtoSUEP


In [31]:
run_full_analysis("VBF_H125_a55a55_4b_ctau1_filtered")

=== Running full analysis for: VBF_H125_a55a55_4b_ctau1_filtered ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/VBF_H125_a55a55_4b_ctau1_filtered


In [32]:
run_full_analysis("Znunu")

=== Running full analysis for: Znunu ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/Znunu


In [33]:
run_full_analysis("ggF_H125_a16a16_4b_ctau10_filtered")

=== Running full analysis for: ggF_H125_a16a16_4b_ctau10_filtered ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/ggF_H125_a16a16_4b_ctau10_filtered


In [34]:
run_full_analysis("hh_bbbb_vbf_novhh_5fs_l1cvv1cv1")

=== Running full analysis for: hh_bbbb_vbf_novhh_5fs_l1cvv1cv1 ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/hh_bbbb_vbf_novhh_5fs_l1cvv1cv1


In [35]:
run_full_analysis("all_signals")

=== Running full analysis for: all_signals ===

→ Applying jet pt threshold T = 5 GeV

→ Applying jet pt threshold T = 15 GeV

→ Applying jet pt threshold T = 25 GeV

→ Applying jet pt threshold T = 45 GeV

→ Applying jet pt threshold T = 60 GeV

→ Applying jet pt threshold T = 80 GeV
 Completed. Plots saved in: /content/drive/MyDrive/Datasets/plots_with_OR/all_signals
