In [None]:
from pathlib import Path
import tensorflow as tf

In [None]:
dataset_path = "../processed_data/transformer_dataset"

# Load dataset without any transformations
raw_ds = tf.data.Dataset.load(dataset_path)

# Print dataset structure
print("Dataset element specification:", raw_ds.element_spec)

In [27]:
print("Dataset cardinality:",
      tf.data.experimental.cardinality(raw_ds).numpy())   # should now print a number


Dataset cardinality: 592871


In [None]:
def filter_split(split_num):
    def _filter(meta, x, y):
        return tf.equal(meta[2], split_num)
    return _filter

def drop_meta(meta, x, y):
    return x, y

In [None]:
train_ds = (raw_ds
            .filter(filter_split(0))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(4096)
            .batch(64)
            .prefetch(tf.data.AUTOTUNE))

val_ds   = (raw_ds
            .filter(filter_split(1))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(64)
            .prefetch(tf.data.AUTOTUNE))

test_ds  = (raw_ds
            .filter(filter_split(2))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(64)
            .prefetch(tf.data.AUTOTUNE))

# Take one batch from the dataset
for x_batch, y_batch in train_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

for x_batch, y_batch in val_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

for x_batch, y_batch in test_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

2025-04-30 11:28:00.649907: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 1220 of 4096
2025-04-30 11:28:16.262915: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


x_batch shape: (64, 100, 46)
y_batch shape: (64, 46)


2025-04-30 11:28:17.234248: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


x_batch shape: (64, 100, 46)
y_batch shape: (64, 46)
x_batch shape: (64, 100, 46)
y_batch shape: (64, 46)


In [29]:
# Check for GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        print(f"✅ Using GPU: {gpus[0].name}")
        # Set mixed precision policy
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
    except RuntimeError as e:
        print("Failed to set GPU memory growth:", e)
else:
    print("No GPU found. Using CPU.")

# Set logging
tf.debugging.set_log_device_placement(False) # Set it to True to make sure the GPU is used

✅ Using GPU: /physical_device:GPU:0


In [30]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

NUM_FEATS = 46          # x,y for 23 entities
MAX_LEN  = 100          # same value you used in dataset builder
D_MODEL  = 128          # transformer hidden size
N_HEADS  = 4
N_LAYERS = 4
D_FF     = 512
DROPOUT  = 0.1

In [31]:
# ╔═══════════════════╗
# ║ 2. Positional enc ║  (learnable 1‑D embedding)
# ╚═══════════════════╝
class PositionalEncoding(layers.Layer):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_emb = self.add_weight(
            name="pos_emb",
            shape=(max_len, d_model),
            initializer="uniform",
            trainable=True,
        )

    def call(self, x):
        return x + self.pos_emb


In [32]:
# ╔═══════════════════════════╗
# ║ 3. Padding‑mask function  ║
# ╚═══════════════════════════╝
class PaddingMask(layers.Layer):
    def call(self, x):
        # x:  (B, T, F) — zero‐padded on the left
        pad = tf.reduce_all(tf.equal(x, 0.0), axis=-1)      # → (B, T)
        # reshape to (B, 1, 1, T) for MultiHeadAttention
        return pad[:, tf.newaxis, tf.newaxis, :]



In [33]:
# ╔════════════════════════╗
# ║ 4. Transformer encoder ║
# ╚════════════════════════╝
def transformer_block(d_model, n_heads, d_ff, dropout):
    inputs   = layers.Input(shape=(None, d_model))
    padding  = layers.Input(shape=(1,1,None), dtype=tf.bool)  # mask

    x = layers.MultiHeadAttention(
        num_heads=n_heads, key_dim=d_model//n_heads, dropout=dropout
    )(inputs, inputs, attention_mask=padding)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(inputs + x)

    y = layers.Dense(d_ff, activation="relu")(x)
    y = layers.Dense(d_model)(y)
    y = layers.Dropout(dropout)(y)
    y = layers.LayerNormalization(epsilon=1e-6)(x + y)

    return keras.Model([inputs, padding], y)


In [34]:
# ╔════════════════════════════════╗
# ║ 5. End‑to‑end prediction model ║
# ╚════════════════════════════════╝
def build_model(
    num_feats=NUM_FEATS,
    max_len=MAX_LEN,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT,
):
    seq_in  = layers.Input(shape=(max_len, num_feats), name="sequence")   # (B,T,F)

    # Linear projection to d_model
    x = layers.Dense(d_model)(seq_in)

    # Add learnable positional encodings
    x = PositionalEncoding(max_len, d_model)(x)

    # Build padding mask once
    pad_mask = PaddingMask()(seq_in)

    # Stack encoder layers
    for _ in range(n_layers):
        x = transformer_block(d_model, n_heads, d_ff, dropout)([x, pad_mask])

    # We need the hidden state that corresponds to *frame t* (the last row)
    # – that is always index -1 thanks to left padding.
    h_t = layers.Lambda(lambda t: t[:, -1])(x)          # (B, D)

    # Regress the 46 co‑ordinates
    out = layers.Dense(num_feats, name="pred_xy")(h_t)

    return keras.Model(seq_in, out, name="NFL_Frame_Predictor")

model = build_model()
model.summary()


In [35]:
# ╔════════════════════╗
# ║ 6. Compile & train ║
# ╚════════════════════╝

# ── 1)  Make sure we have a place to put checkpoints ─────────────────
WEIGHT_DIR = Path("../weights")
WEIGHT_DIR.mkdir(parents=True, exist_ok=True)

ckpt_cb = keras.callbacks.ModelCheckpoint(
    filepath=(WEIGHT_DIR /
              "epoch_{epoch:03d}-val{val_loss:.4f}.weights.h5").as_posix(),
    monitor="val_loss",
    save_best_only=False,      # save every epoch → “periodic” archive
    save_weights_only=True,    # just the weights, not optimizer state
    verbose=0,
)

# ── 2)  Early-stopping ───────────────────────────────────────────────
early_stop = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
    verbose=1,
)

# ── 3)  Compile the model ────────────────────────────────────────────
LR = 1e-4
model.compile(
    optimizer=keras.optimizers.Adam(LR),
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.MeanAbsoluteError()],
)

# ── 4)  Fit – stop early, save weights each epoch ────────────────────
EPOCHS = 100   # high ceiling; early-stop decides real count
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[early_stop, ckpt_cb],
    verbose=1,
)

# optional: evaluate on test set after training
test_loss, test_mae = model.evaluate(test_ds, verbose=1)
print(f"\n✅  Test MSE: {test_loss:.5f}   |   Test MAE: {test_mae:.5f}")


Epoch 1/100


2025-04-30 11:28:37.228427: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 1738 of 4096
2025-04-30 11:28:43.886839: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.
I0000 00:00:1746008923.945567   98454 service.cc:152] XLA service 0x7f62c80040e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1746008923.945635   98454 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 4080 Laptop GPU, Compute Capability 8.9
2025-04-30 11:28:44.192019: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1746008925.807657   98454 cuda_dnn.cc:529] Loaded cuDNN version 90300











      1/Unknown [1m46s[0m 46s/step - loss: 1.7699 - mean_absolute_error: 1.0403

I0000 00:00:1746008945.177069   98454 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


   6439/Unknown [1m1915s[0m 290ms/step - loss: 0.0313 - mean_absolute_error: 0.1043





















   6443/Unknown [1m1941s[0m 294ms/step - loss: 0.0313 - mean_absolute_error: 0.1042

































2025-04-30 12:55:59.668833: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 13240305251111713949
2025-04-30 12:55:59.669175: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7154537782701321634


[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5261s[0m 810ms/step - loss: 0.0313 - mean_absolute_error: 0.1042 - val_loss: 4.1384e-04 - val_mean_absolute_error: 0.0150
Epoch 2/100


2025-04-30 12:56:12.513266: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 1852 of 4096
2025-04-30 12:56:30.936402: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 3917 of 4096
2025-04-30 12:56:36.415595: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 454ms/step - loss: 7.9201e-04 - mean_absolute_error: 0.0214

2025-04-30 13:45:22.074583: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2025-04-30 14:06:10.288949: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 13240305251111713949
2025-04-30 14:06:10.289326: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7154537782701321634


[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4210s[0m 648ms/step - loss: 7.9198e-04 - mean_absolute_error: 0.0214 - val_loss: 1.5941e-04 - val_mean_absolute_error: 0.0093
Epoch 3/100
[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 286ms/step - loss: 3.6120e-04 - mean_absolute_error: 0.0145

2025-04-30 14:59:24.710872: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 13240305251111713949
2025-04-30 14:59:24.710973: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7154537782701321634


[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3194s[0m 494ms/step - loss: 3.6120e-04 - mean_absolute_error: 0.0145 - val_loss: 8.1994e-05 - val_mean_absolute_error: 0.0069
Epoch 4/100


2025-04-30 14:59:35.565836: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 2661 of 4096
2025-04-30 14:59:47.447850: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 3478 of 4096
2025-04-30 14:59:57.663887: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 200ms/step - loss: 2.2658e-04 - mean_absolute_error: 0.0115

2025-04-30 15:21:31.854102: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 13240305251111713949
2025-04-30 15:47:24.311715: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 13240305251111713949
2025-04-30 15:47:24.311821: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7154537782701321634


[1m6443/6443[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2880s[0m 441ms/step - loss: 2.2658e-04 - mean_absolute_error: 0.0115 - val_loss: 5.4679e-05 - val_mean_absolute_error: 0.0056
Epoch 5/100


2025-04-30 15:47:35.895312: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:79: Filling up shuffle buffer (this may take a while): 1852 of 4096
2025-04-30 15:47:43.446702: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


[1m4671/6443[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m7:30[0m 254ms/step - loss: 1.6380e-04 - mean_absolute_error: 0.0097

KeyboardInterrupt: 

In [None]:
# ╔═══════════════╗
# ║ 7. Evaluation ║
# ╚═══════════════╝
# Simple end‑to‑end evaluation on a held‑out batch
for X_batch, y_batch in val_ds.take(1):
    y_pred = model(X_batch)
    mse = tf.reduce_mean(tf.square(y_pred - y_batch))
    print("Validation MSE (batch):", mse.numpy())
