# Example Training Loop

In [None]:
# Add project root to Python path to enable src imports
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..'))

import torch
from src.synthetic_datasets import InContextRecallDataset
from src.meta_optimizers import MetaSGD
from src.forward_loop import inner_optimization_forward
from src.model_components import LearnableHyperparam, TTT


## Configuration Variables

In [2]:
key_dim = 20
val_dim = 20
context_size = 5  # inner loss computation window
seq_len = 50
batch_size = 100
outer_window_size = 20
output_corr = 0
inner_opt = MetaSGD()
outer_optimizer_kwargs = {"lr": 0.001}
offset = 10

## Training Loop with BPTT

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# The inner model's parameters are the hidden state we meta-learn to update
inner_model = TTT(key_dim, val_dim, 1,0.0)
inner_model = inner_model.to(device)


lr_head = LearnableHyperparam().to(device)
loss_weight_head = torch.ones(context_size).to(device)


# Outer optimizer learns the inner learning rate
outer_optim = torch.optim.SGD(
    list(lr_head.parameters()),
    **outer_optimizer_kwargs
)
num_steps = 1000

for step in range(num_steps):
    dataset = InContextRecallDataset(
        seq_len=seq_len,
        key_dim=key_dim,
        val_dim=val_dim,
        context_size=context_size,
        batch_size=batch_size,
    )

    outer_optim.zero_grad()
    loss, predictions = inner_optimization_forward(
        inner_model,
        dataset,
        inner_optimizer=inner_opt,
        inner_lr_head=lr_head,
        inner_loss_weight_head=loss_weight_head,
        outer_window_size=outer_window_size,
        offset=offset
    )
    loss.backward()
    outer_optim.step()

    if step % 100 == 0 or step == num_steps - 1:
        with torch.no_grad():
            lr_val = lr_head()
            print(
                f"step {step:4d} | loss {loss.item():.4f} | inner_lr {lr_val.item():.4f}"
            )

step    0 | loss 138.1478 | inner_lr 0.1003
step  100 | loss 137.6427 | inner_lr 0.1259
step  200 | loss 137.2064 | inner_lr 0.1453
step  300 | loss 136.8836 | inner_lr 0.1599
step  400 | loss 136.9148 | inner_lr 0.1710
step  500 | loss 137.0769 | inner_lr 0.1795
step  600 | loss 136.8754 | inner_lr 0.1859
step  700 | loss 137.0741 | inner_lr 0.1910
step  800 | loss 136.6629 | inner_lr 0.1950
step  900 | loss 136.7405 | inner_lr 0.1984
step  999 | loss 137.0141 | inner_lr 0.2010
