In [7]:
from pathlib import Path
import tensorflow as tf

In [None]:
# Check for GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        print(f"✅ Using GPU: {gpus[0].name}")
        # Set mixed precision policy
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
    except RuntimeError as e:
        print("Failed to set GPU memory growth:", e)
else:
    print("No GPU found. Using CPU.")

# Set logging
tf.debugging.set_log_device_placement(False) # Set it to True to make sure the GPU is used

Visible GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [9]:
dataset_path = "../processed_data/transformer_dataset_9"

# Load dataset without any transformations
raw_ds = tf.data.Dataset.load(dataset_path)

# Print dataset structure
print("Dataset element specification:", raw_ds.element_spec)

Dataset element specification: (TensorSpec(shape=(4,), dtype=tf.int32, name=None), TensorSpec(shape=(100, 46), dtype=tf.float32, name=None), TensorSpec(shape=(46,), dtype=tf.float32, name=None))


In [10]:
print("Dataset cardinality:",
      tf.data.experimental.cardinality(raw_ds).numpy())   # should now print a number


Dataset cardinality: 2547197


In [11]:
def filter_split(split_num):
    def _filter(meta, x, y):
        return tf.equal(meta[2], split_num)
    return _filter

def drop_meta(meta, x, y):
    return x, y

In [None]:
BATCH_SIZE = 128

train_ds = (raw_ds
            .filter(filter_split(0))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(4096)
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE))

val_ds   = (raw_ds
            .filter(filter_split(1))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE))

test_ds  = (raw_ds
            .filter(filter_split(2))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE))

# Take one batch from the dataset
for x_batch, y_batch in train_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

for x_batch, y_batch in val_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

for x_batch, y_batch in test_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

2025-05-01 02:00:35.595410: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:23: Filling up shuffle buffer (this may take a while): 1519 of 4096
2025-05-01 02:00:45.614771: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:23: Filling up shuffle buffer (this may take a while): 2385 of 4096
2025-05-01 02:01:05.596084: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:23: Filling up shuffle buffer (this may take a while): 3842 of 4096
2025-05-01 02:01:08.398560: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


x_batch shape: (2048, 100, 46)
y_batch shape: (2048, 46)


2025-05-01 02:01:53.532377: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


x_batch shape: (2048, 100, 46)
y_batch shape: (2048, 46)


2025-05-01 02:02:30.087571: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


x_batch shape: (2048, 100, 46)
y_batch shape: (2048, 46)


In [14]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

NUM_FEATS = 46          # x,y for 23 entities
MAX_LEN  = 100          # same value you used in dataset builder
D_MODEL  = 128          # transformer hidden size
N_HEADS  = 4
N_LAYERS = 4
D_FF     = 512
DROPOUT  = 0.1

In [15]:
# ╔═══════════════════╗
# ║ 2. Positional enc ║  (learnable 1‑D embedding)
# ╚═══════════════════╝
class PositionalEncoding(layers.Layer):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_emb = self.add_weight(
            name="pos_emb",
            shape=(max_len, d_model),
            initializer="uniform",
            trainable=True,
        )

    def call(self, x):
        return x + self.pos_emb


In [16]:
# ╔═══════════════════════════╗
# ║ 3. Padding‑mask function  ║
# ╚═══════════════════════════╝
class PaddingMask(layers.Layer):
    def call(self, x):
        # x:  (B, T, F) — zero‐padded on the left
        pad = tf.reduce_all(tf.equal(x, 0.0), axis=-1)      # → (B, T)
        # reshape to (B, 1, 1, T) for MultiHeadAttention
        return pad[:, tf.newaxis, tf.newaxis, :]



In [17]:
# ╔════════════════════════╗
# ║ 4. Transformer encoder ║
# ╚════════════════════════╝
def transformer_block(d_model, n_heads, d_ff, dropout):
    inputs   = layers.Input(shape=(None, d_model))
    padding  = layers.Input(shape=(1,1,None), dtype=tf.bool)  # mask

    x = layers.MultiHeadAttention(
        num_heads=n_heads, key_dim=d_model//n_heads, dropout=dropout
    )(inputs, inputs, attention_mask=padding)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(inputs + x)

    y = layers.Dense(d_ff, activation="relu")(x)
    y = layers.Dense(d_model)(y)
    y = layers.Dropout(dropout)(y)
    y = layers.LayerNormalization(epsilon=1e-6)(x + y)

    return keras.Model([inputs, padding], y)


In [18]:
# ╔════════════════════════════════╗
# ║ 5. End‑to‑end prediction model ║
# ╚════════════════════════════════╝
def build_model(
    num_feats=NUM_FEATS,
    max_len=MAX_LEN,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT,
):
    seq_in  = layers.Input(shape=(max_len, num_feats), name="sequence")   # (B,T,F)

    # Linear projection to d_model
    x = layers.Dense(d_model)(seq_in)

    # Add learnable positional encodings
    x = PositionalEncoding(max_len, d_model)(x)

    # Build padding mask once
    pad_mask = PaddingMask()(seq_in)

    # Stack encoder layers
    for _ in range(n_layers):
        x = transformer_block(d_model, n_heads, d_ff, dropout)([x, pad_mask])

    # We need the hidden state that corresponds to *frame t* (the last row)
    # – that is always index -1 thanks to left padding.
    h_t = layers.Lambda(lambda t: t[:, -1])(x)          # (B, D)

    # Regress the 46 co‑ordinates
    out = layers.Dense(num_feats, name="pred_xy")(h_t)

    return keras.Model(seq_in, out, name="NFL_Frame_Predictor")

model = build_model()
model.summary()


In [None]:
# ╔════════════════════╗
# ║ 6. Compile & train ║
# ╚════════════════════╝

# ── 1)  Make sure we have a place to put checkpoints ─────────────────
WEIGHT_DIR = Path("../weights")
WEIGHT_DIR.mkdir(parents=True, exist_ok=True)



ckpt_cb = keras.callbacks.ModelCheckpoint(
    filepath=(WEIGHT_DIR /
              "epoch_{epoch:03d}-val{val_loss:.6f}.weights.h5").as_posix(),
    monitor="val_loss",
    save_best_only=False,      # save every epoch → “periodic” archive
    save_weights_only=True,    # just the weights, not optimizer state
    verbose=0,
)

# ── 2)  Early-stopping ───────────────────────────────────────────────
PATIENCE = 5

early_stop = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
    verbose=1,
)

# ── 3)  Compile the model ────────────────────────────────────────────
LR = 1e-4
model.compile(
    optimizer=keras.optimizers.Adam(LR),
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.MeanAbsoluteError()],
)

# ── 4)  Fit – stop early, save weights each epoch ────────────────────
EPOCHS = 10_000   # high ceiling; early-stop decides real count
STEPS_PER_EPOCH = 1_000
VAL_STEPS_PER_EPCH = 100
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VAL_STEPS_PER_EPCH,
    validation_data=val_ds,
    callbacks=[early_stop, ckpt_cb],
    verbose=1,
)

# optional: evaluate on test set after training
test_loss, test_mae = model.evaluate(test_ds, verbose=1)
print(f"\n✅  Test MSE: {test_loss:.5f}   |   Test MAE: {test_mae:.5f}")


Epoch 1/10000


2025-05-01 02:03:22.317981: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:23: Filling up shuffle buffer (this may take a while): 2851 of 4096
2025-05-01 02:03:24.003989: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.
I0000 00:00:1746061407.204001   35478 service.cc:152] XLA service 0x7f0350001770 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1746061407.204079   35478 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 4080 Laptop GPU, Compute Capability 8.9
2025-05-01 02:03:27.500516: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1746061409.644903   35478 cuda_dnn.cc:529] Loaded cuDNN version 90300










I0000 00:00:1746061432.621921   35478 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the l

[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9s/step - loss: 0.3856 - mean_absolute_error: 0.4294

















[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4041s[0m 40s/step - loss: 0.3835 - mean_absolute_error: 0.4280 - val_loss: 0.0128 - val_mean_absolute_error: 0.0869
Epoch 2/10000
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3940s[0m 40s/step - loss: 0.0545 - mean_absolute_error: 0.1849 - val_loss: 0.0077 - val_mean_absolute_error: 0.0682
Epoch 3/10000
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4049s[0m 41s/step - loss: 0.0366 - mean_absolute_error: 0.1518 - val_loss: 0.0056 - val_mean_absolute_error: 0.0577
Epoch 4/10000
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3989s[0m 40s/step - loss: 0.0270 - mean_absolute_error: 0.1304 - val_loss: 0.0052 - val_mean_absolute_error: 0.0551
Epoch 5/10000
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3923s[0m 40s/step - loss: 0.0209 - mean_absolute_error: 0.1145 - val_loss: 0.0046 - val_mean_absolute_error: 0.0519
Epoch 6/10000
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

In [None]:
# ╔═══════════════╗
# ║ 7. Evaluation ║ -
# ╚═══════════════╝
# Simple end‑to‑end evaluation on a held‑out batch
for X_batch, y_batch in val_ds.take(1):
    y_pred = model(X_batch)
    mse = tf.reduce_mean(tf.square(y_pred - y_batch))
    print("Validation MSE (batch):", mse.numpy())
