# SnackTrack ML --- RNN/GRU Training for Meal Sequence Prediction

This notebook trains a **GRU-based Recurrent Neural Network** to predict a user's next meal based on
their chronological eating history. The RNN captures temporal patterns such as:

- **Time-of-day preferences** (e.g., lighter breakfasts, hearty dinners)
- **Weekly meal rotation** (e.g., Taco Tuesday, Fish Friday)
- **Seasonal and progressive dietary changes** over time

### Architecture

| Component | Dimension | Description |
|-----------|-----------|-------------|
| Input | 39D | 32D recipe embedding + 7D cyclical time features |
| GRU Hidden | 64D | Custom GRU cell matching the NumPy production code |
| Output | 32D | Predicted next-meal recipe embedding |

### Why a custom GRU cell?

PyTorch's built-in `nn.GRUCell` applies the reset gate **before** the hidden-to-hidden
multiplication: `h_candidate = tanh(W_h @ x + U_h @ (r * h_prev) + b_h)`. Our production
NumPy code applies it **after**: `h_candidate = tanh(W_h @ x + (r * h_prev) @ U_h + b_h)`.
This is a subtle but critical difference --- if we used `nn.GRUCell`, the exported weights
would produce different outputs in production. We therefore implement a custom GRU cell
whose forward pass is an exact PyTorch translation of the NumPy `gru_step()`.

### Critical note

The production RNN (`app/recommender/rnn.py`) currently uses **random placeholder weights**.
This notebook trains real weights from data and exports them in the exact format expected
by `utils/weight_io.py`.

In [None]:
import sys
import os
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm.auto import tqdm

warnings.filterwarnings("ignore", category=FutureWarning)

# ---------------------------------------------------------------------------
# Path setup: make notebook utils importable
# ---------------------------------------------------------------------------
sys.path.insert(0, "..")

from notebooks.utils.plot_helpers import setup_plot_style, plot_loss_curves, SNACKTRACK_COLORS, PALETTE
from notebooks.utils.data_loader import load_kaggle_dataset, _encode_time_features
from notebooks.utils.weight_io import save_rnn_weights, load_rnn_weights, RNN_WEIGHT_SHAPES

setup_plot_style()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch {torch.__version__} | Device: {DEVICE}")
print(f"NumPy {np.__version__} | Pandas {pd.__version__}")

## 1. Prepare Sequence Data

We build user meal sequences from two complementary data sources:

1. **Food.com interactions** (`foodcom_interactions`) --- large-scale user-recipe interaction
   logs with timestamps, providing real chronological meal sequences.
2. **Daily Food Nutrition** (`daily_food_nutrition`) --- daily food logs with meal-level
   detail including meal type and nutritional information.

For each meal event we construct a **39-dimensional input vector**:
- **32D recipe embedding**: from the recipe's ingredient/nutrition vector, or a
  feature-based proxy when the full vector is unavailable.
- **7D time features**: cyclical encoding of day-of-week (sin/cos), hour (sin/cos),
  meal type (sin/cos), and a normalized meal-type index.

The time encoding uses sin/cos pairs so that the model understands that
Sunday (6) and Monday (0) are adjacent, and 23:00 is close to 00:00.

We then create **sliding windows of length 20**, where the target for each
window is the 32D recipe embedding of the *next* meal. The train/validation
split is done **by user** (not by sequence) to prevent data leakage.

In [None]:
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
SEQ_LEN = 20           # sliding window length
RECIPE_EMB_DIM = 32    # recipe embedding dimensionality
TIME_FEAT_DIM = 7      # cyclical time features
INPUT_DIM = RECIPE_EMB_DIM + TIME_FEAT_DIM  # 39
VAL_FRACTION = 0.2     # fraction of users held out for validation
RANDOM_SEED = 42

np.random.seed(RANDOM_SEED)


# ---------------------------------------------------------------------------
# Time feature encoding (matches production MealSequenceRNN.encode_time_features)
# ---------------------------------------------------------------------------
def encode_time_features(logged_at, meal_type):
    """Encode temporal context into a 7D vector with cyclical sin/cos pairs.

    This MUST match the production code in app/recommender/rnn.py exactly.
    """
    dow = logged_at.weekday() if hasattr(logged_at, "weekday") else 0
    hour = logged_at.hour if hasattr(logged_at, "hour") else 12

    day_sin = np.sin(2 * np.pi * dow / 7)
    day_cos = np.cos(2 * np.pi * dow / 7)
    hour_sin = np.sin(2 * np.pi * hour / 24)
    hour_cos = np.cos(2 * np.pi * hour / 24)

    meal_types = {"breakfast": 0, "lunch": 1, "dinner": 2, "snack": 3}
    meal_idx = meal_types.get(meal_type, 1)
    meal_sin = np.sin(2 * np.pi * meal_idx / 4)
    meal_cos = np.cos(2 * np.pi * meal_idx / 4)

    return np.array([day_sin, day_cos, hour_sin, hour_cos,
                     meal_sin, meal_cos, meal_idx / 3.0])


# ---------------------------------------------------------------------------
# Recipe embedding helper
# ---------------------------------------------------------------------------
def make_recipe_embedding(row, nutrition_cols=None):
    """Create a 32D recipe embedding from available nutritional features.

    Uses a normalised feature-based proxy: the first few dimensions encode
    key macro-nutrients, and remaining dimensions are zero-padded.
    """
    emb = np.zeros(RECIPE_EMB_DIM, dtype=np.float64)
    emb[0] = (row.get("calories") or 0) / 1000.0
    emb[1] = (row.get("protein_g") or row.get("protein") or 0) / 100.0
    emb[2] = (row.get("carbs_g") or row.get("carbohydrate") or row.get("carbs") or 0) / 200.0
    emb[3] = (row.get("fat_g") or row.get("total_fat") or row.get("fat") or 0) / 100.0
    emb[4] = (row.get("sodium_mg") or row.get("sodium") or 0) / 2300.0
    emb[5] = (row.get("fiber_g") or row.get("fiber") or 0) / 30.0
    emb[6] = (row.get("sugar_g") or row.get("sugar") or 0) / 50.0
    return emb


# ---------------------------------------------------------------------------
# Load datasets
# ---------------------------------------------------------------------------
print("Loading datasets...")

sequences_all = []   # list of (user_id, [(emb_39d, target_32d), ...])
user_sequence_map = {}  # user_id -> list of (input_window, target)

# --- Source 1: Food.com interactions ---
try:
    interactions_df = load_kaggle_dataset("foodcom_interactions")
    print(f"  Food.com interactions: {len(interactions_df):,} rows")

    # Parse date column
    date_col = None
    for col in ["date", "submitted", "created_at", "interaction_date"]:
        if col in interactions_df.columns:
            date_col = col
            break
    if date_col is None:
        # Try to find any date-like column
        for col in interactions_df.columns:
            if "date" in col.lower() or "time" in col.lower():
                date_col = col
                break

    if date_col:
        interactions_df["timestamp"] = pd.to_datetime(interactions_df[date_col], errors="coerce")
    else:
        # Assign synthetic timestamps so we can still train
        interactions_df["timestamp"] = pd.date_range(
            start="2020-01-01", periods=len(interactions_df), freq="h"
        )
    interactions_df = interactions_df.dropna(subset=["timestamp"])

    # Identify user and recipe columns
    user_col = next((c for c in ["user_id", "author_id", "contributor_id"] if c in interactions_df.columns), None)
    recipe_col = next((c for c in ["recipe_id", "id"] if c in interactions_df.columns), None)

    if user_col and recipe_col:
        # Identify available nutrition columns
        nutrition_present = [c for c in ["calories", "protein", "protein_g",
                                         "carbohydrate", "carbs_g", "total_fat",
                                         "fat_g", "sodium", "sodium_mg",
                                         "fiber", "fiber_g", "sugar", "sugar_g"]
                            if c in interactions_df.columns]

        # Infer meal type from hour
        def infer_meal_type(hour):
            if 5 <= hour < 11:
                return "breakfast"
            elif 11 <= hour < 15:
                return "lunch"
            elif 15 <= hour < 21:
                return "dinner"
            return "snack"

        # Group by user and build sequences
        user_groups = interactions_df.sort_values("timestamp").groupby(user_col)

        for uid, group in tqdm(user_groups, desc="Building Food.com sequences", leave=False):
            if len(group) < SEQ_LEN + 1:
                continue

            meal_vectors = []
            for _, row in group.iterrows():
                emb = make_recipe_embedding(row)
                ts = row["timestamp"]
                mt = infer_meal_type(ts.hour if hasattr(ts, "hour") else 12)
                time_feat = encode_time_features(ts, mt)
                meal_vectors.append(np.concatenate([emb, time_feat]))

            # Sliding windows
            windows = []
            for i in range(len(meal_vectors) - SEQ_LEN):
                inp = np.array(meal_vectors[i : i + SEQ_LEN])  # (20, 39)
                target = meal_vectors[i + SEQ_LEN][:RECIPE_EMB_DIM]  # (32,)
                windows.append((inp, target))

            if windows:
                user_key = f"foodcom_{uid}"
                user_sequence_map[user_key] = windows

    print(f"  -> {len(user_sequence_map)} users with enough history from Food.com")

except FileNotFoundError:
    print("  Food.com interactions not found. Skipping.")

# --- Source 2: Daily Food Nutrition ---
try:
    nutrition_df = load_kaggle_dataset("daily_food_nutrition")
    print(f"  Daily Food Nutrition: {len(nutrition_df):,} rows")

    # Parse date/time
    date_col = None
    for col in ["date", "log_date", "day", "timestamp"]:
        if col in nutrition_df.columns:
            date_col = col
            break

    if date_col:
        nutrition_df["timestamp"] = pd.to_datetime(nutrition_df[date_col], errors="coerce")
    else:
        nutrition_df["timestamp"] = pd.date_range(
            start="2021-01-01", periods=len(nutrition_df), freq="h"
        )
    nutrition_df = nutrition_df.dropna(subset=["timestamp"])

    # Identify columns
    user_col_n = next((c for c in ["user_id", "user", "person", "name"]
                       if c in nutrition_df.columns), None)
    meal_col = next((c for c in ["meal_type", "meal", "category"]
                     if c in nutrition_df.columns), None)

    if user_col_n is None:
        # Create synthetic user IDs from row index groups
        nutrition_df["user_id_synth"] = nutrition_df.index // 100
        user_col_n = "user_id_synth"

    user_groups_n = nutrition_df.sort_values("timestamp").groupby(user_col_n)

    n_before = len(user_sequence_map)
    for uid, group in tqdm(user_groups_n, desc="Building nutrition sequences", leave=False):
        if len(group) < SEQ_LEN + 1:
            continue

        meal_vectors = []
        for _, row in group.iterrows():
            emb = make_recipe_embedding(row)
            ts = row["timestamp"]
            mt = row[meal_col].lower() if meal_col and pd.notna(row.get(meal_col)) else "lunch"
            time_feat = encode_time_features(ts, mt)
            meal_vectors.append(np.concatenate([emb, time_feat]))

        windows = []
        for i in range(len(meal_vectors) - SEQ_LEN):
            inp = np.array(meal_vectors[i : i + SEQ_LEN])
            target = meal_vectors[i + SEQ_LEN][:RECIPE_EMB_DIM]
            windows.append((inp, target))

        if windows:
            user_key = f"nutrition_{uid}"
            user_sequence_map[user_key] = windows

    print(f"  -> {len(user_sequence_map) - n_before} additional users from Daily Food Nutrition")

except FileNotFoundError:
    print("  Daily Food Nutrition not found. Skipping.")

# ---------------------------------------------------------------------------
# Fallback: generate synthetic sequences for development/testing
# ---------------------------------------------------------------------------
if len(user_sequence_map) == 0:
    print("\n  WARNING: No real datasets available. Generating synthetic sequences.")
    print("  Run notebook 00 first to download the datasets for real training.")

    rng = np.random.default_rng(RANDOM_SEED)
    N_SYNTH_USERS = 200
    MEALS_PER_USER = 60

    for u in range(N_SYNTH_USERS):
        # Each user has a latent "taste" vector that slowly drifts
        taste = rng.standard_normal(RECIPE_EMB_DIM) * 0.3
        meal_vectors = []
        base_date = pd.Timestamp("2023-01-01")

        for m in range(MEALS_PER_USER):
            # Drift taste slightly
            taste += rng.standard_normal(RECIPE_EMB_DIM) * 0.01
            emb = taste + rng.standard_normal(RECIPE_EMB_DIM) * 0.1

            ts = base_date + pd.Timedelta(hours=m * 6)  # ~4 meals/day
            hour = ts.hour
            if 5 <= hour < 11:
                mt = "breakfast"
            elif 11 <= hour < 15:
                mt = "lunch"
            elif 15 <= hour < 21:
                mt = "dinner"
            else:
                mt = "snack"

            time_feat = encode_time_features(ts, mt)
            meal_vectors.append(np.concatenate([emb, time_feat]))

        windows = []
        for i in range(len(meal_vectors) - SEQ_LEN):
            inp = np.array(meal_vectors[i : i + SEQ_LEN])
            target = meal_vectors[i + SEQ_LEN][:RECIPE_EMB_DIM]
            windows.append((inp, target))

        user_sequence_map[f"synth_{u}"] = windows

    print(f"  Generated {N_SYNTH_USERS} synthetic users.")

print(f"\nTotal users with sequences: {len(user_sequence_map)}")


# ---------------------------------------------------------------------------
# Train/val split by user
# ---------------------------------------------------------------------------
all_users = sorted(user_sequence_map.keys())
np.random.shuffle(all_users)
split_idx = int(len(all_users) * (1 - VAL_FRACTION))
train_users = set(all_users[:split_idx])
val_users = set(all_users[split_idx:])

train_inputs, train_targets = [], []
val_inputs, val_targets = [], []

for user, windows in user_sequence_map.items():
    for inp, target in windows:
        if user in train_users:
            train_inputs.append(inp)
            train_targets.append(target)
        else:
            val_inputs.append(inp)
            val_targets.append(target)

X_train = np.array(train_inputs, dtype=np.float32)
y_train = np.array(train_targets, dtype=np.float32)
X_val = np.array(val_inputs, dtype=np.float32)
y_val = np.array(val_targets, dtype=np.float32)

print(f"Training:   {X_train.shape[0]:,} sequences from {len(train_users)} users")
print(f"Validation: {X_val.shape[0]:,} sequences from {len(val_users)} users")
print(f"Input shape:  {X_train.shape}")
print(f"Target shape: {y_train.shape}")


# ---------------------------------------------------------------------------
# Create PyTorch DataLoaders
# ---------------------------------------------------------------------------
class MealSequenceDataset(Dataset):
    def __init__(self, sequences, targets):
        self.sequences = torch.tensor(sequences, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.targets[idx]


BATCH_SIZE = 128

train_dataset = MealSequenceDataset(X_train, y_train)
val_dataset = MealSequenceDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

print(f"\nTrain batches: {len(train_loader)} | Val batches: {len(val_loader)}")

In [None]:
# ---------------------------------------------------------------------------
# Sequence statistics
# ---------------------------------------------------------------------------
seq_counts = {u: len(w) for u, w in user_sequence_map.items()}
total_sequences = sum(seq_counts.values())
unique_users = len(seq_counts)
avg_seq_per_user = np.mean(list(seq_counts.values()))
median_seq_per_user = np.median(list(seq_counts.values()))
max_seq_per_user = max(seq_counts.values())
min_seq_per_user = min(seq_counts.values())

print(f"Sequence Statistics")
print(f"{'=' * 45}")
print(f"  Total sequences:           {total_sequences:,}")
print(f"  Unique users:              {unique_users:,}")
print(f"  Avg sequences per user:    {avg_seq_per_user:.1f}")
print(f"  Median sequences per user: {median_seq_per_user:.1f}")
print(f"  Min / Max per user:        {min_seq_per_user} / {max_seq_per_user}")
print(f"  Train users / Val users:   {len(train_users)} / {len(val_users)}")
print(f"  Sequence length (window):  {SEQ_LEN}")
print(f"  Input dim:                 {INPUT_DIM}")
print(f"  Target dim:                {RECIPE_EMB_DIM}")

# Distribution plot
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(list(seq_counts.values()), bins=50, color=SNACKTRACK_COLORS["primary"], alpha=0.8,
        edgecolor="white", linewidth=0.5)
ax.set_xlabel("Sequences per User")
ax.set_ylabel("Number of Users")
ax.set_title("Distribution of Sequence Count per User")
ax.axvline(avg_seq_per_user, color=SNACKTRACK_COLORS["secondary"], linestyle="--",
           label=f"Mean = {avg_seq_per_user:.1f}")
ax.legend()
plt.tight_layout()
plt.show()

## 2. Define Custom GRU Model

We implement a custom `CustomGRUCell` whose `forward()` method is a line-by-line
PyTorch translation of the NumPy `MealSequenceRNN.gru_step()`:

```
z = sigmoid(x @ Wz + h_prev @ Uz + bz)           # update gate
r = sigmoid(x @ Wr + h_prev @ Ur + br)            # reset gate
h_candidate = tanh(x @ Wh + (r * h_prev) @ Uh + bh)  # candidate
h_new = (1 - z) * h_prev + z * h_candidate         # new hidden state
```

The weight matrices are defined as `nn.Parameter` with the **same shapes** as
production (`Wz/Wr/Wh: (39, 64)`, `Uz/Ur/Uh: (64, 64)`, `bz/br/bh: (64,)`).
After training, the weights can be exported directly without transposition
(except for the output projection `nn.Linear`, which stores `weight` as
`(out_features, in_features)`).

In [None]:
class CustomGRUCell(nn.Module):
    """Custom GRU cell that exactly matches the NumPy production implementation.

    Key difference from nn.GRUCell:
      - nn.GRUCell computes: h_candidate = tanh(Wh @ x + Uh @ (r * h_prev) + bh)
      - Our code computes:   h_candidate = tanh(x @ Wh + (r * h_prev) @ Uh + bh)

    The reset gate is applied element-wise to h_prev BEFORE the matrix multiply
    with Uh, which is mathematically equivalent to: (r * h_prev) @ Uh.
    This differs from the standard formulation where r gates the result of Uh @ h_prev.
    """

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # Update gate parameters
        self.Wz = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)
        self.Uz = nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.1)
        self.bz = nn.Parameter(torch.zeros(hidden_dim))

        # Reset gate parameters
        self.Wr = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)
        self.Ur = nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.1)
        self.br = nn.Parameter(torch.zeros(hidden_dim))

        # Candidate hidden state parameters
        self.Wh = nn.Parameter(torch.randn(input_dim, hidden_dim) * 0.1)
        self.Uh = nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.1)
        self.bh = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, x, h_prev):
        """Single GRU step: exactly mirrors NumPy gru_step().

        Args:
            x: input at current timestep, shape (batch, input_dim)
            h_prev: previous hidden state, shape (batch, hidden_dim)

        Returns:
            h_new: updated hidden state, shape (batch, hidden_dim)
        """
        z = torch.sigmoid(x @ self.Wz + h_prev @ self.Uz + self.bz)
        r = torch.sigmoid(x @ self.Wr + h_prev @ self.Ur + self.br)
        h_candidate = torch.tanh(x @ self.Wh + (r * h_prev) @ self.Uh + self.bh)
        h_new = (1 - z) * h_prev + z * h_candidate
        return h_new


class MealSequenceModel(nn.Module):
    """Full model: CustomGRU unrolled over the sequence + linear output projection."""

    INPUT_DIM = 39
    HIDDEN_DIM = 64
    OUTPUT_DIM = 32

    def __init__(self):
        super().__init__()
        self.gru_cell = CustomGRUCell(self.INPUT_DIM, self.HIDDEN_DIM)
        self.output_proj = nn.Linear(self.HIDDEN_DIM, self.OUTPUT_DIM)

    def forward(self, sequence):
        """Process a batch of sequences and predict next-meal embedding.

        Args:
            sequence: (batch_size, seq_len, 39)

        Returns:
            predicted embedding: (batch_size, 32)
        """
        batch_size = sequence.size(0)
        h = torch.zeros(batch_size, self.HIDDEN_DIM, device=sequence.device)

        for t in range(sequence.size(1)):
            h = self.gru_cell(sequence[:, t, :], h)

        return self.output_proj(h)


# Instantiate
model = MealSequenceModel().to(DEVICE)

# Print parameter summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Architecture")
print("=" * 55)
print(f"{'Parameter':<25} {'Shape':<20} {'Count':>10}")
print("-" * 55)
for name, param in model.named_parameters():
    print(f"  {name:<23} {str(tuple(param.shape)):<20} {param.numel():>10,}")
print("-" * 55)
print(f"  {'Total':<23} {'':<20} {total_params:>10,}")
print(f"  {'Trainable':<23} {'':<20} {trainable_params:>10,}")

## 3. Training Loop

Training details:

| Hyperparameter | Value | Rationale |
|---------------|-------|-----------|
| Loss | MSE | Regression on continuous recipe embeddings |
| Optimizer | Adam | Adaptive LR, good for RNNs |
| Learning rate | 1e-3 | Standard starting point for Adam |
| Gradient clipping | max_norm=1.0 | Prevents exploding gradients in long sequences |
| Epochs | 150 max | Upper bound; early stopping usually triggers earlier |
| Early stopping | patience=15 | Stop if val loss hasn't improved in 15 epochs |
| LR scheduler | ReduceLROnPlateau(patience=10) | Halve LR when val loss stagnates |

In [None]:
# ---------------------------------------------------------------------------
# Training configuration
# ---------------------------------------------------------------------------
NUM_EPOCHS = 150
LEARNING_RATE = 1e-3
GRAD_CLIP_NORM = 1.0
EARLY_STOP_PATIENCE = 15
SCHEDULER_PATIENCE = 10
SCHEDULER_FACTOR = 0.5

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=SCHEDULER_PATIENCE,
    factor=SCHEDULER_FACTOR, verbose=False
)

# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
train_losses = []
val_losses = []
lr_history = []
best_val_loss = float("inf")
best_epoch = 0
patience_counter = 0
best_state = None

print(f"Training for up to {NUM_EPOCHS} epochs...")
print(f"  LR={LEARNING_RATE}, GradClip={GRAD_CLIP_NORM}, "
      f"EarlyStop={EARLY_STOP_PATIENCE}, SchedulerPatience={SCHEDULER_PATIENCE}")
print()

pbar = tqdm(range(1, NUM_EPOCHS + 1), desc="Training")

for epoch in pbar:
    # --- Train ---
    model.train()
    epoch_train_loss = 0.0
    n_train_batches = 0

    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        optimizer.zero_grad()
        predictions = model(X_batch)
        loss = criterion(predictions, y_batch)
        loss.backward()

        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)

        optimizer.step()

        epoch_train_loss += loss.item()
        n_train_batches += 1

    avg_train_loss = epoch_train_loss / max(n_train_batches, 1)

    # --- Validate ---
    model.eval()
    epoch_val_loss = 0.0
    n_val_batches = 0

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(DEVICE)
            y_batch = y_batch.to(DEVICE)

            predictions = model(X_batch)
            loss = criterion(predictions, y_batch)

            epoch_val_loss += loss.item()
            n_val_batches += 1

    avg_val_loss = epoch_val_loss / max(n_val_batches, 1)

    # --- LR scheduler ---
    current_lr = optimizer.param_groups[0]["lr"]
    scheduler.step(avg_val_loss)

    # --- Bookkeeping ---
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    lr_history.append(current_lr)

    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch
        patience_counter = 0
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    else:
        patience_counter += 1

    pbar.set_postfix({
        "train": f"{avg_train_loss:.6f}",
        "val": f"{avg_val_loss:.6f}",
        "lr": f"{current_lr:.1e}",
        "best": f"{best_val_loss:.6f}",
        "pat": f"{patience_counter}/{EARLY_STOP_PATIENCE}",
    })

    if patience_counter >= EARLY_STOP_PATIENCE:
        print(f"\nEarly stopping at epoch {epoch}. "
              f"Best val loss: {best_val_loss:.6f} at epoch {best_epoch}.")
        break

# Restore best weights
if best_state is not None:
    model.load_state_dict(best_state)
    model = model.to(DEVICE)

print(f"\nTraining complete.")
print(f"  Best epoch:    {best_epoch}")
print(f"  Best val loss: {best_val_loss:.6f}")
print(f"  Final LR:      {lr_history[-1]:.1e}")
print(f"  Total epochs:  {len(train_losses)}")

## 4. Training Visualization

We plot two charts:
1. **Loss curves**: train vs validation loss over epochs, with the best epoch marked.
2. **Learning rate schedule**: shows when the `ReduceLROnPlateau` scheduler reduced the LR.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# --- Loss curves ---
ax = axes[0]
epochs_range = range(1, len(train_losses) + 1)
ax.plot(epochs_range, train_losses, label="Train Loss",
        color=SNACKTRACK_COLORS["primary"], linewidth=2)
ax.plot(epochs_range, val_losses, label="Val Loss",
        color=SNACKTRACK_COLORS["secondary"], linewidth=2)
ax.axvline(best_epoch, color=SNACKTRACK_COLORS["accent"], linestyle="--",
           alpha=0.7, label=f"Best epoch ({best_epoch})")
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss")
ax.set_title("Training vs Validation Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# --- Learning rate schedule ---
ax = axes[1]
ax.plot(epochs_range, lr_history, color=SNACKTRACK_COLORS["accent"], linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Learning Rate")
ax.set_title("Learning Rate Schedule")
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Also use the utility function for a standalone loss plot
fig2 = plot_loss_curves(train_losses, val_losses,
                        title="RNN/GRU Training Progress")
plt.show()

## 5. Prediction Quality

We evaluate the trained model on the validation set by:

1. **Cosine similarity** between predicted and actual next-meal embeddings ---
   this is the metric that directly affects recommendation quality in production,
   since the RNN output is matched against recipe embeddings via cosine similarity.
2. **Sample predictions** showing side-by-side predicted vs actual vectors.
3. **Histogram** of cosine similarities across the full validation set.

In [None]:
# ---------------------------------------------------------------------------
# Compute predictions on validation set
# ---------------------------------------------------------------------------
model.eval()
all_preds = []
all_actuals = []

with torch.no_grad():
    for X_batch, y_batch in val_loader:
        X_batch = X_batch.to(DEVICE)
        preds = model(X_batch).cpu().numpy()
        all_preds.append(preds)
        all_actuals.append(y_batch.numpy())

all_preds = np.concatenate(all_preds, axis=0)
all_actuals = np.concatenate(all_actuals, axis=0)

# Cosine similarity
def cosine_similarity_batch(a, b):
    """Compute row-wise cosine similarity between two matrices."""
    dot = np.sum(a * b, axis=1)
    norm_a = np.linalg.norm(a, axis=1)
    norm_b = np.linalg.norm(b, axis=1)
    denom = norm_a * norm_b
    denom = np.where(denom > 0, denom, 1.0)
    return dot / denom

cos_sims = cosine_similarity_batch(all_preds, all_actuals)

print("Prediction Quality on Validation Set")
print("=" * 45)
print(f"  Mean cosine similarity:   {np.mean(cos_sims):.4f}")
print(f"  Median cosine similarity: {np.median(cos_sims):.4f}")
print(f"  Std:                      {np.std(cos_sims):.4f}")
print(f"  Min / Max:                {np.min(cos_sims):.4f} / {np.max(cos_sims):.4f}")
print(f"  % > 0.5:                  {100 * np.mean(cos_sims > 0.5):.1f}%")
print(f"  % > 0.8:                  {100 * np.mean(cos_sims > 0.8):.1f}%")
print(f"  MSE (val):                {np.mean((all_preds - all_actuals) ** 2):.6f}")

# ---------------------------------------------------------------------------
# Sample predictions
# ---------------------------------------------------------------------------
print("\n--- Sample Predictions (first 5) ---")
for i in range(min(5, len(all_preds))):
    sim = cos_sims[i]
    print(f"\nSample {i + 1} | Cosine Similarity: {sim:.4f}")
    print(f"  Predicted: [{', '.join(f'{v:.3f}' for v in all_preds[i][:8])}  ...]")
    print(f"  Actual:    [{', '.join(f'{v:.3f}' for v in all_actuals[i][:8])}  ...]")

# ---------------------------------------------------------------------------
# Histogram of cosine similarities
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Histogram
ax = axes[0]
ax.hist(cos_sims, bins=60, color=SNACKTRACK_COLORS["primary"], alpha=0.8,
        edgecolor="white", linewidth=0.5)
ax.axvline(np.mean(cos_sims), color=SNACKTRACK_COLORS["secondary"], linestyle="--",
           linewidth=2, label=f"Mean = {np.mean(cos_sims):.3f}")
ax.axvline(np.median(cos_sims), color=SNACKTRACK_COLORS["accent"], linestyle=":",
           linewidth=2, label=f"Median = {np.median(cos_sims):.3f}")
ax.set_xlabel("Cosine Similarity")
ax.set_ylabel("Count")
ax.set_title("Distribution of Cosine Similarity (Predicted vs Actual)")
ax.legend()
ax.grid(True, alpha=0.3)

# Scatter plot of predicted vs actual (first 2 dims)
ax = axes[1]
n_plot = min(500, len(all_preds))
ax.scatter(all_actuals[:n_plot, 0], all_preds[:n_plot, 0],
           alpha=0.3, s=15, color=SNACKTRACK_COLORS["primary"], label="Dim 0")
ax.scatter(all_actuals[:n_plot, 1], all_preds[:n_plot, 1],
           alpha=0.3, s=15, color=SNACKTRACK_COLORS["secondary"], label="Dim 1")
lims = [min(all_actuals[:n_plot, :2].min(), all_preds[:n_plot, :2].min()) - 0.1,
        max(all_actuals[:n_plot, :2].max(), all_preds[:n_plot, :2].max()) + 0.1]
ax.plot(lims, lims, "k--", alpha=0.5, label="Perfect prediction")
ax.set_xlabel("Actual")
ax.set_ylabel("Predicted")
ax.set_title("Predicted vs Actual (First 2 Embedding Dims)")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Export Weights

We export the trained weights to `weights/rnn_weights.npz` using `save_rnn_weights()`.

**Shape mapping from PyTorch to NumPy:**

| PyTorch Parameter | Export Key | Shape | Notes |
|-------------------|-----------|-------|-------|
| `gru_cell.Wz` | `Wz` | (39, 64) | Direct copy --- same layout |
| `gru_cell.Uz` | `Uz` | (64, 64) | Direct copy |
| `gru_cell.bz` | `bz` | (64,) | Direct copy |
| ... | ... | ... | Same for Wr, Ur, br, Wh, Uh, bh |
| `output_proj.weight` | `Wo` | (64, 32) | **Transposed** --- `nn.Linear` stores `(out, in)` |
| `output_proj.bias` | `bo` | (32,) | Direct copy |

The critical detail is the transpose on `output_proj.weight`: PyTorch's `nn.Linear`
stores the weight matrix as `(output_dim, input_dim)` and computes `x @ W.T + b`,
but our production NumPy code computes `h @ Wo + bo` where `Wo` is `(64, 32)`.
So we must export `weight.T`.

In [None]:
# ---------------------------------------------------------------------------
# Extract and export weights
# ---------------------------------------------------------------------------
model.eval()
gru = model.gru_cell

weights = {
    # GRU gate weights --- direct copy (no transpose needed)
    "Wz": gru.Wz.detach().cpu().numpy(),   # (39, 64)
    "Uz": gru.Uz.detach().cpu().numpy(),   # (64, 64)
    "bz": gru.bz.detach().cpu().numpy(),   # (64,)
    "Wr": gru.Wr.detach().cpu().numpy(),   # (39, 64)
    "Ur": gru.Ur.detach().cpu().numpy(),   # (64, 64)
    "br": gru.br.detach().cpu().numpy(),   # (64,)
    "Wh": gru.Wh.detach().cpu().numpy(),   # (39, 64)
    "Uh": gru.Uh.detach().cpu().numpy(),   # (64, 64)
    "bh": gru.bh.detach().cpu().numpy(),   # (64,)
    # Output projection --- TRANSPOSE required
    "Wo": model.output_proj.weight.detach().cpu().numpy().T,  # (64, 32)
    "bo": model.output_proj.bias.detach().cpu().numpy(),       # (32,)
}

# Verify shapes before saving
print("Exported weight shapes:")
for key, arr in weights.items():
    expected = RNN_WEIGHT_SHAPES[key]
    status = "OK" if arr.shape == expected else "MISMATCH"
    print(f"  {key:<5} {str(arr.shape):<12} expected {str(expected):<12} [{status}]")

# Save
save_path = save_rnn_weights(weights)
print(f"\nWeights saved to: {save_path}")
print(f"File size: {save_path.stat().st_size / 1024:.1f} KB")

## 7. Verification

This is the most critical step: we load the exported `.npz` weights, replicate the
production NumPy GRU forward pass, and verify that the outputs match the PyTorch
model's outputs to within floating-point tolerance (`atol=1e-5`).

This ensures that when the production `MealSequenceRNN` loads these weights,
it will produce **identical** predictions to the trained PyTorch model.

In [None]:
# ---------------------------------------------------------------------------
# Load exported weights
# ---------------------------------------------------------------------------
loaded_weights = load_rnn_weights()
print("Loaded weights:")
for key, arr in loaded_weights.items():
    print(f"  {key:<5} {arr.shape}")


# ---------------------------------------------------------------------------
# NumPy GRU forward pass (exact copy of production code)
# ---------------------------------------------------------------------------
def numpy_sigmoid(x):
    return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))


def numpy_gru_step(x, h_prev, w):
    """Single GRU step using NumPy --- mirrors MealSequenceRNN.gru_step()."""
    z = numpy_sigmoid(x @ w["Wz"] + h_prev @ w["Uz"] + w["bz"])
    r = numpy_sigmoid(x @ w["Wr"] + h_prev @ w["Ur"] + w["br"])
    h_candidate = np.tanh(x @ w["Wh"] + (r * h_prev) @ w["Uh"] + w["bh"])
    h_new = (1 - z) * h_prev + z * h_candidate
    return h_new


def numpy_forward(sequence, w):
    """Full forward pass using NumPy --- mirrors MealSequenceRNN.forward()."""
    h = np.zeros(64, dtype=np.float64)
    for x in sequence:
        h = numpy_gru_step(x.astype(np.float64), h, w)
    output = h @ w["Wo"] + w["bo"]
    return output


# ---------------------------------------------------------------------------
# Compare PyTorch vs NumPy on validation samples
# ---------------------------------------------------------------------------
model.eval()
n_test = min(20, len(X_val))
max_abs_diff = 0.0
all_close = True

print(f"\nVerifying PyTorch vs NumPy on {n_test} validation sequences...")
print(f"{'Sample':<8} {'Max |diff|':<15} {'Cosine sim':<15} {'Match?'}")
print("-" * 50)

for i in range(n_test):
    seq = X_val[i]  # (20, 39)

    # PyTorch prediction
    with torch.no_grad():
        seq_tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0).to(DEVICE)
        pt_output = model(seq_tensor).cpu().numpy().flatten()

    # NumPy prediction
    np_output = numpy_forward(seq, loaded_weights)

    # Compare
    abs_diff = np.max(np.abs(pt_output - np_output))
    max_abs_diff = max(max_abs_diff, abs_diff)

    cos_sim = np.dot(pt_output, np_output) / (
        np.linalg.norm(pt_output) * np.linalg.norm(np_output) + 1e-10
    )

    match = np.allclose(pt_output, np_output, atol=1e-5)
    if not match:
        all_close = False

    print(f"  {i + 1:<6} {abs_diff:<15.8f} {cos_sim:<15.8f} {'PASS' if match else 'FAIL'}")

print(f"\n{'=' * 50}")
print(f"Maximum absolute difference: {max_abs_diff:.2e}")
print(f"Overall verification: {'PASSED' if all_close else 'FAILED'}")

if all_close:
    print("\nThe exported weights reproduce the PyTorch model's outputs exactly.")
    print("The production NumPy RNN will produce identical recommendations.")
else:
    print("\nWARNING: Numerical differences exceed tolerance!")
    print("This is likely due to float32 vs float64 precision.")
    print(f"Max diff of {max_abs_diff:.2e} may still be acceptable for recommendations.")

# Final assertion
for i in range(min(5, len(X_val))):
    seq = X_val[i]
    with torch.no_grad():
        seq_tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0).to(DEVICE)
        pt_output = model(seq_tensor).cpu().numpy().flatten()
    np_output = numpy_forward(seq, loaded_weights)
    assert np.allclose(pt_output, np_output, atol=1e-5), (
        f"Verification FAILED on sample {i}: "
        f"max diff = {np.max(np.abs(pt_output - np_output)):.2e}"
    )

print("\nAll assertions passed. Weights are verified and ready for production.")