In [None]:
import sys

sys.path.append("..")
from core.DataLoader import (
    DataPreprocessor,
    get_load_config_from_yaml,
)
from importlib import reload
import core
import keras
import core.keras_models as keras_models
import core.utils as utils

MODEL_DIR = f"models/assignment_transformer/"
CONFIG_PATH = "../CONDOR/train_regression/load_config.yaml"

import os

if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)


load_config = get_load_config_from_yaml(CONFIG_PATH)

DataProcessor = DataPreprocessor(load_config)


data_config = DataProcessor.load_from_npz(
    load_config.data_path["nominal"], max_events=4_000_000, event_numbers="even"
)

X_train, y_train = DataProcessor.get_data()
del DataProcessor  # Free memory

In [None]:
Transformer = keras_models.FeatureConcatAssigner(data_config, name="Transformer")

In [None]:
Transformer.build_model(
    hidden_dim=64,
    num_layers=6,
    dropout_rate=0.2,
    use_global_event_inputs=True,
    #predict_confidence=True,
)

In [None]:
Transformer.adapt_normalization_layers(X_train)
Transformer.compile_model(
    loss={
        "assignment": utils.AssignmentLoss(),
    },
    optimizer=keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-5),
    metrics={
        "assignment": [utils.AssignmentAccuracy(name="assignment_accuracy")],
    },
    #loss_weights={"assignment": 1.0, "confidence_loss_output": 1.0},
)

In [None]:
Transformer.trainable_model.summary()

In [None]:
Transformer.train_model(
    epochs=10,
    data=X_train,
    sample_weights=utils.compute_sample_weights(X_train, y_train),
    batch_size=1024,
    callbacks=[
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=5,
            verbose=1,
            mode="min",
            min_lr=1e-6,
        ),
    ],
    validation_split=0.1,
)

In [None]:
Transformer.save_model(f"{MODEL_DIR}/odd_model.keras")

In [None]:
Transformer.export_to_onnx(f"{MODEL_DIR}/odd_model.onnx")
Transformer.export_to_onnx(f"{MODEL_DIR}/even_model.onnx")