In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np
from torch.func import functional_call

from func_memory_module import TTTMLP, WeightModel, HyperparamModel
from synthetic_datasets import InContextRecallDataset
from meta_optimizers import ManualAdam
from losses import windowed_p_loss, windowed_recall_cross_entropy


# --- Configuration ---
key_dim = 16
val_dim = 16
context_dim = 1 # Placeholder loss takes a single weight
seq_len = 50
num_epochs = 500
recall_window = 1  # Number of timesteps to score in the outer loss window

# --- Instantiate all meta-learning models ---
weight_model = WeightModel(key_dim, context_dim)
memory_module = TTTMLP(key_dim, val_dim)

lr_model = HyperparamModel(key_dim, initial_bias=-2.0) # Start with LR ~0.12

# --- Instantiate optimizers ---
# The outer-loop optimizer that trains ALL meta-models
outer_optimizer = torch.optim.AdamW(
    list(memory_module.parameters()) + 
    list(weight_model.parameters()) +
    list(lr_model.parameters()),
    lr=0.01 
)


# --- Loss Functions ---
inner_loss_func = windowed_p_loss
outer_loss_func = nn.CrossEntropyLoss()

print("--- Starting Training ---")
for epoch in range(num_epochs):
    # Generate a new task (sequence) for each epoch
    my_dataset = InContextRecallDataset(seq_len=seq_len, key_dim=key_dim,  val_dim=val_dim, context_size=context_dim, output_corr=0.5)
    
    # Reset inner model and inner optimizer states for the new task
    fast_model = copy.deepcopy(memory_module)
    current_params = {name: p.clone().detach().requires_grad_(True) for name, p in fast_model.named_parameters()}
    inner_optimizer = ManualAdam()
    states = inner_optimizer.init_states(current_params)  # Pass list of values to init_states


    
    outer_optimizer.zero_grad()
    total_outer_loss = torch.tensor(0.0, requires_grad=True)

    # --- Inner Loop ---
    for i in range(seq_len):
        current_key, current_val = my_dataset[i]
        current_key = current_key # Add batch dimension... why?

        # 1. Get dynamic parameters from meta-models
        loss_weights = weight_model(current_key[-1]) # only current key
        hyperparams = {
            'lr': lr_model(current_key),
            'beta1': 0.95,
            'beta2': 0.99,
        }
        
        # 2. Calculate inner loss and gradients using the fast_model
        pred = functional_call(fast_model, current_params, current_key)
        inner_loss = inner_loss_func(pred.T, current_val.T, loss_weights)
        grad_tuple = torch.autograd.grad(inner_loss, tuple(current_params.values()), create_graph=True)
        grads = dict(zip(current_params.keys(), grad_tuple))


        # 3. Step the inner optimizer for the fast_model
        current_params, states = inner_optimizer.step(current_params, grads, states, **hyperparams)


        # 4. Calculate outer loss using the updated fast_model
        # once again, only use the last element
        # Note that we are only testing if the model can recall the current item
        outer_loss_step = windowed_recall_cross_entropy(
            fast_model,
            current_params,
            my_dataset.inputs,
            my_dataset.targets,
            time_index=i,
            window_size=recall_window,
            loss_fn=outer_loss_func,
        )
        total_outer_loss = total_outer_loss + outer_loss_step



    # --- Outer Loop Update (Meta-Model Training) ---
    final_outer_loss = total_outer_loss / seq_len
    final_outer_loss.backward()
    outer_optimizer.step()

    if epoch % 50 == 0:
        print(f"Epoch {epoch} | Avg Outer Loss: {final_outer_loss.item():.4f}")
        # Print a sample of learned hyperparameters from the first step
        sample_key = my_dataset.inputs[0]
        lr_sample = lr_model(sample_key).item()
        print(f"  Sample Hyperparams -> LR: {lr_sample:.4f}")


--- Starting Training ---
Epoch 0 | Avg Outer Loss: 4.0738
  Sample Hyperparams -> LR: 0.1172
Epoch 50 | Avg Outer Loss: 3.9351
  Sample Hyperparams -> LR: 0.0989
Epoch 100 | Avg Outer Loss: 3.7417
  Sample Hyperparams -> LR: 0.1119
Epoch 150 | Avg Outer Loss: 3.9105
  Sample Hyperparams -> LR: 0.0890
Epoch 200 | Avg Outer Loss: 3.8611
  Sample Hyperparams -> LR: 0.0875
Epoch 250 | Avg Outer Loss: 3.8296
  Sample Hyperparams -> LR: 0.0721
Epoch 300 | Avg Outer Loss: 3.6973
  Sample Hyperparams -> LR: 0.0773
Epoch 350 | Avg Outer Loss: 3.7667
  Sample Hyperparams -> LR: 0.0836
Epoch 400 | Avg Outer Loss: 3.9366
  Sample Hyperparams -> LR: 0.0660
Epoch 450 | Avg Outer Loss: 3.7158
  Sample Hyperparams -> LR: 0.0951
