# Baseline **LSTM** Frame Predictor

A companion to the Transformer notebook — shares the same dataset loading,
section layout, and training loop so you can compare apples‑to‑apples.

<small>*(Sections are numbered & boxed exactly like the Transformer notebook.)*</small>

In [4]:
import tensorflow as tf
import numpy as np

# ╔════════════════════╗
# ║ 1. Load  Dataset  ║
# ╚════════════════════╝
DATASET_PATH = "../processed_data/transformer_dataset"  # adjust if needed

ds = (tf.data.Dataset.load(DATASET_PATH)
      .shuffle(4096)
      .batch(64)
      .prefetch(tf.data.AUTOTUNE))

# Peek at one batch
for X, y in ds.take(1):
    print("X shape:", X.shape)
    print("y shape:", y.shape)
    break

X shape: (64, 100, 46)
y shape: (64, 46)


In [5]:
import matplotlib.pyplot as plt

# ╔═══════════════════╗
# ║ 2. Sanity checks ║
# ╚═══════════════════╝
NUM_FEATS = 46
MAX_LEN   = 100

def sanity_check(x, y):
    assert x.shape == (MAX_LEN, NUM_FEATS)
    assert y.shape == (NUM_FEATS,)
    nonzero_rows = np.any(x.numpy() != 0, axis=1)
    pad = MAX_LEN - np.count_nonzero(nonzero_rows)
    print(f"Padding rows: {pad}")
    if np.allclose(y.numpy(), 0):
        print("⚠️ y is all zeros!")

for X_batch, y_batch in ds.take(1):
    sanity_check(X_batch[0], y_batch[0])
    break

Padding rows: 42


In [6]:
from tensorflow import keras
from tensorflow.keras import layers

# ╔══════════════════════╗
# ║ 3. LSTM Model Build ║
# ╚══════════════════════╝
# Hyper-parameters
LSTM_UNITS     = 128          # hidden size
EMBED_DIM      = 4            # position-token embedding
POSITION_VOCAB = 15           # distinct position IDs in data

def mse_sum_xy(y_true, y_pred):
    diff = tf.reshape(y_true - y_pred, (-1, 23, 2))   # (B,23,2)
    return tf.reduce_sum(tf.square(diff), axis=[1, 2])

def build_lstm(
    max_len   = MAX_LEN,
    num_feats = NUM_FEATS,     # 46 coords
    lstm_units=LSTM_UNITS,
    embed_dim =EMBED_DIM,
    pos_vocab =POSITION_VOCAB,
):
    # ── Inputs ───────────────────────────────────────────────────────────────
    coords = layers.Input(shape=(max_len, num_feats), name="coords_seq")        # (B,T,46)
    wh_in  = layers.Input(shape=(23, 2),   name="weight_height")                # (B,23,2)
    pos_id = layers.Input(shape=(23,), dtype="int32", name="position_id")       # (B,23)

    # ── Static-feature embedding & tiling ────────────────────────────────────
    pos_emb = layers.Embedding(pos_vocab, embed_dim,
                               name="position_embedding")(pos_id)               # (B,23,E)
    static  = layers.Concatenate(axis=-1)([wh_in, pos_emb])   # (B,23,2+E)
    static  = layers.Flatten()(static)                        # (B, 23*(2+E))
    static  = layers.RepeatVector(max_len)(static)            # (B,T,*)

    # ── Sequence + static concatenation ─────────────────────────────────────
    lstm_in = layers.Concatenate(axis=-1)([coords, static])   # (B,T,46+138)

    # ── LSTM backbone ───────────────────────────────────────────────────────
    x = layers.LSTM(lstm_units, name="lstm_backbone")(lstm_in)
    x = layers.Dense(lstm_units, activation="relu", name="dense_relu")(x)
    out = layers.Dense(num_feats, name="pred_xy")(x)          # (B,46)

    model = keras.Model([coords, wh_in, pos_id], out, name="Baseline_LSTM")
    model.compile(optimizer=keras.optimizers.Adam(1e-4),
                  loss=mse_sum_xy,
                  metrics=[keras.metrics.MeanAbsoluteError()])
    return model

model = build_lstm()
model.summary()


In [8]:
# ╔══════════════════════╗
# ║ 5. Compile & Train  ║
# ╚══════════════════════╝
EPOCHS = 15
val_split = 0.05
val_ds = ds.take(int(len(ds)*val_split))
train_ds = ds.skip(int(len(ds)*val_split))

history = model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds, verbose=2)

Epoch 1/15


ValueError: Layer "Baseline_LSTM" expects 3 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'data:0' shape=(None, 100, 46) dtype=float32>]

In [None]:
# ╔═══════════════╗
# ║ 6. Evaluation ║
# ╚═══════════════╝
for X_batch, y_batch in val_ds.take(1):
    y_pred = model(X_batch)
    mse = tf.reduce_mean(tf.square(y_pred - y_batch))
    print("Validation MSE (batch):", mse.numpy())

In [None]:
# ╔════════════════╗
# ║ 7. Curves Plot ║
# ╚════════════════╝
plt.figure()
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='val')
plt.title('Loss curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()