In [None]:
import sys

sys.path.append("..")
from core.DataLoader import (
    DataPreprocessor,
    DataConfig,
    LoadConfig,
    get_load_config_from_yaml,
)
import numpy as np
from importlib import reload
import matplotlib.pyplot as plt
import yaml
import core.assignment_models as Models
from core.evaluation import MLEvaluator
from core.reconstruction import MLReconstructorBase
import core
import keras


PLOTS_DIR = "plots/plot_HLF_comp_histories/"
CONFIG_PATH = "../config/workspace_config_HLF.yaml"
CONFIG_PATH_BASELINE = "../config/workspace_config.yaml"
import os

if not os.path.exists(PLOTS_DIR):
    os.makedirs(PLOTS_DIR)


load_config = get_load_config_from_yaml(CONFIG_PATH_BASELINE)
load_config_high_level = get_load_config_from_yaml(
    CONFIG_PATH
)
DataProcessor = DataPreprocessor(load_config)
DataProcessor_high_level = DataPreprocessor(load_config_high_level)

with open(CONFIG_PATH_BASELINE, "r") as file:
    data_configs = yaml.safe_load(file)

data_config = DataProcessor.load_from_npz(
    data_configs["data_path"]["nominal"], max_events=400_000
)
X_val, y_val = DataProcessor.get_data()
data_config_high_level = DataProcessor_high_level.load_from_npz(
    data_configs["data_path"]["nominal"], max_events=400_000
)
X_val_high_level, y_val_high_level = DataProcessor_high_level.get_data()
MODEL_DIRS = ["../models/CrossAttentionTransformer_d256_l10_h8/", "../models/FeatureConcatTransformer_d256_l10_h8/", "../models/FeatureConcatTransformer_d256_l10_h8_HLF/"]
MODEL_CONFIGS = [data_config, data_config, data_config_high_level]
MODEL_NAMES = ["CrossAttention", "Transformer", "Transformer High Level Features"]
MODEL_X_TEST = [X_val, X_val, X_val_high_level]
MODEL_Y_TEST = [y_val, y_val, y_val_high_level]

ml_reconstructors = []
ml_evaluators = []
for i in range(len(MODEL_DIRS)):
    reconstructor = MLReconstructorBase(MODEL_CONFIGS[i], name=MODEL_NAMES[i], perform_regression=False)
    reconstructor.load_model(MODEL_DIRS[i] + "model.keras")
    ml_reconstructors.append(reconstructor)

evaluator = MLEvaluator(
    reconstructor=ml_reconstructors[1:],
    X_test=MODEL_X_TEST[1:],
    y_test=MODEL_Y_TEST[1:],
)
plt.rcParams.update({"font.size": 14})


2025-12-01 15:35:05.011575: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764599705.034188 3437346 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764599705.041465 3437346 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1764599705.059074 3437346 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764599705.059092 3437346 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764599705.059095 3437346 computation_placer.cc:177] computation placer alr

KeyboardInterrupt: 

: 

In [None]:
fig, ax = evaluator.plot_training_history()
fig.savefig(PLOTS_DIR + "training_histories_comparison.png")

In [None]:
def rename_features(str):
    if str == "ordered_jet_b_tag":
        return "b-tag"
    elif str == "ordered_jet_pt":
        return r"$p_{T}(\text{jet})$"
    elif str == "ordered_jet_eta":
        return r"$\eta(\text{jet})$"
    elif str == "ordered_jet_phi":
        return r"$\phi(\text{jet})$"
    elif str == "ordered_jet_e":
        return r"$E(\text{jet})$"
    elif str == "lep_pt":
        return r"$p_{T}(\ell)$"
    elif str == "lep_eta":
        return r"$\eta(\ell)$"
    elif str == "lep_phi":
        return r"$\phi(\ell)$"
    elif str == "lep_e":
        return r"$E(\ell)$"
    elif str == "met_met":
        return r"$p_{T}(miss)$"
    elif str == "met_phi":
        return r"$\phi(miss)$"
    elif str == "m_l1j":
        return r"$m(\ell^+, \text{jet})$"
    elif str == "m_l2j":
        return r"$m(\ell^-, \text{jet})$"
    elif str == "dR_l1j":
        return r"$\Delta R(\ell^+, \text{jet})$"
    elif str == "dR_l2j":
        return r"$\Delta R(\ell^-, \text{jet})$"
    else:
        return str

evaluator.plot_feature_importance(save_dir=PLOTS_DIR, num_repeats=1, rename_features=rename_features)