In [None]:
import sys
sys.path.append("..") # Ensure the parent directory is in the path

import core.AssignmentTransformer as TransformerAnalysis
from core.DataLoader import DataPreprocessor
import core.AssignmentKFold as KFold
import numpy as np
from importlib import reload
import keras
import tensorflow as tf
import matplotlib.pyplot as plt
MAX_JETS = 6
PLOTS_DIR = f"plots/"
import os
if not os.path.exists(PLOTS_DIR):
    os.makedirs(PLOTS_DIR)

DataProcessor = DataPreprocessor(jet_features=["ordered_jet_pt", "ordered_jet_e", "ordered_jet_phi", "ordered_jet_eta", "ordered_jet_b_tag","m_l1j", "m_l2j", "dR_l1j", "dR_l2j"], 
                                lepton_features=["lep_pt","lep_e", "lep_eta", "lep_phi"],
                                jet_truth_label="ordered_event_jet_truth_idx", 
                                lepton_truth_label="event_lepton_truth_idx", 
                                global_features = ["met_met_NOSYS","met_phi_NOSYS"], 
                                max_leptons=2, 
                                max_jets = MAX_JETS, 
                                non_training_features= ["truth_ttbar_mass", "truth_ttbar_pt", "N_jets"], 
                                event_weight="weight_mc_NOSYS")
DataProcessor.load_data("/data/dust/group/atlas/ttreco/full_training.root", "reco", max_events=100000)
DataProcessor.prepare_data()
DataProcessor.normalise_data()
DataProcessor.split_data(0.2)

In [None]:
reload(TransformerAnalysis)
reload(KFold)

TransformerMatcher = TransformerAnalysis.FeatureConcatTransformer(DataProcessor); TransformerMatcher.load_data(*DataProcessor.get_data())
TransformerMatcher.build_model(
    hidden_dim = 32,
    num_heads = 8,
    num_layers = 3,
    dropout_rate = 0.1,
)
TransformerMatcher.compile_model(
    lambda_excl=0, optimizer=keras.optimizers.Adam(learning_rate=1e-4)
)
TransformerMatcher.model.summary()
TransformerMatcher.load_data(*DataProcessor.get_data())

In [None]:
TransformerMatcher.train_model(epochs=100,
                                batch_size=128,
                                verbose=1,
                                #weight = "sample",
                                callbacks = keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=50, restore_best_weights=True, mode ="max"))


In [None]:
TransformerMatcher.plot_history()

In [None]:
TransformerMatcher.plot_confusion_matrix()

In [None]:
TransformerMatcher.plot_permutation_importance(shuffle_number=5)