In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np
from func_memory_module import TTTMLP, WeightModel, HyperparamModel
from synthetic_datasets import InContextRecallDataset
from meta_optimizers import ManualAdam
from losses import windowed_p_loss


# --- Configuration ---
key_dim = 16
val_dim = 16
context_dim = 1 # Placeholder loss takes a single weight
seq_len = 20
num_epochs = 101

# --- Instantiate all meta-learning models ---
weight_model = WeightModel(key_dim, context_dim)
memory_module = TTTMLP(key_dim, val_dim)

lr_model = HyperparamModel(key_dim, initial_bias=-2.0) # Start with LR ~0.12
beta1_model = HyperparamModel(key_dim, initial_bias=2.0)  # Start with beta1 ~0.88
beta2_model = HyperparamModel(key_dim, initial_bias=3.0)  # Start with beta2 ~0.95

# --- Instantiate optimizers ---
# The outer-loop optimizer that trains ALL meta-models
outer_optimizer = torch.optim.AdamW(
    list(memory_module.parameters()) + # <-- The key addition!
    list(weight_model.parameters()) +
    list(lr_model.parameters()) +
    list(beta1_model.parameters()) +
    list(beta2_model.parameters()),
    lr=0.01 
)


# --- Loss Functions ---
inner_loss_func = windowed_p_loss
outer_loss_func = nn.CrossEntropyLoss()

print("--- Starting Meta-Training ---")
for epoch in range(num_epochs):
    # Generate a new task (sequence) for each epoch
    my_dataset = InContextRecallDataset(seq_len=seq_len, key_dim=key_dim, val_dim=val_dim)
    
    # Reset inner model and inner optimizer states for the new task
    fast_model = copy.deepcopy(memory_module)
    inner_optimizer = ManualAdam(fast_model.parameters()) # Optimizer for the fast_model

    
    outer_optimizer.zero_grad()
    total_outer_loss = torch.tensor(0.0, requires_grad=True)

    # --- Inner Loop (Task Adaptation) ---
    for i in range(seq_len):
        current_key, current_val = my_dataset[i]
        current_key = current_key.unsqueeze(0) # Add batch dimension

        # 1. Get dynamic parameters from meta-models
        loss_weights = weight_model(current_key)
        hyperparams = {
            'lr': lr_model(current_key),
            'beta1': beta1_model(current_key),
            'beta2': beta2_model(current_key),
        }
        
        # 2. Calculate inner loss and gradients using the fast_model
        pred = fast_model(current_key) # <-- Using fast_model
        inner_loss = inner_loss_func(pred, current_val.unsqueeze(0), loss_weights)
        grads = torch.autograd.grad(inner_loss, fast_model.parameters(), create_graph=True) # <-- Grads for fast_model

        # 3. Step the inner optimizer for the fast_model
        inner_optimizer.step(grads, **hyperparams)

        # 4. Calculate outer loss using the updated fast_model
        pred_after_update = fast_model(current_key) # <-- Using updated fast_model
        logits = torch.matmul(pred_after_update, my_dataset.val_vectors.T)
        target_index = torch.tensor([i], dtype=torch.long)
        outer_loss_step = outer_loss_func(logits, target_index)
        total_outer_loss = total_outer_loss + outer_loss_step

    # --- Outer Loop Update (Meta-Model Training) ---
    final_outer_loss = total_outer_loss / seq_len
    final_outer_loss.backward()
    outer_optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch} | Avg Outer Loss: {final_outer_loss.item():.4f}")
        # Print a sample of learned hyperparameters from the first step
        sample_key = my_dataset.key_vectors[0].unsqueeze(0)
        lr_sample = lr_model(sample_key).item()
        b1_sample = beta1_model(sample_key).item()
        b2_sample = beta2_model(sample_key).item()
        print(f"  Sample Hyperparams -> LR: {lr_sample:.4f}, B1: {b1_sample:.4f}, B2: {b2_sample:.4f}")
