In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import yaml
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from plot_utils import binned, hist_plot, profile_plot
from track_evaluate import load_events

plt.rcParams["figure.dpi"] = 200
# plt.rcParams["text.usetex"] = True
plt.rcParams["text.usetex"] = False
# disabled due to missing font in texlive on the Nikhef clusters
plt.rcParams["font.family"] = "serif"
plt.rcParams["figure.constrained_layout.use"] = True

# Tracking model evaluation

## Plot parameters

In [None]:
training_colours = {
    "Paper": "tab:green",
    "TRK-v0 0.9 GeV": "tab:orange",  # |eta| < 4.0
    "trackml 1 GeV": "tab:blue",  # |eta| < 2.5
    "trackml 0.6 GeV": "mediumvioletred",
    "thresh_low_iou": "tab:green",
    "thresh_low_iou2": "tab:green",
    "thresh_low": "cornflowerblue",
    "1 GeV": "mediumseagreen",
    "iou": "mediumseagreen",
    "baseline": "red",
}

qty_bins = {
    "pt": np.array([0.6, 0.75, 1.0, 1.5, 2, 3, 4, 6, 10]),
    # "eta": np.array([-2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5]),
    "eta": np.array([-4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]),
    "phi": np.array([-math.pi, -2.36, -1.57, -0.79, 0, 0.79, 1.57, 2.36, math.pi]),
    "vz": np.array([-100, -50, -20, -10, 0, 10, 20, 50, 100]),
}

qty_symbols = {"pt": "p_\\mathrm{T}", "eta": "\\eta", "phi": "\\phi", "vz": "v_z"}
qty_units = {"pt": "[GeV]", "eta": "", "phi": "", "vz": "[mm]"}
out_dir = Path("plots/")
out_dir.mkdir(parents=True, exist_ok=True)

## Retrieve tracking model configuration

In [None]:
tracking_fnames = {
    # "Paper": "/data/atlas/users/slin/hepattn/src/hepattn/experiments/trackml/logs/epoch=028-val_loss=1.29786__test_test.h5",
    # "TRK-v0 0.9 GeV": "/data/atlas/users/slin/myHepattn/hepattn/src/hepattn/experiments/trackml/logs/TRK-v0-full_20250906-T205842/ckpts/epoch=029-val_loss=50.09092_test_eval.h5",
    # "trackml 1 GeV": "/data/atlas/users/slin/hepattn/src/hepattn/experiments/trackml/logs/trackml_tracking_20250711-T162137/ckpts/epoch=029-val_loss=12.68400_test_eval.h5",
    # "trackml 0.6 GeV": "/data/atlas/users/slin/hepattn/src/hepattn/experiments/trackml/logs/trackml_tracking_20251103-T102312/ckpts/epoch=029-val_loss=3.11541_test_eval.h5",
    # "sam": "/home/u5du/svanstroud.u5du/hepattn/src/hepattn/experiments/trackml/logs/TRK-v2-actuallyHybrid-preNormFinalDecLayer_20251114-T092015/ckpts/epoch=001-val_loss=4.24806_test_eval.h5",
    "baseline": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/ckpts/epoch=012-val_loss=0.82803_test_eval.h5",
    "thresh_low": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/ckpts/epoch=012-val_loss=0.82803_test_eval.h5",
    "thresh_low_iou": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/ckpts/epoch=012-val_loss=0.82803_test_eval.h5",
    # "thresh_low_iou2": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/ckpts/epoch=012-val_loss=0.82803_test_eval.h5",
    # "iou": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/ckpts/epoch=006-val_loss=0.94272_test_eval.h5",
}
load_events_kwargs = {
    "baseline": {"regression": False, "iou_threshold": 0.0, "track_valid_threshold": 0.5},
    "thresh_low": {"regression": False, "iou_threshold": 0.0, "track_valid_threshold": 0.1},
    "thresh_low_iou": {"regression": False, "iou_threshold": 0.5, "track_valid_threshold": 0.1},
    # "thresh_low_iou2": {"regression": False, "iou_threshold": 0.5, "track_valid_threshold": 0.2},
    # "iou": {"regression": False, "iou_threshold": 0.5, "track_valid_threshold": 0.5},
}
tracking_config_fname = {
    # "Paper": "/data/atlas/users/slin/hepattn/src/hepattn/experiments/trackml/logs/trackml_tracking_20251103-T102312/config.yaml",
    # "TRK-v0 0.9 GeV": "/data/atlas/users/slin/myHepattn/hepattn/src/hepattn/experiments/trackml/logs/TRK-v0-full_20250906-T205842/config.yaml",
    # "trackml 1 GeV": "/data/atlas/users/slin/hepattn/src/hepattn/experiments/trackml/logs/trackml_tracking_20250711-T162137/config.yaml",
    # "trackml 0.6 GeV": "/data/atlas/users/slin/hepattn/src/hepattn/experiments/trackml/logs/trackml_tracking_20251103-T102312/config.yaml",
    # "sam": "/home/u5du/svanstroud.u5du/hepattn/src/hepattn/experiments/trackml/logs/TRK-v2-actuallyHybrid-preNormFinalDecLayer_20251114-T092015/config.yaml",
    "baseline": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/config.yaml",
    "thresh_low": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/config.yaml",
    "thresh_low_iou": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/config.yaml",
    "thresh_low_iou2": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/config.yaml",
    # "iou": "/lus/lfs1aip2/projects/u5du/svanstroud/logs/TRK-v4-iou_20251201-T105911/config.yaml",
}
tracking_params = ["particle_min_pt", "particle_max_abs_eta"]
tracking_configs = {}
for name in tracking_config_fname:
    with Path(tracking_config_fname[name]).open() as f:
        fconfig = yaml.safe_load(f)
        print("name: " + fconfig["name"])
        for i in tracking_params:
            print("> " + i + "\t: ", fconfig["data"][i])
    tracking_configs[name] = fconfig

particle_targets = next(iter(tracking_configs.values()))["data"]["targets"]["particle"]
print("> particle targets: ", particle_targets)

## Load evaluation files

In [None]:
def load_single_model(name, fname, eta_cut, pt_cut, index_list, num_events, load_events_kwargs_for_model):
    """Helper function to load a single model's events."""
    key_mode = "old" if name == "Paper" else None  # None or "old"
    particle_targets = ["pt", "eta", "phi"]
    print(f"Loading {name} model with PT > {pt_cut} and |eta| < {eta_cut}\n")
    result = load_events(
        fname=fname,
        eta_cut=eta_cut,
        index_list=index_list,
        pt_cut=pt_cut,
        randomize=num_events,
        particle_targets=particle_targets,
        key_mode=key_mode,
        **load_events_kwargs_for_model,
    )
    print(f"Finished {name}\n")
    return name, result


In [None]:
num_events = 100
index_list = None  # ["event_0", ..., "event_99"] or ["29800", ..., "29899"]

args_list = []
for name, fname in tracking_fnames.items():
    eta_cut = tracking_configs[name]["data"]["particle_max_abs_eta"]
    pt_cut = tracking_configs[name]["data"]["particle_min_pt"]
    load_events_kwargs_for_model = load_events_kwargs.get(name, {"regression": False})
    args_list.append((name, fname, eta_cut, pt_cut, index_list, num_events, load_events_kwargs_for_model))

with Pool(processes=len(tracking_fnames)) as pool:
    results = pool.starmap(load_single_model, args_list)

tracking_results = dict(results)

## Plot metrics

### Efficiency and fake rate 

In [None]:
plot_fr = True
# plot_trainings = {"baseline"}  # trainings to be included in plot
plot_trainings = tracking_fnames.keys()
for qty in particle_targets:
    if qty not in {"pt", "eta", "phi"}:
        continue

    axlist = []
    if plot_fr:
        fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(12, 4), constrained_layout=True)
        axlist.extend([ax, ax1])
        ax1.set_xlabel(rf"Particle ${qty_symbols[qty]}^\mathrm{{True}}$ {qty_units[qty]}")
        ax1.set_ylabel("Fake Rate")

    else:
        fig, ax = plt.subplots(ncols=1, figsize=(6, 4), constrained_layout=True)
        axlist.append(ax)

    names = []
    for name, (tracks, parts) in tracking_results.items():
        if name in plot_trainings:
            names.append(name)
            """Efficiency plots"""
            reconstructable = parts["reconstructable"]
            # double majority
            bin_count, bin_error = binned(
                parts["eff_dm"][reconstructable],
                parts["particle_" + qty][reconstructable],
                qty_bins[qty],
                underflow=True,
                overflow=True,
                binomial=False,
            )
            profile_plot(bin_count, bin_error, qty_bins[qty], axes=ax, colour=training_colours[name], ls="solid")

            # perfect
            bin_count, bin_error = binned(
                parts["eff_perfect"][reconstructable],
                parts["particle_" + qty][reconstructable],
                qty_bins[qty],
                underflow=True,
                overflow=True,
                binomial=False,
            )
            profile_plot(bin_count, bin_error, qty_bins[qty], axes=ax, colour=training_colours[name], ls="dotted")

            """Fake rate plots"""
            if plot_fr:
                x = f"matched_{qty}" if f"matched_{qty}" in tracks.columns else f"part_{qty}"
                reconstructable = tracks["reconstructable"]
                # fake rate
                fakes = (~tracks["eff_dm"]) & (~tracks["duplicate"])
                bin_count, bin_error = binned(
                    fakes[reconstructable],
                    tracks[x][reconstructable],
                    qty_bins[qty],
                    underflow=True,
                    overflow=True,
                    binomial=False,
                )
                profile_plot(bin_count, bin_error, qty_bins[qty], axes=ax1, colour=training_colours[name], ls="solid")
                # duplicate
                bin_count, bin_error = binned(
                    tracks["duplicate"][reconstructable],
                    tracks[x][reconstructable],
                    qty_bins[qty],
                    underflow=True,
                    overflow=True,
                    binomial=False,
                )
                profile_plot(bin_count, bin_error, qty_bins[qty], axes=ax1, colour=training_colours[name], ls="dotted")

    # axis ranges
    ax.set_ylim(0.8, 1.1)
    ax.set_ylabel("Efficiency")
    ax.set_xlabel(rf"Particle ${qty_symbols[qty]}^\mathrm{{True}}$ {qty_units[qty]}")

    for i in axlist:
        i.grid(zorder=0, alpha=0.25, linestyle="--")
        if qty == "pt":
            i.set_xlim([0, 10.5])
            i.set_xticks(np.arange(start=2, stop=11, step=2))
        if qty == "eta":
            i.set_xlim([-4.5, 4.5])
            i.set_xticks(np.arange(start=-4, stop=4.5, step=1))
        if qty == "phi":
            i.set_xlim([-3.5, 3.5])
            i.set_xticks(np.arange(start=-3, stop=3.5, step=1))
        if qty == "vz":
            i.set_xlim([-112, 112])
            i.set_xticks(np.arange(start=-100, stop=110, step=25))

    # custom legends
    legend_elements_0 = [Line2D([0], [0], color=training_colours[training], label=training) for training in names]
    leg1_0 = ax.legend(handles=legend_elements_0, frameon=False, loc="upper left")
    ax.add_artist(leg1_0)

    legend_elements_eff = [Line2D([0], [0], color="black", label="DM"), Line2D([0], [0], color="black", ls="dotted", label="Perfect")]
    leg2_0 = ax.legend(handles=legend_elements_eff, frameon=False, loc="upper right")
    ax.add_artist(leg2_0)
    if plot_fr:
        leg1_1 = ax1.legend(handles=legend_elements_0, frameon=False, loc="upper left")
        ax1.add_artist(leg1_1)
        legend_elements_fake = [Line2D([0], [0], color="black", label="Fake"), Line2D([0], [0], color="black", ls="dotted", label="Duplicate")]
        leg2_1 = ax1.legend(handles=legend_elements_fake, frameon=False, loc="upper right")
        ax1.add_artist(leg2_1)
        axlist[1].set_ylim(0.0, 0.04)
        if qty == "pt":
            axlist[1].set_ylim(0.0, 0.06)

    fig.savefig(out_dir / f"{qty}_eff.pdf")

### Regression residual plots

In [None]:
plot_regression = True
plot_trainings = {"Paper", "TRK-v0 0.9 GeV"}  # trainings to be included in plot

if plot_regression:
    nbins = 55
    qty_res_bins = {
        "pt": np.linspace(-1, 1, nbins),
        "eta": np.linspace(-0.1, 0.1, nbins),
        "phi": np.linspace(-0.1, 0.1, nbins),
        "vz": np.linspace(-15, 15, nbins),
    }
    fig, ax = plt.subplots(nrows=2, ncols=2, constrained_layout=True)
    fig.set_size_inches(10, 4)
    ax = ax.flatten()

    for i, qty in enumerate(["pt", "eta", "phi", "vz"]):
        labels = []
        colours = []
        for name, (tracks, _parts) in tracking_results.items():
            if name in plot_trainings:
                bins = qty_res_bins[qty]
                colour = training_colours[name]
                # track physicsal quantity regression predicted value
                tracks_qty = tracks["track_" + qty][tracks["eff_dm"] & tracks["reconstructable_parts"]]
                # particle physical quantity true value
                parts_qty = tracks["matched_" + qty][tracks["eff_dm"] & tracks["reconstructable_parts"]]
                res = tracks_qty - parts_qty

                label = hist_plot(xs=res, bins=bins, xrange=(bins[0], bins[-1]), name=name, axes=ax[i], colour=colour)
                labels.append(label)
                colours.append(colour)

        ax[i].grid(zorder=0, alpha=0.25, linestyle="--")
        ax[i].set_xlabel(rf"${qty_symbols[qty]}^\mathrm{{Reco}} - {qty_symbols[qty]}^\mathrm{{True}}$ {qty_units[qty]}")
        ax[i].set_ylabel("Density")

        ticks = None
        if qty in {"eta", "phi"}:
            ticks = np.arange(-0.1, 0.11, 0.05).round(2)
        if qty == "pt":
            ticks = np.arange(-1, 1.1, 0.5).round(2)
        if qty == "vz":
            ticks = np.arange(-15, 16, 5).round(2)

        ax[i].set_xticks(ticks)
        ax[i].set_xticklabels(ticks)
        legend_elements = [Line2D([0], [0], color=colours[j], label=labels[j]) for j in range(len(labels))]
        ax[i].legend(handles=legend_elements, frameon=False, loc="upper left", fontsize=8)

    fig.savefig(out_dir + "regr_residuals.pdf")

### PT response vs particle PT and assigned hits

In [None]:
if plot_regression:
    fig, ax = plt.subplots(nrows=1, ncols=2, constrained_layout=True)
    fig.set_size_inches(10, 3)
    qty = "pt"
    for name, (tracks, _parts) in tracking_results.items():
        if name in plot_trainings:
            parts_eff_qty = tracks["matched_" + qty][tracks["eff_dm"]]
            tracks_eff_qty = tracks["track_" + qty][tracks["eff_dm"]]
            n_assigned = tracks["n_pred_hits"][tracks["eff_dm"]]
            response = tracks_eff_qty / parts_eff_qty
            # response = np.clip(response, 0.5, 1.5)

            # make hist
            bins = qty_bins[qty]
            ys, ys_err = binned(response, parts_eff_qty, bins, underflow=True, overflow=True, binomial=False)
            profile_plot(ys, ys_err, bins, axes=ax[0], colour=training_colours[name], ls="solid", label=name)

            bins = np.linspace(3, 10, 8)
            ys, ys_err = binned(response, n_assigned, bins, binomial=False)
            profile_plot(ys, ys_err, bins, axes=ax[1], colour=training_colours[name], ls="solid", label=name)

    ax[0].set_xlabel(rf"${qty_symbols[qty]}^\mathrm{{True}}$ {qty_units[qty]}")
    ax[0].set_ylabel(rf"${qty_symbols[qty]}^\mathrm{{Reco}}/{qty_symbols[qty]}^\mathrm{{True}}$")
    ax[0].grid(zorder=0, alpha=0.25, linestyle="--")
    ax[0].legend(frameon=False)
    ax[1].set_xlabel(r"Number of assigned hits")
    ax[1].set_ylabel(rf"${qty_symbols[qty]}^\mathrm{{Reco}}/{qty_symbols[qty]}^\mathrm{{True}}$")
    # ax[1].set_ylabel(rf"${qty_symbols[qty]}^\mathrm{{True}}$ {qty_units[qty]}")
    ax[1].grid(zorder=0, alpha=0.25, linestyle="--")
    ax[1].legend(frameon=False)
    fig.savefig(out_dir + "pt_regr-nhits-response_paper.pdf")

### Efficiency and fake rate numbers

In [None]:
for name, (tracks, parts) in tracking_results.items():
    print(name)
    tgts = parts[parts.reconstructable]
    trks = tracks[tracks.reconstructable]
    # compute high pt integrated metrics
    high_pt_parts = tgts[tgts.particle_pt > 1.0]
    high_pt_parts_900 = tgts[tgts.particle_pt > 0.9]
    high_pt_eff = high_pt_parts.eff_dm.mean()
    high_pt_eff_900 = high_pt_parts_900.eff_dm.mean()
    high_pt_tracks = trks[trks.matched_pt > 1.0]
    high_pt_tracks_900 = trks[trks.matched_pt > 0.9]
    high_pt_fr = (~high_pt_tracks.eff_dm & ~trks.duplicate).mean()
    high_pt_fr_900 = (~high_pt_tracks_900.eff_dm & ~trks.duplicate).mean()

    # compute the overall fake rate
    integrated_fr = (~trks.eff_dm & ~trks.duplicate).mean()

    # print summary
    print(f"N events: {100 if num_events is None else num_events}, N particles: {len(parts)}, N tracks: {len(tracks)}")
    print(f"DM Integrated efficiency: {tgts.eff_dm.mean():.1%}")
    print(f"DM Efficiency for pT > 1.0 GeV: {high_pt_eff:.1%}")
    print(f"DM Efficiency for pT > 0.9 GeV: {high_pt_eff_900:.1%}")
    print()
    print(f"DM Integrated fake rate: {integrated_fr:.1%}")
    print(f"DM Fake rate for pT > 1.0 GeV: {high_pt_fr:.1%}")
    print(f"DM Fake rate for pT > 0.9 GeV: {high_pt_fr_900:.1%}")
    print()
    print(f"Perfect integrated Efficiency: {tgts.eff_perfect.mean():.1%}")
    print(f"Perfect Efficiency for pT > 1.0 GeV: {high_pt_parts.eff_perfect.mean():.1%}")
    print(f"Perfect Efficiency for pT > 0.9 GeV: {high_pt_parts_900.eff_perfect.mean():.1%}")
    print()
    print(f"Duplicate rate: {tracks.duplicate.mean():.1%}")
    print("\n")