In [None]:
import sys

sys.path.append("..")
from core.DataLoader import (
    DataPreprocessor,
    get_load_config_from_yaml,
)
from importlib import reload
import keras
import core.keras_models as regression_transformer
import core.utils as utils

PLOTS_DIR = f"plots/regression_transformer/"
MODEL_DIR = f"models/regression_transformer/"
CONFIG_PATH = "../config/nominal_load_config.yaml"

import os

if not os.path.exists(PLOTS_DIR):
    os.makedirs(PLOTS_DIR)
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)


load_config = get_load_config_from_yaml(CONFIG_PATH)

DataProcessor = DataPreprocessor(load_config)


data_config = DataProcessor.load_from_npz(
    load_config.data_path, max_events=6_000_000, event_numbers="even"
)

X, y = DataProcessor.get_data()
del DataProcessor 

In [None]:
reload(regression_transformer)
Transformer = regression_transformer.FeatureConcatBinnedReconstructor(data_config, name="Transformer")

In [None]:
Transformer.build_model(
    hidden_dim=128,
    num_layers=6,
    dropout_rate=0.2,
    regression_bins = 20,
    use_global_event_inputs=True,
    log_variables=True,
)

In [None]:
Transformer.adapt_normalization_layers(X)
Transformer.compile_model(
    loss={
        "assignment": utils.AssignmentLoss(),
        "binned_regression": utils.BinnedRegressionLoss(),
    },
    optimizer=keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4),
    metrics={
        "assignment": [utils.AssignmentAccuracy(name="accuracy")],
        "binned_regression": utils.BinnedRegressionAccuracy(),
    },
    loss_weights={"assignment": 1.0, "binned_regression": 3.0}
)

In [None]:
X_train, y_train, sample_weights = Transformer.prepare_training_data(
    X,
    y,
    sample_weights=Transformer.compute_sample_weights(X)
)

In [None]:
Transformer.train_model(
    epochs=10,
    X=X_train,
    y=y_train,
    sample_weight=sample_weights,
    batch_size=1024,
    callbacks=[
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=5,
            verbose=1,
            mode="min",
            min_lr=1e-6,
        ),
        keras.callbacks.TerminateOnNaN(),
    ],
    validation_split=0.1,
)

In [None]:
Transformer.save_model(MODEL_DIR + "odd_model.keras")