In [None]:
import sys

sys.path.append("..")
from core.DataLoader import (
    DataPreprocessor,
    DataConfig,
    LoadConfig,
    get_load_config_from_yaml,
    combine_train_datasets
)
import numpy as np
from importlib import reload
import matplotlib.pyplot as plt
import yaml
import core.assignment_models as Models
import core
import keras
import tensorflow as tf


PLOTS_DIR = f"plots/mixed_model_HLF/"
MODEL_DIR = f"models/mixed_model_HLF"
#MODEL_DIR = f"../models/FeatureConcatTransformer_HLF_d256_l10_h8"
CONFIG_PATH = "../config/workspace_config_global_features.yaml"

import os

if not os.path.exists(PLOTS_DIR):
    os.makedirs(PLOTS_DIR, exist_ok=True)
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR, exist_ok=True)

load_config = get_load_config_from_yaml(CONFIG_PATH)

DataProcessorNominal = DataPreprocessor(load_config)
DataProcessorToponium = DataPreprocessor(load_config)

with open(CONFIG_PATH, "r") as file:
    data_configs = yaml.safe_load(file)

plt.rcParams.update({"font.size": 14})

data_config = DataProcessorNominal.load_from_npz(data_configs["data_path"]["nominal"], max_events=2000000)
DataProcessorToponium.load_from_npz(data_configs["data_path"]["toponium"])

X_nominal, y_nominal = DataProcessorNominal.get_data()
X_toponium, y_toponium = DataProcessorToponium.get_data()
X_train, y_train = combine_train_datasets(
    [X_nominal, X_toponium],
    [y_nominal, y_toponium],
    weights=[0.5, 0.5],
)

In [None]:
reload(Models)
reload(core)
TransformerMatcher = Models.CrossAttentionModel(data_config, name=r"Transformer + $\nu^2$-Flows")

#TransformerMatcher.load_model(f"{MODEL_DIR}/model.keras")
TransformerMatcher.build_model(
    hidden_dim=64,
    num_layers=8,
    num_heads=8,
    dropout_rate=0.1,
)

TransformerMatcher.adapt_normalization_layers(X_train)

TransformerMatcher.compile_model(
    loss={
        "assignment": core.utils.AssignmentLoss(lambda_excl=0.0),
    },
    optimizer=keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4),
    metrics={
        "assignment": [core.utils.AssignmentAccuracy()],
    },
)
TransformerMatcher.model.summary()

In [None]:
TransformerMatcher.train_model(
    epochs=100,
    X_train=X_train,
    y_train=y_train,
    sample_weights=core.utils.compute_sample_weights(X_train, y_train),
    batch_size=1024,
    callbacks=[
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=5,
            verbose=1,
            mode="min",
            min_lr=1e-6,
        ),
    ],
)

In [None]:
MODEL_DIR = f"models/mixed_cross_attention_model/"
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR, exist_ok=True)
TransformerMatcher.save_model(f"{MODEL_DIR}/model.keras")