In [125]:
from pathlib import Path
import tensorflow as tf
import numpy as np
import statistics as st

In [126]:
dataset_path = "../processed_data/transformer_dataset"

# Load dataset without any transformations
raw_ds = tf.data.Dataset.load(dataset_path)

# Print dataset structure
print("Dataset element specification:", raw_ds.element_spec)

# Examine first 3 examples
print("\nFirst 3 examples:")
for i, example in enumerate(raw_ds.take(3)):
    # Each example contains 3 components:
    meta_tensor = example[0]  # Metadata (gameId, playId, split_id, firstFrameId)
    x_tensor = example[1]     # Input sequence (padded frames)
    y_tensor = example[2]     # Target vector
    
    print(f"\nExample {i+1}:")
    print("Metadata tensor:", meta_tensor)
    print(f"Metadata values: {meta_tensor.numpy()}")
    print(f"Input shape: {x_tensor.shape} | dtype: {x_tensor.dtype}")
    print(f"Target shape: {y_tensor.shape} | dtype: {y_tensor.dtype}")
    
    # First 5 elements of first frame's features
    print("First frame features (first 5 values):", x_tensor[0, :5].numpy())
    print("Target values (first 5):", y_tensor[:5].numpy())

Dataset element specification: (TensorSpec(shape=(4,), dtype=tf.int32, name=None), TensorSpec(shape=(100, 46), dtype=tf.float32, name=None), TensorSpec(shape=(46,), dtype=tf.float32, name=None))

First 3 examples:

Example 1:
Metadata tensor: tf.Tensor([2022091809         98          0          1], shape=(4,), dtype=int32)
Metadata values: [2022091809         98          0          1]
Input shape: (100, 46) | dtype: <dtype: 'float32'>
Target shape: (46,) | dtype: <dtype: 'float32'>
First frame features (first 5 values): [0. 0. 0. 0. 0.]
Target values (first 5): [0.65816665 0.45928705 0.67158335 0.41782364 0.66908336]

Example 2:
Metadata tensor: tf.Tensor([2022091809         98          0          1], shape=(4,), dtype=int32)
Metadata values: [2022091809         98          0          1]
Input shape: (100, 46) | dtype: <dtype: 'float32'>
Target shape: (46,) | dtype: <dtype: 'float32'>
First frame features (first 5 values): [0. 0. 0. 0. 0.]
Target values (first 5): [0.657      0.4596623

In [127]:
"""# Count examples per split
split_counts = {"train": 0, "val": 0, "test": 0}
for meta, *_ in raw_ds:
    split_id = meta[2].numpy()
    split_counts["train" if split_id==0 else "val" if split_id==1 else "test"] += 1
print(split_counts)"""

'# Count examples per split\nsplit_counts = {"train": 0, "val": 0, "test": 0}\nfor meta, *_ in raw_ds:\n    split_id = meta[2].numpy()\n    split_counts["train" if split_id==0 else "val" if split_id==1 else "test"] += 1\nprint(split_counts)'

In [128]:
"""game_splits = {}
for meta, *_ in raw_ds:
    game_id = meta[0].numpy()
    split_id = meta[2].numpy()
    if game_id in game_splits:
        assert game_splits[game_id] == split_id, f"Game {game_id} in multiple splits!"
    else:
        game_splits[game_id] = split_id"""

'game_splits = {}\nfor meta, *_ in raw_ds:\n    game_id = meta[0].numpy()\n    split_id = meta[2].numpy()\n    if game_id in game_splits:\n        assert game_splits[game_id] == split_id, f"Game {game_id} in multiple splits!"\n    else:\n        game_splits[game_id] = split_id'

In [129]:
"""from collections import Counter
seq_lengths = []
for meta, x, _ in raw_ds:
    seq_len = tf.math.count_nonzero(tf.reduce_any(x != 0, axis=1)).numpy()
    seq_lengths.append(seq_len)
print("Sequence length distribution:", Counter(seq_lengths))"""

'from collections import Counter\nseq_lengths = []\nfor meta, x, _ in raw_ds:\n    seq_len = tf.math.count_nonzero(tf.reduce_any(x != 0, axis=1)).numpy()\n    seq_lengths.append(seq_len)\nprint("Sequence length distribution:", Counter(seq_lengths))'

In [130]:
# Load the dataset and filter based on split_id
def filter_split(split_num):
    def _filter(meta, x, y):
        return tf.equal(meta[2], split_num)
    return _filter

# Split the dataset into train, val, test using the split_id
train_ds = raw_ds.filter(filter_split(0)).shuffle(4096).batch(64).prefetch(tf.data.AUTOTUNE)
val_ds = raw_ds.filter(filter_split(1)).batch(64).prefetch(tf.data.AUTOTUNE)
test_ds = raw_ds.filter(filter_split(2)).batch(64).prefetch(tf.data.AUTOTUNE)

In [131]:
# Examine first 3 examples
print("\nFirst 3 examples:")
for i, example in enumerate(train_ds.take(3)):
    # Each example contains 3 components:
    meta_tensor = example[0]  # Metadata (gameId, playId, split_id, firstFrameId)
    x_tensor = example[1]     # Input sequence (padded frames)
    y_tensor = example[2]     # Target vector
    
    print(f"\nExample {i+1}:")
    print("Metadata tensor:", meta_tensor)
    print(f"Metadata values: {meta_tensor.numpy()}")
    print(f"Input shape: {x_tensor.shape} | dtype: {x_tensor.dtype}")
    print(f"Target shape: {y_tensor.shape} | dtype: {y_tensor.dtype}")
    
    # First 5 elements of first frame's features
    print("First frame features (first 5 values):", x_tensor[0, :-5].numpy())
    print("Target values (first 5):", y_tensor[:5].numpy())


First 3 examples:

Example 1:
Metadata tensor: tf.Tensor(
[[2022091100        166          0         80]
 [2022091104       2662          0          1]
 [2022091101       3287          0          1]
 [2022091200        264          0         44]
 [2022091803        592          0          1]
 [2022091805       3235          0         89]
 [2022091807        330          0         67]
 [2022091807        330          0          1]
 [2022091806       2124          0         61]
 [2022091809        695          0          1]
 [2022091200        264          0          1]
 [2022091806       2124          0          1]
 [2022091800        338          0          1]
 [2022091807       1374          0          1]
 [2022091807        735          0          1]
 [2022091807       1241          0         90]
 [2022091100        166          0          1]
 [2022091803        365          0          1]
 [2022091102       2783          0          1]
 [2022091200        264          0          1]
 

In [132]:
print("Dataset cardinality:",
      tf.data.experimental.cardinality(raw_ds).numpy())   # should now print a number


Dataset cardinality: 592871


In [133]:
def drop_meta(meta, x, y):
    return x, y

train_ds = (raw_ds
            .filter(filter_split(0))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(4096)
            .batch(64)
            .prefetch(tf.data.AUTOTUNE))

val_ds   = (raw_ds
            .filter(filter_split(1))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(64)
            .prefetch(tf.data.AUTOTUNE))

test_ds  = (raw_ds
            .filter(filter_split(2))
            .map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(64)
            .prefetch(tf.data.AUTOTUNE))

# Take one batch from the dataset
for x_batch, y_batch in train_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

for x_batch, y_batch in val_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

for x_batch, y_batch in test_ds.take(1):
    print("x_batch shape:", x_batch.shape)
    print("y_batch shape:", y_batch.shape)

x_batch shape: (64, 100, 46)
y_batch shape: (64, 46)
x_batch shape: (64, 100, 46)
y_batch shape: (64, 46)
x_batch shape: (64, 100, 46)
y_batch shape: (64, 46)


In [134]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

NUM_FEATS = 46          # x,y for 23 entities
MAX_LEN  = 100          # same value you used in dataset builder
D_MODEL  = 128          # transformer hidden size
N_HEADS  = 4
N_LAYERS = 4
D_FF     = 512
DROPOUT  = 0.1

In [135]:
# ╔═══════════════════╗
# ║ 2. Positional enc ║  (learnable 1‑D embedding)
# ╚═══════════════════╝
class PositionalEncoding(layers.Layer):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_emb = self.add_weight(
            name="pos_emb",
            shape=(max_len, d_model),
            initializer="uniform",
            trainable=True,
        )

    def call(self, x):
        return x + self.pos_emb


In [136]:
# ╔═══════════════════════════╗
# ║ 3. Padding‑mask function  ║
# ╚═══════════════════════════╝
class PaddingMask(layers.Layer):
    def call(self, x):
        # x:  (B, T, F) — zero‐padded on the left
        pad = tf.reduce_all(tf.equal(x, 0.0), axis=-1)      # → (B, T)
        # reshape to (B, 1, 1, T) for MultiHeadAttention
        return pad[:, tf.newaxis, tf.newaxis, :]



In [137]:
# ╔════════════════════════╗
# ║ 4. Transformer encoder ║
# ╚════════════════════════╝
def transformer_block(d_model, n_heads, d_ff, dropout):
    inputs   = layers.Input(shape=(None, d_model))
    padding  = layers.Input(shape=(1,1,None), dtype=tf.bool)  # mask

    x = layers.MultiHeadAttention(
        num_heads=n_heads, key_dim=d_model//n_heads, dropout=dropout
    )(inputs, inputs, attention_mask=padding)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(inputs + x)

    y = layers.Dense(d_ff, activation="relu")(x)
    y = layers.Dense(d_model)(y)
    y = layers.Dropout(dropout)(y)
    y = layers.LayerNormalization(epsilon=1e-6)(x + y)

    return keras.Model([inputs, padding], y)


In [138]:
# ╔════════════════════════════════╗
# ║ 5. End‑to‑end prediction model ║
# ╚════════════════════════════════╝
def build_model(
    num_feats=NUM_FEATS,
    max_len=MAX_LEN,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT,
):
    seq_in  = layers.Input(shape=(max_len, num_feats), name="sequence")   # (B,T,F)

    # Linear projection to d_model
    x = layers.Dense(d_model)(seq_in)

    # Add learnable positional encodings
    x = PositionalEncoding(max_len, d_model)(x)

    # Build padding mask once
    pad_mask = PaddingMask()(seq_in)

    # Stack encoder layers
    for _ in range(n_layers):
        x = transformer_block(d_model, n_heads, d_ff, dropout)([x, pad_mask])

    # We need the hidden state that corresponds to *frame t* (the last row)
    # – that is always index -1 thanks to left padding.
    h_t = layers.Lambda(lambda t: t[:, -1])(x)          # (B, D)

    # Regress the 46 co‑ordinates
    out = layers.Dense(num_feats, name="pred_xy")(h_t)

    return keras.Model(seq_in, out, name="NFL_Frame_Predictor")

model = build_model()
model.summary()


Model: "NFL_Frame_Predictor"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 sequence (InputLayer)          [(None, 100, 46)]    0           []                               
                                                                                                  
 dense_27 (Dense)               (None, 100, 128)     6016        ['sequence[0][0]']               
                                                                                                  
 positional_encoding_3 (Positio  (None, 100, 128)    12800       ['dense_27[0][0]']               
 nalEncoding)                                                                                     
                                                                                                  
 padding_mask_3 (PaddingMask)   (None, 1, 1, 100)    0           ['sequence[0][0

In [None]:
# ╔════════════════════╗
# ║ 6. Compile & train ║
# ╚════════════════════╝

# ── 1)  Make sure we have a place to put checkpoints ─────────────────
WEIGHT_DIR = Path("../weights")
WEIGHT_DIR.mkdir(parents=True, exist_ok=True)

ckpt_cb = keras.callbacks.ModelCheckpoint(
    filepath=(WEIGHT_DIR /
              "epoch_{epoch:03d}-val{val_loss:.4f}.keras").as_posix(),
    monitor="val_loss",
    save_best_only=False,      # save every epoch → “periodic” archive
    save_weights_only=True,    # just the weights, not optimizer state
    verbose=0,
)

# ── 2)  Early-stopping ───────────────────────────────────────────────
early_stop = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
    verbose=1,
)

# ── 3)  Compile the model ────────────────────────────────────────────
LR = 1e-4
model.compile(
    optimizer=keras.optimizers.Adam(LR),
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.MeanAbsoluteError()],
)

# ── 4)  Fit – stop early, save weights each epoch ────────────────────
EPOCHS = 5   # high ceiling; early-stop decides real count
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[early_stop, ckpt_cb],
    verbose=1,
)

# optional: evaluate on test set after training
test_loss, test_mae = model.evaluate(test_ds, verbose=1)
print(f"\n✅  Test MSE: {test_loss:.5f}   |   Test MAE: {test_mae:.5f}")


Epoch 1/100
      6/Unknown - 11s 815ms/step - loss: 0.1442 - mean_absolute_error: 0.3017

KeyboardInterrupt: 

In [None]:
# ╔═══════════════╗
# ║ 7. Evaluation ║
# ╚═══════════════╝
# Simple end‑to‑end evaluation on a held‑out batch
for X_batch, y_batch in val_ds.take(1):
    y_pred = model(X_batch)
    mse = tf.reduce_mean(tf.square(y_pred - y_batch))
    print("Validation MSE (batch):", mse.numpy())
