In [None]:
import sys

sys.path.append("..")
from core.DataLoader import (
    DataPreprocessor,
    DataConfig,
    LoadConfig,
    get_load_config_from_yaml,
)
import numpy as np
from importlib import reload
import matplotlib.pyplot as plt
import yaml
import core.keras_models as Models
from core.evaluation import MLEvaluator
from core.reconstruction import KerasFFRecoBase
import core
import keras as keras


PLOTS_DIR = "plots/plot_HLF_comp_histories/"
CONFIG_PATH = "../config/workspace_config.yaml"
import os

if not os.path.exists(PLOTS_DIR):
    os.makedirs(PLOTS_DIR)


load_config = get_load_config_from_yaml(CONFIG_PATH)
DataProcessor = DataPreprocessor(load_config)


data_config = DataProcessor.load_from_npz(
    load_config.data_path["nominal"], max_events=2_000_000, event_numbers="odd"
)
X_val, y_val = DataProcessor.get_data()

In [None]:
MODEL_DIRS = ["../models/regression_transformer_PtEtaPhi/", "../models/regression_transformer_PtEtaPhiE/"]
MODEL_CONFIGS = [data_config, data_config]
MODEL_NAMES = ["Transformer (PtEtaPhi)", "Transformer (PtEtaPhiE)"]

ml_reconstructors = []
ml_evaluators = []
for i in range(len(MODEL_DIRS)):
    reconstructor = KerasFFRecoBase(MODEL_CONFIGS[i], name=MODEL_NAMES[i], perform_regression=False)
    reconstructor.load_model(MODEL_DIRS[i] + "odd_model.keras")
    ml_reconstructors.append(reconstructor)

evaluator = MLEvaluator(
    reconstructor=ml_reconstructors,
    X_test=X_val,
    y_test=y_val,
)
plt.rcParams.update({"font.size": 14})


In [None]:
evaluator.reconstructors[0].model.summary()

In [None]:
fig, ax = evaluator.plot_training_history()
fig.savefig(PLOTS_DIR + "training_histories_comparison.pdf")


In [None]:
evaluator.save_accuracy_latex_table(
    confidence=0.95,
    n_bootstrap=10,
    save_dir=PLOTS_DIR)

In [None]:
def rename_features(str):
    if str == "jet_b_tag":
        return "$b$-tag"
    elif str == "jet_pt":
        return r"$p_{T}(\text{jet})$"
    elif str == "jet_eta":
        return r"$\eta(\text{jet})$"
    elif str == "jet_phi":
        return r"$\phi(\text{jet})$"
    elif str == "jet_e":
        return r"$E(\text{jet})$"
    elif str == "lep_pt":
        return r"$p_{T}(\ell)$"
    elif str == "lep_eta":
        return r"$\eta(\ell)$"
    elif str == "lep_phi":
        return r"$\phi(\ell)$"
    elif str == "lep_e":
        return r"$E(\ell)$"
    elif str == "met_met":
        return r"$p_{T}(miss)$"
    elif str == "met_phi":
        return r"$\phi(miss)$"
    elif str == "m_l1j":
        return r"$m(\ell^+, \text{jet})$"
    elif str == "m_l2j":
        return r"$m(\ell^-, \text{jet})$"
    elif str == "dR_l1j":
        return r"$\Delta R(\ell^+, \text{jet})$"
    elif str == "dR_l2j":
        return r"$\Delta R(\ell^-, \text{jet})$"
    elif str == "N_jets":
        return r"$N(\text{jets})$"
    elif str == "N_bjets":
        return r"$N(b\text{-jets})$"
    elif str == "reco_mllbb":
        return r"$m_{reco}(\ell\ell bb)$"
    else:
        return str

evaluator.plot_feature_importance(save_dir=PLOTS_DIR, num_repeats=1, rename_features=rename_features)