# SnackTrack ML — VAE Training

**Purpose:** Train the Variational Autoencoder (VAE) that powers SnackTrack's
recipe representation learning. The production VAE in `app/recommender/vae.py`
currently uses **random placeholder weights** initialized with `np.random.default_rng(42)`.
This notebook trains real weights on our recipe corpus and exports them in the
exact format the NumPy inference code expects.

**Architecture:**
- **Encoder:** 12D recipe features -> 32D latent mean (mu) + 32D latent log-variance
- **Reparameterization:** z = mu + exp(0.5 * logvar) * epsilon
- **Decoder:** 32D latent -> 12D reconstructed features

The 12 input features are: `calories`, `protein_g`, `carbs_g`, `fat_g`,
`sodium_mg`, `fiber_g`, `sugar_g`, `ready_in_minutes`, `servings`,
`is_vegetarian`, `is_vegan`, `is_gluten_free`.

**Training strategy:**
- Beta-warmup to prevent posterior collapse
- Early stopping on validation loss
- ReduceLROnPlateau scheduler
- Weight export with correct transposition for NumPy inference

In [None]:
import sys
sys.path.insert(0, "..")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

from notebooks.utils.plot_helpers import (
    setup_plot_style,
    plot_loss_curves,
    plot_latent_space_2d,
    plot_feature_distributions,
)
from notebooks.utils.data_loader import (
    load_recipes_from_db,
    load_kaggle_dataset,
    extract_vae_features,
)
from notebooks.utils.db_connect import get_connection
from notebooks.utils.weight_io import save_vae_weights, load_vae_weights

setup_plot_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 1. Prepare Training Data

We combine recipe data from all available sources to maximize training set size.
The more diverse recipes the VAE sees, the better its latent space will capture
the full variety of the recipe landscape.

Sources (in priority order):
1. **SnackTrack DB** — production recipes
2. **Daily Food Nutrition** — Kaggle dataset with detailed nutritional data
3. **Food.com** — large recipe corpus from Kaggle
4. **Epicurious** — curated recipe dataset from Kaggle

In [None]:
all_recipe_dfs = []
source_counts = {}

# --- Source 1: SnackTrack database ---
try:
    conn = get_connection()
    db_recipes = load_recipes_from_db(conn)
    conn.close()
    if not db_recipes.empty:
        all_recipe_dfs.append(db_recipes)
        source_counts["SnackTrack DB"] = len(db_recipes)
        print(f"SnackTrack DB:         {len(db_recipes):>7,} recipes")
except Exception as e:
    print(f"SnackTrack DB:         unavailable ({e})")

# --- Source 2: Daily Food Nutrition (Kaggle) ---
try:
    dfn = load_kaggle_dataset("daily_food_nutrition")
    all_recipe_dfs.append(dfn)
    source_counts["Daily Food Nutrition"] = len(dfn)
    print(f"Daily Food Nutrition:  {len(dfn):>7,} recipes")
except FileNotFoundError:
    print("Daily Food Nutrition:  not found (run notebook 00 first)")

# --- Source 3: Food.com recipes (Kaggle) ---
try:
    foodcom = load_kaggle_dataset("food_com_recipes")
    all_recipe_dfs.append(foodcom)
    source_counts["Food.com"] = len(foodcom)
    print(f"Food.com:              {len(foodcom):>7,} recipes")
except FileNotFoundError:
    print("Food.com:              not found (run notebook 00 first)")

# --- Source 4: Epicurious (Kaggle) ---
try:
    epi = load_kaggle_dataset("epicurious")
    all_recipe_dfs.append(epi)
    source_counts["Epicurious"] = len(epi)
    print(f"Epicurious:            {len(epi):>7,} recipes")
except FileNotFoundError:
    print("Epicurious:            not found (run notebook 00 first)")

# --- Fallback: synthetic data if nothing loaded ---
if not all_recipe_dfs:
    print("\nNo datasets available. Generating synthetic recipe data.")
    rng = np.random.default_rng(42)
    n_synthetic = 5000
    synthetic_df = pd.DataFrame({
        "calories": rng.lognormal(5.5, 0.7, n_synthetic).clip(50, 2000),
        "protein_g": rng.lognormal(2.5, 0.8, n_synthetic).clip(0, 150),
        "carbs_g": rng.lognormal(3.2, 0.7, n_synthetic).clip(0, 300),
        "fat_g": rng.lognormal(2.3, 0.9, n_synthetic).clip(0, 150),
        "sodium_mg": rng.lognormal(5.8, 1.0, n_synthetic).clip(0, 5000),
        "fiber_g": rng.exponential(3, n_synthetic).clip(0, 50),
        "sugar_g": rng.lognormal(2.0, 1.0, n_synthetic).clip(0, 100),
        "ready_in_minutes": rng.lognormal(3.2, 0.6, n_synthetic).clip(5, 300),
        "servings": rng.choice([1, 2, 4, 6, 8, 12], n_synthetic),
        "diet_labels": [[] for _ in range(n_synthetic)],
    })
    # Assign diet labels to a subset
    for i in range(n_synthetic):
        labels = []
        if rng.random() < 0.2:
            labels.append("vegetarian")
        if rng.random() < 0.08:
            labels.append("vegan")
        if rng.random() < 0.15:
            labels.append("gluten free")
        synthetic_df.at[i, "diet_labels"] = labels
    all_recipe_dfs.append(synthetic_df)
    source_counts["Synthetic"] = n_synthetic

# --- Combine all sources ---
combined_df = pd.concat(all_recipe_dfs, ignore_index=True)
print(f"\nTotal combined recipes: {len(combined_df):,}")

# --- Extract 12D features ---
features_raw = extract_vae_features(combined_df)

# Remove rows with any NaN or infinite values
valid_mask = np.isfinite(features_raw).all(axis=1)
features_raw = features_raw[valid_mask]
print(f"Valid feature vectors:  {len(features_raw):,} (removed {(~valid_mask).sum()} invalid rows)")
print(f"Feature shape:         {features_raw.shape}")

In [None]:
# --- Compute normalization statistics ---
feature_means = features_raw.mean(axis=0)
feature_stds = features_raw.std(axis=0)
# Prevent division by zero for constant features (e.g., binary flags)
feature_stds[feature_stds == 0] = 1.0

print("Normalization statistics:")
feature_names = [
    'calories', 'protein_g', 'carbs_g', 'fat_g', 'sodium_mg', 'fiber_g',
    'sugar_g', 'ready_in_minutes', 'servings', 'is_vegetarian', 'is_vegan',
    'is_gluten_free',
]
for name, mu, std in zip(feature_names, feature_means, feature_stds):
    print(f"  {name:<18s}  mean={mu:10.3f}  std={std:10.3f}")

# --- Normalize ---
features_normalized = (features_raw - feature_means) / (feature_stds + 1e-8)

# --- Train/validation split (80/20) ---
rng = np.random.default_rng(42)
n_total = len(features_normalized)
indices = rng.permutation(n_total)
split_idx = int(0.8 * n_total)

train_features = features_normalized[indices[:split_idx]]
val_features = features_normalized[indices[split_idx:]]

print(f"\nTrain set: {len(train_features):,} samples")
print(f"Val set:   {len(val_features):,} samples")

# --- Create PyTorch DataLoaders ---
train_tensor = torch.tensor(train_features, dtype=torch.float32)
val_tensor = torch.tensor(val_features, dtype=torch.float32)

batch_size = 256
train_loader = DataLoader(
    TensorDataset(train_tensor),
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
)
val_loader = DataLoader(
    TensorDataset(val_tensor),
    batch_size=batch_size,
    shuffle=False,
)

print(f"\nBatch size: {batch_size}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")

## 2. Feature Distributions

Before training, we inspect the raw (pre-normalization) feature distributions.
This helps us understand the data characteristics:
- Continuous nutritional features tend to be right-skewed (log-normal)
- Binary diet flags are heavily imbalanced (most recipes are not restricted)
- Understanding these distributions helps interpret reconstruction quality later

In [None]:
fig = plot_feature_distributions(
    features_raw,
    feature_names,
    title="Raw Feature Distributions (before normalization)",
)
plt.show()

## 3. Define VAE Model

The PyTorch model architecture **exactly mirrors** the NumPy inference code in
`app/recommender/vae.py`. This is critical: the weight matrices must be
compatible after transposition (PyTorch stores `nn.Linear` weights as
`(out_features, in_features)`, while NumPy inference uses `x @ W + b` where
`W` is `(in_features, out_features)`).

**Architecture:**
- `encoder_mu`: Linear(12 -> 32) — predicts latent mean
- `encoder_logvar`: Linear(12 -> 32) — predicts latent log-variance
- `decoder`: Linear(32 -> 12) — reconstructs features from latent code

This is intentionally a single-layer encoder/decoder. For recipe features,
this linear VAE provides smooth latent spaces suitable for interpolation
and similarity search.

In [None]:
class RecipeVAEPyTorch(nn.Module):
    """Variational Autoencoder for recipe feature embeddings.
    
    Architecture matches app/recommender/vae.py RecipeVAE exactly.
    After training, weights are exported with transposition for NumPy inference.
    """

    LATENT_DIM = 32
    FEATURE_DIM = 12

    def __init__(self):
        super().__init__()
        self.encoder_mu = nn.Linear(self.FEATURE_DIM, self.LATENT_DIM)
        self.encoder_logvar = nn.Linear(self.FEATURE_DIM, self.LATENT_DIM)
        self.decoder = nn.Linear(self.LATENT_DIM, self.FEATURE_DIM)

    def encode(self, x):
        """Encode input features to latent distribution parameters."""
        return self.encoder_mu(x), self.encoder_logvar(x)

    def reparameterize(self, mu, logvar):
        """Sample from latent distribution using reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """Decode latent vector back to feature space."""
        return self.decoder(z)

    def forward(self, x):
        """Full forward pass: encode, sample, decode."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = RecipeVAEPyTorch().to(device)
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
for name, param in model.named_parameters():
    print(f"  {name:<25s} {str(tuple(param.shape)):<15s} ({param.numel():,} params)")

In [None]:
def vae_loss(recon, original, mu, logvar, beta=1.0):
    """Compute VAE loss = reconstruction loss + beta * KL divergence.
    
    Parameters
    ----------
    recon : Tensor
        Reconstructed features from the decoder.
    original : Tensor
        Original input features.
    mu : Tensor
        Mean of the latent distribution.
    logvar : Tensor
        Log-variance of the latent distribution.
    beta : float
        Weight for the KL divergence term (beta-VAE formulation).
        During warmup, beta increases from 0 to 1.
    
    Returns
    -------
    total_loss, recon_loss, kl_loss : Tensors
    """
    recon_loss = F.mse_loss(recon, original, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = recon_loss + beta * kl_loss
    return total_loss, recon_loss, kl_loss


print("VAE loss function defined.")
print("  - Reconstruction: MSE (sum reduction)")
print("  - KL Divergence:  D_KL(q(z|x) || p(z)) with beta weighting")

## 4. Training Loop

We train using the beta-VAE formulation with a warmup schedule:

- **Beta warmup:** Over the first 30% of epochs, beta increases linearly from 0
  to 1. This prevents the common "posterior collapse" problem where the KL term
  dominates early training and the model learns to ignore the latent code.

- **Optimizer:** Adam with lr=1e-3 and weight_decay=1e-5 for mild regularization.

- **Scheduler:** ReduceLROnPlateau monitors validation loss with patience=10.

- **Early stopping:** Training halts if validation loss does not improve for 20
  consecutive epochs.

In [None]:
# --- Hyperparameters ---
MAX_EPOCHS = 200
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
BETA_WARMUP_FRACTION = 0.30  # beta goes 0 -> 1 over first 30% of epochs
EARLY_STOP_PATIENCE = 20
SCHEDULER_PATIENCE = 10

optimizer = torch.optim.Adam(
    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=SCHEDULER_PATIENCE, verbose=False
)

# --- Tracking ---
history = {
    "train_total": [], "train_recon": [], "train_kl": [],
    "val_total": [], "val_recon": [], "val_kl": [],
    "beta": [], "lr": [],
}
best_val_loss = float("inf")
best_epoch = 0
best_state_dict = None
epochs_without_improvement = 0

beta_warmup_epochs = int(MAX_EPOCHS * BETA_WARMUP_FRACTION)

print(f"Training configuration:")
print(f"  Max epochs:          {MAX_EPOCHS}")
print(f"  Learning rate:       {LEARNING_RATE}")
print(f"  Weight decay:        {WEIGHT_DECAY}")
print(f"  Beta warmup epochs:  {beta_warmup_epochs}")
print(f"  Early stop patience: {EARLY_STOP_PATIENCE}")
print(f"  Scheduler patience:  {SCHEDULER_PATIENCE}")
print()

# --- Training loop ---
progress = tqdm(range(1, MAX_EPOCHS + 1), desc="Training")

for epoch in progress:
    # Compute beta for this epoch
    if epoch <= beta_warmup_epochs:
        beta = epoch / beta_warmup_epochs
    else:
        beta = 1.0

    # --- Train ---
    model.train()
    train_total, train_recon, train_kl = 0.0, 0.0, 0.0
    n_train = 0

    for (batch,) in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        recon, mu, logvar = model(batch)
        loss, r_loss, kl = vae_loss(recon, batch, mu, logvar, beta=beta)

        loss.backward()
        optimizer.step()

        train_total += loss.item()
        train_recon += r_loss.item()
        train_kl += kl.item()
        n_train += len(batch)

    # --- Validate ---
    model.eval()
    val_total, val_recon, val_kl = 0.0, 0.0, 0.0
    n_val = 0

    with torch.no_grad():
        for (batch,) in val_loader:
            batch = batch.to(device)
            recon, mu, logvar = model(batch)
            loss, r_loss, kl = vae_loss(recon, batch, mu, logvar, beta=beta)

            val_total += loss.item()
            val_recon += r_loss.item()
            val_kl += kl.item()
            n_val += len(batch)

    # --- Normalize by sample count ---
    train_total /= n_train
    train_recon /= n_train
    train_kl /= n_train
    val_total /= n_val
    val_recon /= n_val
    val_kl /= n_val

    # --- Record ---
    history["train_total"].append(train_total)
    history["train_recon"].append(train_recon)
    history["train_kl"].append(train_kl)
    history["val_total"].append(val_total)
    history["val_recon"].append(val_recon)
    history["val_kl"].append(val_kl)
    history["beta"].append(beta)
    history["lr"].append(optimizer.param_groups[0]["lr"])

    # --- LR scheduler ---
    scheduler.step(val_total)

    # --- Early stopping check ---
    if val_total < best_val_loss:
        best_val_loss = val_total
        best_epoch = epoch
        best_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    # --- Progress bar update ---
    progress.set_postfix({
        "loss": f"{val_total:.4f}",
        "recon": f"{val_recon:.4f}",
        "kl": f"{val_kl:.4f}",
        "beta": f"{beta:.2f}",
        "lr": f"{optimizer.param_groups[0]['lr']:.1e}",
        "best": f"{best_epoch}",
    })

    if epochs_without_improvement >= EARLY_STOP_PATIENCE:
        print(f"\nEarly stopping at epoch {epoch}. "
              f"Best val loss: {best_val_loss:.4f} at epoch {best_epoch}.")
        break

# --- Restore best model ---
if best_state_dict is not None:
    model.load_state_dict(best_state_dict)
    print(f"\nRestored best model from epoch {best_epoch} "
          f"(val loss = {best_val_loss:.4f})")
else:
    print("\nTraining completed. Using final epoch weights.")

print(f"Total epochs run: {len(history['train_total'])}")

## 5. Training Visualization

Four panels showing the training dynamics:
- **(a) Total loss:** combined reconstruction + beta * KL (train and val)
- **(b) Reconstruction loss:** how well the decoder reconstructs features
- **(c) KL divergence:** regularization toward the prior N(0,1)
- **(d) Beta schedule:** the KL weight warmup curve

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs_range = range(1, len(history["train_total"]) + 1)

# (a) Total loss
axes[0, 0].plot(epochs_range, history["train_total"], label="Train", color="#4CAF50", linewidth=1.5)
axes[0, 0].plot(epochs_range, history["val_total"], label="Val", color="#FF9800", linewidth=1.5)
axes[0, 0].axvline(best_epoch, color="#F44336", linestyle="--", alpha=0.5, label=f"Best (epoch {best_epoch})")
axes[0, 0].set_title("(a) Total Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss / sample")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# (b) Reconstruction loss
axes[0, 1].plot(epochs_range, history["train_recon"], label="Train", color="#4CAF50", linewidth=1.5)
axes[0, 1].plot(epochs_range, history["val_recon"], label="Val", color="#FF9800", linewidth=1.5)
axes[0, 1].axvline(best_epoch, color="#F44336", linestyle="--", alpha=0.5)
axes[0, 1].set_title("(b) Reconstruction Loss (MSE)")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("MSE / sample")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# (c) KL divergence
axes[1, 0].plot(epochs_range, history["train_kl"], label="Train", color="#4CAF50", linewidth=1.5)
axes[1, 0].plot(epochs_range, history["val_kl"], label="Val", color="#FF9800", linewidth=1.5)
axes[1, 0].axvline(best_epoch, color="#F44336", linestyle="--", alpha=0.5)
axes[1, 0].set_title("(c) KL Divergence")
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("KL / sample")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# (d) Beta schedule + learning rate
ax_beta = axes[1, 1]
ax_lr = ax_beta.twinx()

line1, = ax_beta.plot(epochs_range, history["beta"], label="Beta", color="#9C27B0", linewidth=2)
line2, = ax_lr.plot(epochs_range, history["lr"], label="Learning Rate", color="#2196F3",
                    linewidth=1.5, linestyle="--")

ax_beta.set_title("(d) Beta Schedule & Learning Rate")
ax_beta.set_xlabel("Epoch")
ax_beta.set_ylabel("Beta", color="#9C27B0")
ax_lr.set_ylabel("Learning Rate", color="#2196F3")
ax_beta.set_ylim(-0.05, 1.1)
ax_lr.set_yscale("log")
ax_beta.legend(handles=[line1, line2], loc="center right")
ax_beta.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final training loss:   {history['train_total'][-1]:.4f}")
print(f"Best validation loss:  {best_val_loss:.4f} (epoch {best_epoch})")

## 6. Latent Space Visualization

We encode all recipes into the 32D latent space, then project down to 2D with
t-SNE. If diet labels are available, we color-code points by diet type to see
whether the VAE has learned to separate different recipe categories in its
latent representation.

In [None]:
# Encode all recipes
model.eval()
all_features_tensor = torch.tensor(features_normalized, dtype=torch.float32).to(device)

with torch.no_grad():
    all_mu, _ = model.encode(all_features_tensor)
    all_embeddings = all_mu.cpu().numpy()

print(f"Encoded {len(all_embeddings)} recipes to {all_embeddings.shape[1]}D latent space.")

# Build labels from diet flags (use the raw features columns 9, 10, 11)
labels = []
for i in range(len(features_raw)):
    if features_raw[i, 10] > 0.5:       # is_vegan
        labels.append("Vegan")
    elif features_raw[i, 9] > 0.5:       # is_vegetarian
        labels.append("Vegetarian")
    elif features_raw[i, 11] > 0.5:      # is_gluten_free
        labels.append("Gluten-Free")
    else:
        labels.append("Standard")

# Subsample for t-SNE if dataset is large (t-SNE is O(n^2))
max_tsne = 5000
if len(all_embeddings) > max_tsne:
    rng = np.random.default_rng(42)
    sample_idx = rng.choice(len(all_embeddings), max_tsne, replace=False)
    plot_embeddings = all_embeddings[sample_idx]
    plot_labels = [labels[i] for i in sample_idx]
    print(f"Subsampled to {max_tsne} points for t-SNE.")
else:
    plot_embeddings = all_embeddings
    plot_labels = labels

fig = plot_latent_space_2d(
    plot_embeddings,
    labels=plot_labels,
    title="VAE Latent Space — Recipes Colored by Diet Type",
)
plt.show()

# Print label distribution
from collections import Counter
label_counts = Counter(labels)
print("\nLabel distribution:")
for label, count in sorted(label_counts.items()):
    print(f"  {label:<15s}: {count:>6,} ({count / len(labels):.1%})")

## 7. Reconstruction Quality

We evaluate how well the VAE reconstructs recipe features by comparing original
and decoded features for a sample of 5 recipes. We also compute per-feature MSE
across the entire validation set to identify which features are hardest to
reconstruct.

In [None]:
# --- Per-feature MSE on validation set ---
model.eval()
with torch.no_grad():
    val_all = val_tensor.to(device)
    recon_all, _, _ = model(val_all)
    recon_np = recon_all.cpu().numpy()
    val_np = val_all.cpu().numpy()

per_feature_mse = ((recon_np - val_np) ** 2).mean(axis=0)

print("Per-feature reconstruction MSE (normalized space):")
print(f"{'Feature':<20s} {'MSE':>10s}")
print("-" * 32)
for name, mse in zip(feature_names, per_feature_mse):
    print(f"{name:<20s} {mse:>10.6f}")
print(f"{'TOTAL (avg)':<20s} {per_feature_mse.mean():>10.6f}")

# --- Sample comparison table (de-normalized to original scale) ---
print("\n" + "=" * 80)
print("Sample Reconstruction Comparison (original scale)")
print("=" * 80)

n_samples = 5
sample_indices = np.linspace(0, len(val_np) - 1, n_samples, dtype=int)

for idx in sample_indices:
    original_norm = val_np[idx]
    reconstructed_norm = recon_np[idx]

    # De-normalize back to original scale
    original_raw = original_norm * (feature_stds + 1e-8) + feature_means
    reconstructed_raw = reconstructed_norm * (feature_stds + 1e-8) + feature_means

    print(f"\nSample {idx}:")
    print(f"  {'Feature':<20s} {'Original':>12s} {'Reconstructed':>14s} {'Error':>10s}")
    print(f"  {'-'*20} {'-'*12} {'-'*14} {'-'*10}")
    for name, orig, rec in zip(feature_names, original_raw, reconstructed_raw):
        err = abs(orig - rec)
        print(f"  {name:<20s} {orig:>12.2f} {rec:>14.2f} {err:>10.2f}")

## 8. Export Weights

Export the trained PyTorch weights to the format expected by the NumPy-based
production `RecipeVAE` in `app/recommender/vae.py`.

**Critical: Weight Transposition**

PyTorch `nn.Linear` stores weights as `(out_features, in_features)` and computes
`y = x @ W.T + b`. The NumPy inference code computes `y = x @ W + b` where `W`
is `(in_features, out_features)`. Therefore we must **transpose** all weight
matrices during export:

| Weight | PyTorch shape | Exported (NumPy) shape |
|--------|--------------|------------------------|
| encoder_mu_w | (32, 12) | (12, 32) |
| encoder_logvar_w | (32, 12) | (12, 32) |
| decoder_w | (12, 32) | (32, 12) |

In [None]:
# Build weight dict with CORRECT transposition for NumPy inference
weights = {
    # Encoder mu: PyTorch (32,12) -> NumPy (12,32)
    'encoder_mu_w': model.encoder_mu.weight.detach().cpu().numpy().T,
    'encoder_mu_b': model.encoder_mu.bias.detach().cpu().numpy(),

    # Encoder logvar: PyTorch (32,12) -> NumPy (12,32)
    'encoder_logvar_w': model.encoder_logvar.weight.detach().cpu().numpy().T,
    'encoder_logvar_b': model.encoder_logvar.bias.detach().cpu().numpy(),

    # Decoder: PyTorch (12,32) -> NumPy (32,12)
    'decoder_w': model.decoder.weight.detach().cpu().numpy().T,
    'decoder_b': model.decoder.bias.detach().cpu().numpy(),

    # Normalization statistics (computed from training data)
    'feature_means': feature_means,
    'feature_stds': feature_stds,
}

# Print shapes for verification
print("Exported weight shapes:")
for key, arr in weights.items():
    print(f"  {key:<25s} {str(arr.shape):<15s} dtype={arr.dtype}")

# Save to disk
save_path = save_vae_weights(weights)
print(f"\nWeights saved to: {save_path}")

## 9. Verification

We verify that the exported weights produce **identical results** when used in
the NumPy inference path (which mirrors `RecipeVAE.encode()` in production).

Specifically, we:
1. Load the saved `.npz` weights
2. Replicate the NumPy encode logic (`normalized @ W + b`)
3. Compare against the PyTorch model's deterministic output (mu)
4. Assert that results match within numerical tolerance (atol=1e-5)

In [None]:
# --- Load exported weights ---
loaded_weights = load_vae_weights()
print("Loaded weights from disk.")

# --- Replicate NumPy encode logic (mirrors RecipeVAE.encode in production) ---
def numpy_encode(features_raw, w):
    """NumPy inference path matching app/recommender/vae.py RecipeVAE.encode()."""
    normalized = (features_raw - w['feature_means']) / (w['feature_stds'] + 1e-8)
    mu = normalized @ w['encoder_mu_w'] + w['encoder_mu_b']
    logvar = normalized @ w['encoder_logvar_w'] + w['encoder_logvar_b']
    return mu, logvar

# --- Select test samples ---
n_test = 100
test_raw = features_raw[:n_test]  # un-normalized raw features

# --- NumPy path ---
np_mu, np_logvar = numpy_encode(test_raw, loaded_weights)

# --- PyTorch path (normalize the same way, then encode) ---
model.eval()
test_normalized = (test_raw - feature_means) / (feature_stds + 1e-8)
test_tensor = torch.tensor(test_normalized, dtype=torch.float32).to(device)

with torch.no_grad():
    pt_mu, pt_logvar = model.encode(test_tensor)
    pt_mu = pt_mu.cpu().numpy()
    pt_logvar = pt_logvar.cpu().numpy()

# --- Compare ---
mu_max_diff = np.max(np.abs(np_mu - pt_mu))
logvar_max_diff = np.max(np.abs(np_logvar - pt_logvar))

print(f"\nVerification results ({n_test} test samples):")
print(f"  mu max absolute difference:     {mu_max_diff:.2e}")
print(f"  logvar max absolute difference:  {logvar_max_diff:.2e}")

# --- Assert ---
TOLERANCE = 1e-5
assert np.allclose(np_mu, pt_mu, atol=TOLERANCE), (
    f"mu mismatch! Max diff: {mu_max_diff:.2e} > tolerance {TOLERANCE:.2e}"
)
assert np.allclose(np_logvar, pt_logvar, atol=TOLERANCE), (
    f"logvar mismatch! Max diff: {logvar_max_diff:.2e} > tolerance {TOLERANCE:.2e}"
)

print(f"\n{'='*50}")
print(f"  Verification PASSED  (tolerance = {TOLERANCE:.0e})")
print(f"{'='*50}")
print(f"\nThe exported weights produce identical results in both")
print(f"the PyTorch training model and the NumPy inference path.")
print(f"The weights at '{save_path}' are ready for production use.")