In [1]:
import sys

sys.path.append("..")
from core.DataLoader import (
    DataPreprocessor,
    DataConfig,
    LoadConfig,
    get_load_config_from_yaml,
)
import numpy as np
from importlib import reload
import matplotlib.pyplot as plt
import yaml
import core.assignment_models as Models
import core
import keras


MODEL_DIR = f"models/assignment_transformer"
CONFIG_PATH = "../config/workspace_config.yaml"

import os

if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)


load_config = get_load_config_from_yaml(CONFIG_PATH)

DataProcessor = DataPreprocessor(load_config)


with open(CONFIG_PATH, "r") as file:
    data_configs = yaml.safe_load(file)
plt.rcParams.update({"font.size": 14})

data_config = DataProcessor.load_from_npz(
    data_configs["data_path"]["nominal"], max_events=4_000_000
)
X_train, y_train, _, _ = DataProcessor.split_even_odd()
del DataProcessor

2025-12-17 14:18:19.744678: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765977499.963366 3592645 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765977499.993329 3592645 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765977500.135209 3592645 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765977500.135245 3592645 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765977500.135249 3592645 computation_placer.cc:177] computation placer alr

In [2]:
TransformerMatcher = Models.FeatureConcatTransformer(data_config, name=r"Transformer")

TransformerMatcher.build_model(
    hidden_dim=32,
    num_layers=4,
    num_heads=8,
    dropout_rate=0.1,
    compute_HLF=False,
)

TransformerMatcher.adapt_normalization_layers(X_train)

TransformerMatcher.compile_model(
    loss={
        "assignment": core.utils.AssignmentLoss(),
    },
    optimizer=keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4),
    metrics={
        "assignment": [core.utils.AssignmentAccuracy()],
    },
)
TransformerMatcher.model.summary()

FeatureConcatTransformer is designed for classification tasks; regression targets will be ignored.


I0000 00:00:1765977613.529864 3592645 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15511 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:65:00.0, compute capability: 6.0


Building model without regression output.
Submodel inputs:  [<KerasTensor shape=(None, 6, 5), dtype=float32, sparse=False, ragged=False, name=jet_inputs>, <KerasTensor shape=(None, 2, 4), dtype=float32, sparse=False, ragged=False, name=lep_inputs>, <KerasTensor shape=(None, 1, 2), dtype=float32, sparse=False, ragged=False, name=met_inputs>]
Input tensor:  <KerasTensor shape=(None, 6, 5), dtype=float32, sparse=False, ragged=False, name=jet_inputs> jet_inputs
Input tensor:  <KerasTensor shape=(None, 2, 4), dtype=float32, sparse=False, ragged=False, name=lep_inputs> lep_inputs
Input tensor:  <KerasTensor shape=(None, 1, 2), dtype=float32, sparse=False, ragged=False, name=met_inputs> met_inputs


I0000 00:00:1765977615.632235 3595366 service.cc:152] XLA service 0x7f87dc004960 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1765977615.632267 3595366 service.cc:160]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
2025-12-17 14:20:15.642023: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator pre_met_input_normalization_model_1/met_input_transform_1/assert_equal_1/Assert/Assert
2025-12-17 14:20:15.659932: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1765977615.717639 3595366 cuda_dnn.cc:529] Loaded cuDNN version 91500


[1m151/943[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step

I0000 00:00:1765977616.018100 3595366 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m943/943[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step


2025-12-17 14:20:16.995351: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator pre_met_input_normalization_model_1/met_input_transform_1/assert_equal_1/Assert/Assert


Adapted normalization layer:  met_input_normalization
Submodel inputs:  [<KerasTensor shape=(None, 6, 5), dtype=float32, sparse=False, ragged=False, name=jet_inputs>, <KerasTensor shape=(None, 2, 4), dtype=float32, sparse=False, ragged=False, name=lep_inputs>, <KerasTensor shape=(None, 1, 2), dtype=float32, sparse=False, ragged=False, name=met_inputs>]
Input tensor:  <KerasTensor shape=(None, 6, 5), dtype=float32, sparse=False, ragged=False, name=jet_inputs> jet_inputs
Input tensor:  <KerasTensor shape=(None, 2, 4), dtype=float32, sparse=False, ragged=False, name=lep_inputs> lep_inputs
Input tensor:  <KerasTensor shape=(None, 1, 2), dtype=float32, sparse=False, ragged=False, name=met_inputs> met_inputs
[1m943/943[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Adapted normalization layer:  lep_input_normalization
Submodel inputs:  [<KerasTensor shape=(None, 6, 5), dtype=float32, sparse=False, ragged=False, name=jet_inputs>, <KerasTensor shape=(None, 2, 4), dtype=float32

In [3]:
TransformerMatcher.train_model(
    epochs=50,
    X_train=X_train,
    y_train=y_train,
    sample_weights=core.utils.compute_sample_weights(X_train, y_train),
    batch_size=1024,
    callbacks=[
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            verbose=1,
        ),
        keras.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=8,
            restore_best_weights=True,
            verbose=1,
        ),
    ],
)

Epoch 1/50


2025-12-17 14:20:34.343771: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator FeatureConcatTransformerModel_1/met_input_transform_1/assert_equal_1/Assert/Assert


[1m1560/1562[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.1265 - loss: 0.1568

2025-12-17 14:20:58.999662: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator FeatureConcatTransformerModel_1/met_input_transform_1/assert_equal_1/Assert/Assert


[1m1562/1562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.1267 - loss: 0.1568

2025-12-17 14:21:10.111812: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator FeatureConcatTransformerModel_1/met_input_transform_1/assert_equal_1/Assert/Assert
2025-12-17 14:21:12.288331: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator FeatureConcatTransformerModel_1/met_input_transform_1/assert_equal_1/Assert/Assert


[1m1562/1562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 18ms/step - accuracy: 0.1268 - loss: 0.1568 - val_accuracy: 0.4714 - val_loss: 0.1190 - learning_rate: 1.0000e-04
Epoch 2/50
[1m1562/1562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 9ms/step - accuracy: 0.4841 - loss: 0.1241 - val_accuracy: 0.6245 - val_loss: 0.1004 - learning_rate: 1.0000e-04
Epoch 3/50
[1m1562/1562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 9ms/step - accuracy: 0.5906 - loss: 0.1098 - val_accuracy: 0.6731 - val_loss: 0.0936 - learning_rate: 1.0000e-04
Epoch 4/50
[1m1562/1562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 9ms/step - accuracy: 0.6249 - loss: 0.1039 - val_accuracy: 0.6888 - val_loss: 0.0899 - learning_rate: 1.0000e-04
Epoch 5/50
[1m1562/1562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 9ms/step - accuracy: 0.6425 - loss: 0.1006 - val_accuracy: 0.6985 - val_loss: 0.0879 - learning_rate: 1.0000e-04
Epoch 6/50
[1m1562/1562[0m [32m━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x7f8940106eb0>

In [4]:
TransformerMatcher.save_model(f"{MODEL_DIR}/model.keras")

Model saved to models/assignment_transformer/model.keras
