In [None]:
import h5py
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import yaml

plt.rcParams["figure.dpi"] = 400
# plt.rcParams["text.usetex"] = True
plt.rcParams["text.usetex"] = False
# disabled due to missing font in texlive on the Nikhef clusters
plt.rcParams["font.family"] = "serif"
plt.rcParams["figure.constrained_layout.use"] = True

# Evaluation file structure

In [None]:
hdf_dir = "/data/atlas/users/slin/myHepattn/hepattn/src/hepattn/experiments/trackml/logs/HC-v3_20250818-T211321/ckpts/epoch=029-val_loss=0.34545_test_eval.h5"

In [None]:
f = h5py.File(hdf_dir, "r")
for key in f.keys():
    print("Key: %s  \tType: %s" % (key, type(f[key])))

In [None]:
group_key = list(f.keys())[2]
k = list(f[group_key])
for j in k:
    print(j)
    print("\t", f[group_key][j].keys())

## `inputs` group

In [None]:
print("\ninputs")
print("\t", list(f[group_key]["inputs"]))

print("\n============================")

print("\ninputs/hit_eta")
print("\t", list(f[group_key]["inputs"]["hit_eta"]))

## `outputs` group

In [None]:
print("\noutputs/final/")
print("\t", list(f[group_key]["outputs"]["final"]))

print("\n============================")

print("\noutputs/final/hit_filter/")
print("\t", list(f[group_key]["outputs"]["final"]["hit_filter"]))
print("\noutputs/final/hit_filter/hit_logit")
print("\t", list(f[group_key]["outputs"]["final"]["hit_filter"]["hit_logit"]))



## `preds` group

In [None]:
print("\npreds/final/")
print("\t", list(f[group_key]["preds"]["final"]))

print("\n============================")

print("\npreds/final/hit_filter/")
print("\t", list(f[group_key]["preds"]["final"]["hit_filter"]))
print("\npreds/final/hit_filter/hit_on_valid_particle")
print("\t", list(f[group_key]["preds"]["final"]["hit_filter"]["hit_on_valid_particle"]))


## `targets` group

In [None]:
print("\ntargets/")
print("\t", list(f[group_key]["targets"]))

print("\n============================")

print("\ntargets/hit_on_valid_particle")
print("\t", list(f[group_key]["targets"]["hit_on_valid_particle"]))
print("\ntargets/hit_valid")
print("\t", list(f[group_key]["targets"]["hit_valid"]))
print("\ntargets/particle_hit_valid")
print("\t", list(f[group_key]["targets"]["particle_hit_valid"]))
print("\ntargets/particle_pt")
print("\t", list(f[group_key]["targets"]["particle_pt"]))
print("\ntargets/particle_valid")
print("\t", list(f[group_key]["targets"]["particle_valid"]))
print("\ntargets/sample_id")
print("\t", list(f[group_key]["targets"]["sample_id"]))

# Filter model evaluation

## Plot parameters

In [None]:
training_colours = {
    "600 MeV": "mediumvioletred",
    "750 MeV": "cornflowerblue",
    # "1 GeV": "mediumseagreen", # |eta| < 2.5
    "0.9 GeV": "mediumseagreen",  # |eta| < 4.0
}

qty_bins = {
    "pt": np.array([0.6, 0.75, 1.0, 1.5, 2, 3, 4, 6, 10]),
    # "eta": np.array([-2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5]),
    "eta": np.array([-4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]),
    "phi": np.array([-3.14, -2.36, -1.57, -0.79, 0, 0.79, 1.57, 2.36, 3.14]),
    "vz": np.array([-100, -50, -20, -10, 0, 10, 20, 50, 100]),
}

qty_symbols = {"pt": "p_\\mathrm{T}", "eta": "\\eta", "phi": "\\phi", "vz": "v_z"}
qty_units = {"pt": "[GeV]", "eta": "", "phi": "", "vz": "[mm]"}
out_dir = "plots"

## Retrieve filtering model configuration

In [None]:
with open("/data/atlas/users/slin/myHepattn/hepattn/src/hepattn/experiments/trackml/configs/filtering.yaml", "r") as f:
    fconfig= yaml.safe_load(f)

filter_params = ["particle_min_pt", "particle_max_abs_eta"]

print("name: "+fconfig["name"])
for i in filter_params:
    print("> "+i+"\t: ", fconfig["data"][i])

filtering_configs = {
    "0.9 GeV": fconfig
}

filtering_fnames = {
    "0.9 GeV": "/data/atlas/users/slin/myHepattn/hepattn/src/hepattn/experiments/trackml/logs/HC-v3_20250818-T211321/ckpts/epoch=029-val_loss=0.34545_test_eval.h5"
}
filter_threshold = filtering_configs["0.9 GeV"]["model"]["model"]["init_args"]["tasks"]["init_args"]["modules"][0]["init_args"]["threshold"]
print("> threshold\t\t: ", filter_threshold)
filter_inputs = ["hits_"+filtering_configs["0.9 GeV"]["data"]["inputs"]["hit"][i] for i in range(len(filtering_configs["0.9 GeV"]["data"]["inputs"]["hit"]))]
print("> inputs: ", filter_inputs)

## Load evaluation file

In [None]:
import importlib
import plot_utils
import hit_evaluate
importlib.reload(plot_utils)
importlib.reload(hit_evaluate)
from plot_utils import binned, profile_plot
from hit_evaluate import load_events

In [None]:
from hit_evaluate import load_events
from plot_utils import binned, profile_plot

In [None]:
%%time
filtering_results = {}
num_events=None
for name, fname in filtering_fnames.items():
    filter_threshold = filtering_configs[name]["model"]["model"]["init_args"]["tasks"]["init_args"]["modules"][0]["init_args"]["threshold"]
    filtering_results[name] = load_events(fname=fname, randomize=num_events, write_inputs=None, write_parts=True, threshold=filter_threshold)

## Plotting metrics

### Discriminant

In [None]:
for name, (hits, targets, parts, metrics) in filtering_results.items():
    fig, ax = plt.subplots(figsize=(5,3), constrained_layout=True)
    ax.hist(hits["score_sigmoid"][targets["hit_on_valid_particle"]],
            range=[0,1], bins=40, density=True,
            color="C0", alpha=0.5,
            label="Valid hits"
           )
    ax.hist(hits["score_sigmoid"][~targets["hit_on_valid_particle"]],
            range=[0,1], bins=40, density=True,
            color="C1", alpha=0.5,
            label="Invalid hits"
           )

    ax.axvline(filter_threshold,
                color="r", ls="dashed", label="Threshold: %.1f"%(filter_threshold))
    ax.set_xlabel("Discriminant score")
    ax.set_ylabel("Normalized counts")
    ax.set_xlim(-0.025,1.025)
    ax.grid(which="both")
    ax.grid(zorder=0, alpha=0.25, linestyle="--")
    ax.legend()

### Receiver operating characteristic

In [None]:
fig, ax = plt.subplots(figsize=(5,3), constrained_layout=True)
for name, (hits, targets, parts, metrics) in filtering_results.items():
    ax.plot(metrics["roc_fpr"], metrics["roc_tpr"],
            color=training_colours[name],
            label="%s %s\nAUC: %.4f"%(filtering_configs[name]["name"], name, metrics["roc_fpr_tpr_auc"])
           )

    thid = np.argmin(np.abs(metrics["roc_fpr_tpr_thr"] - 0.1))
    ax.scatter(metrics["roc_fpr"][thid], metrics["roc_tpr"][thid],
               color=training_colours[name], s=100)

ax.set_xlabel("False positive rate")
ax.set_ylabel("True positive rate")
ax.set_xlim(-0.05,1.01)
ax.set_ylim(0.0, 1.05)
ax.grid(which="both")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.legend()

### Efficiency purity plot

In [None]:
fig, ax = plt.subplots(figsize=(5,3), constrained_layout=True)
for name, (hits, targets, parts, metrics) in filtering_results.items():
    ax.plot(metrics["roc_eff"], metrics["roc_pur"],
            color=training_colours[name],
            label="%s %s\nAUC: %.4f"%(filtering_configs[name]["name"], name, metrics["roc_eff_pur_auc"])
           )
    #opt_thresh_label = "Opt. thr.: %.2f"%(filter_threshold)
    thid = np.argmin(np.abs(metrics["roc_eff_pur_thr"] - 0.1))
    ax.scatter(metrics["roc_eff"][thid], metrics["roc_pur"][thid],
               color=training_colours[name], s=100)#, label=opt_thresh_label)

ax.set_xlabel("Hit Efficiency")
ax.set_ylabel("Hit Purity")
ax.set_xlim(0.9,1.01)
ax.set_ylim(0.3, 1.01)
ax.grid(which="both")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.grid(zorder=0, alpha=0.25, linestyle="--")
ax.legend()

### Particle efficiency (pT binned)

In [None]:
fig, ax = plt.subplots(figsize=(5,3), constrained_layout=True)

for name, (hits, targets, parts, metrics) in filtering_results.items():
    reconstructable = np.where(parts["pred_hits"] >= 3, True, False) # reconstructable particles must have >=3 hits
    reconstructable = reconstructable & parts["valid"] # apply valid_particle selection
    valid = ~np.isnan(parts["particle_pt"]) # remove excess entries (particles in event less than n_max_particles)
    bin_count, bin_error = binned(reconstructable[valid], parts["particle_pt"][valid], qty_bins["pt"])
    profile_plot(bin_count, bin_error, qty_bins["pt"], axes=ax, color=training_colours["0.9 GeV"])

ax.set_xlabel(r'Particle $%s$ %s'%(qty_symbols["pt"], qty_units["pt"]))
ax.set_ylabel('Reconstructable particles')
ax.set_ylim(0.97,1)
ax.set_xticks(np.arange(start=2, stop=11, step=2))
ax.grid(which='both')
ax.grid(zorder=0, alpha=0.25, linestyle="--")
plt.show()

### Combined plot

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10,3), constrained_layout=True)
for name, (hits, targets, parts, metrics) in filtering_results.items():
    ax[0].plot(metrics["roc_eff"], metrics["roc_pur"],
            color=training_colours[name],
            label="%s %s\nAUC: %.4f"%(filtering_configs[name]["name"], name, metrics["roc_eff_pur_auc"])
           )
    thid = np.argmin(np.abs(metrics["roc_eff_pur_thr"] - 0.1))
    ax[0].scatter(metrics["roc_eff"][thid], metrics["roc_pur"][thid],
                  color=training_colours[name], s=100)

    # reconstructable particles must have >=3 hits
    reconstructable = np.where(parts["pred_hits"] >= 3, True, False)
    # apply valid_particle selection
    reconstructable = reconstructable & parts["valid"]
    # remove excess entries (particles in event less than n_max_particles)
    valid = ~np.isnan(parts["particle_pt"])
    bin_count, bin_error = binned(reconstructable[valid], parts["particle_pt"][valid], qty_bins["pt"])
    profile_plot(bin_count, bin_error, qty_bins["pt"], axes=ax[1], color=training_colours["0.9 GeV"], label="%s %s"%(filtering_configs[name]["name"], name))

ax[0].set_xlabel("Hit Efficiency")
ax[0].set_ylabel("Hit Purity")
ax[0].set_xlim(0.9,1.01)
ax[0].set_ylim(0.3, 1.01)
ax[0].grid(which="both")
ax[0].grid(zorder=0, alpha=0.25, linestyle="--")
ax[0].legend(loc=3)

ax[1].set_xlabel(r'Particle $%s$ %s'%(qty_symbols["pt"], qty_units["pt"]))
ax[1].set_ylabel('Reconstructable Particles')
ax[1].set_ylim(0.97,1)
ax[1].set_xticks(np.arange(start=2, stop=11, step=2))
ax[1].grid(which='both')
ax[1].grid(zorder=0, alpha=0.25, linestyle="--")
ax[1].legend(loc=3)

fig.savefig(out_dir + "/filter_response.pdf")
plt.show()