In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import pathlib

import numpy as np
import yaml
from matplotlib import pyplot as plt

plt.rcParams["figure.dpi"] = 400
# plt.rcParams["text.usetex"] = True
plt.rcParams["text.usetex"] = False
# disabled due to missing font in texlive on the Nikhef clusters
plt.rcParams["font.family"] = "serif"
plt.rcParams["figure.constrained_layout.use"] = True

# Filter model evaluation

## Plot parameters

In [None]:
training_colours = {
    "600 MeV eta 4": "mediumvioletred",
    "600 MeV eta 2.5": "cornflowerblue",
    # "1 GeV": "mediumseagreen", # |eta| < 2.5
    "900 MeV eta 4": "mediumseagreen",  # |eta| < 4.0
}

qty_bins = {
    "pt": np.array([0.6, 0.9, 1.0, 1.5, 2, 3, 4, 6, 10]),
    # "eta": np.array([-2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5]),
    "eta": np.array([-4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]),
    "phi": np.array([-math.pi, -2.36, -1.57, -0.79, 0, 0.79, 1.57, 2.36, math.pi]),
    "vz": np.array([-100, -50, -20, -10, 0, 10, 20, 50, 100]),
}

qty_symbols = {"pt": "p_\\mathrm{T}", "eta": "\\eta", "phi": "\\phi", "vz": "v_z"}
qty_units = {"pt": "[GeV]", "eta": "", "phi": "", "vz": "[mm]"}
out_dir = "plots"

## Retrieve filtering model configuration

In [None]:
with pathlib.Path("/home/u5du/svanstroud.u5du/hepattn/src/hepattn/experiments/trackml/configs/filtering.yaml").open() as f:
    fconfig = yaml.safe_load(f)

filter_params = ["particle_min_pt", "particle_max_abs_eta"]

print("name: " + fconfig["name"])
for i in filter_params:
    print("> " + i + "\t: ", fconfig["data"][i])


filtering_fnames = {
    "900 MeV eta 4": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/HF-900MeV-eta4_20251208-T205635/ckpts/epoch=079-val_loss=0.36101_test_eval.h5",
    "600 MeV eta 4": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/HF-600MeV-eta4_20251208-T205514/ckpts/epoch=079-val_loss=0.26907_test_eval.h5",
    "600 MeV eta 2.5": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/HF-600MeV-eta2.5_20251208-T205400/ckpts/epoch=079-val_loss=0.11373_test_eval.h5",
}
key = list(filtering_fnames.keys())[0]
filtering_configs = {k: fconfig.copy() for k in filtering_fnames.keys()}

filter_inputs = ["hits_" + filtering_configs[key]["data"]["inputs"]["hit"][i] for i in range(len(filtering_configs[key]["data"]["inputs"]["hit"]))]
print("> inputs: ", filter_inputs)

## Load evaluation file

In [None]:
from hit_evaluate import load_events
from plot_utils import binned, profile_plot

In [None]:
filtering_results = {}
num_events = None
threshold = 0.1
for name, fname in filtering_fnames.items():
    filter_threshold = filtering_configs[name]["model"]["model"]["init_args"]["tasks"]["init_args"]["modules"][0]["init_args"]["threshold"]
    filtering_results[name] = load_events(fname=fname, randomize=num_events, write_inputs=None, write_parts=True, threshold=threshold)

## Plotting metrics

### Discriminant

In [None]:
for name, (hits, targets, _parts, _metrics) in filtering_results.items():
    fig, ax = plt.subplots(figsize=(5, 3), constrained_layout=True)
    ax.hist(hits["score_sigmoid"][targets["hit_on_valid_particle"]], range=[0, 1], bins=40, density=True, color="C0", alpha=0.5, label="Valid hits")
    ax.hist(
        hits["score_sigmoid"][~targets["hit_on_valid_particle"]], range=[0, 1], bins=40, density=True, color="C1", alpha=0.5, label="Invalid hits"
    )

    ax.axvline(filter_threshold, color="r", ls="dashed", label=f"{filtering_configs[name]['name']} Threshold: {filter_threshold:.1f}")
    ax.set_xlabel("Discriminant score")
    ax.set_ylabel("Normalized counts")
    ax.set_xlim(-0.025, 1.025)
    ax.grid(which="both")
    ax.grid(zorder=0, alpha=0.25, linestyle="--")
    ax.legend()

### Receiver operating characteristic

In [None]:
fig, ax = plt.subplots(figsize=(5, 3), constrained_layout=True)
for name, (_hits, _targets, _parts, metrics) in filtering_results.items():
    ax.plot(
        metrics["roc_fpr"],
        metrics["roc_tpr"],
        color=training_colours[name],
        label=f"{filtering_configs[name]['name']} {name}\nAUC: {metrics['roc_fpr_tpr_auc']:.4f}",
    )

    thid = np.argmin(np.abs(metrics["roc_fpr_tpr_thr"] - threshold))
    ax.scatter(metrics["roc_fpr"][thid], metrics["roc_tpr"][thid], color=training_colours[name], s=100)

ax.set_xlabel("False positive rate")
ax.set_ylabel("True positive rate")
ax.set_xlim(-0.05, 1.01)
ax.set_ylim(0.0, 1.05)
ax.grid(which="both")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.legend()

### Efficiency purity plot

In [None]:
fig, ax = plt.subplots(figsize=(5, 3), constrained_layout=True)
for name, (_hits, _targets, _parts, metrics) in filtering_results.items():
    ax.plot(
        metrics["roc_eff"],
        metrics["roc_pur"],
        color=training_colours[name],
        label=f"{filtering_configs[name]['name']} {name}\nAUC: {metrics['roc_eff_pur_auc']:.4f}",
    )

    thid = np.argmin(np.abs(metrics["roc_eff_pur_thr"] - threshold))
    ax.scatter(metrics["roc_eff"][thid], metrics["roc_pur"][thid], color=training_colours[name], s=100)

ax.set_xlabel("Hit Efficiency")
ax.set_ylabel("Hit Purity")
ax.set_xlim(0.9, 1.01)
ax.set_ylim(0.3, 1.01)
ax.grid(which="both")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.legend()

### Particle efficiency (pT binned)

In [None]:
fig, ax = plt.subplots(figsize=(5, 3), constrained_layout=True)

for name, (_hits, _targets, parts, _metrics) in filtering_results.items():
    reconstructable = np.where(parts["pred_hits"] >= 3, True, False)  # reconstructable particles must have >=3 hits
    reconstructable = reconstructable & parts["valid"]  # apply valid_particle selection
    valid = ~np.isnan(parts["particle_pt"])  # remove excess entries (particles in event less than n_max_particles)
    bin_count, bin_error = binned(reconstructable[valid], parts["particle_pt"][valid], qty_bins["pt"])
    profile_plot(bin_count, bin_error, qty_bins["pt"], axes=ax, colour=training_colours[name], label=f"{filtering_configs[name]['name']} {name}")

ax.set_xlabel(rf"Particle ${qty_symbols['pt']}$ {qty_units['pt']}")
ax.set_ylabel("Reconstructable particles")
ax.set_ylim(0.97, 1)
ax.set_xticks(np.arange(start=2, stop=11, step=2))
ax.grid(which="both")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.legend(loc=3)
plt.show()

### Combined plot

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 3), constrained_layout=True)
for name, (_hits, _targets, parts, metrics) in filtering_results.items():
    ax[0].plot(
        metrics["roc_eff"],
        metrics["roc_pur"],
        color=training_colours[name],
        label=f"{filtering_configs[name]['name']} {name}\nAUC: {metrics['roc_eff_pur_auc']:.4f}",
    )
    thid = np.argmin(np.abs(metrics["roc_eff_pur_thr"] - threshold))
    ax[0].scatter(metrics["roc_eff"][thid], metrics["roc_pur"][thid], color=training_colours[name], s=100)

    # reconstructable particles must have >=3 hits
    reconstructable = np.where(parts["pred_hits"] >= 3, True, False)
    # apply valid_particle selection
    reconstructable = reconstructable & parts["valid"]
    # remove excess entries (particles in event less than n_max_particles)
    valid = ~np.isnan(parts["particle_pt"])
    bin_count, bin_error = binned(reconstructable[valid], parts["particle_pt"][valid], qty_bins["pt"])
    profile_plot(bin_count, bin_error, qty_bins["pt"], axes=ax[1], colour=training_colours[name], label=f"{filtering_configs[name]['name']} {name}")

ax[0].set_xlabel("Hit Efficiency")
ax[0].set_ylabel("Hit Purity")
ax[0].set_xlim(0.96, 1.0)
ax[0].set_ylim(0.5, 1.01)
ax[0].grid(which="both")
ax[0].grid(zorder=0, alpha=0.25, linestyle="--")
ax[0].legend(loc=3)

ax[1].set_xlabel(rf"Particle ${qty_symbols['pt']}$ {qty_units['pt']}")
ax[1].set_ylabel("Reconstructable Particles")
ax[1].set_ylim(0.97, 1)
ax[1].set_xticks(np.arange(start=2, stop=11, step=2))
ax[1].grid(which="both")
ax[1].grid(zorder=0, alpha=0.25, linestyle="--")
ax[1].legend(loc=3)

fig.savefig(out_dir + "/filter_response.pdf")
plt.show()

In [None]:
# calculate the threshold that gives 99% hit efficiency
for name, (_hits, _targets, _parts, metrics) in filtering_results.items():
    thid = np.argmin(np.abs(metrics["roc_eff"] - 0.99))
    print(f"{name} threshold for 99% hit efficiency: {metrics['roc_eff_pur_thr'][thid]:.4f}")

In [None]:
# calculate fraction of hits remaining after filtering
for name, (hits, _targets, _parts, _metrics) in filtering_results.items():
    filter_threshold = filtering_configs[name]["model"]["model"]["init_args"]["tasks"]["init_args"]["modules"][0]["init_args"]["threshold"]
    num_hits_before = hits.shape[0]
    num_hits_after = np.sum(hits["score_sigmoid"] >= filter_threshold)
    fraction_remaining = num_hits_after / num_hits_before
    print(f"{name} fraction of hits remaining after filtering: {fraction_remaining:.4f}")