# Example Training Loop

In [1]:
# Add project root to Python path to enable src imports
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..'))

import torch
from src.synthetic_datasets import InContextRecallDataset
from src.meta_optimizers import MetaSGD
from src.memory_module import inner_optimization_forward, TTT
from src.model_components import HyperparamModel, LearnableHyperparam


## Configuration Variables

In [2]:
key_dim = 20
val_dim = 20
context_size = 5  # inner loss computation window
seq_len = 100
batch_size = 100
outer_window_size = 20
output_corr = 0
inner_opt = MetaSGD()
outer_optimizer_kwargs = {"lr": 0.001}
offset = 10

## Training Loop with BPTT

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# The inner model's parameters are the hidden state we meta-learn to update
inner_model = TTT(key_dim, val_dim, 1,0.0)
inner_model = inner_model.to(device)


lr_head = LearnableHyperparam().to(device)
loss_weight_head = torch.ones(context_size).to(device)


# Outer optimizer learns (1) the initial params of inner_model and (2) learning rate
outer_optim = torch.optim.SGD(
    list(inner_model.parameters()) + list(lr_head.parameters()),
    **outer_optimizer_kwargs
)
num_steps = 1000

for step in range(num_steps):
    dataset = InContextRecallDataset(
        seq_len=seq_len,
        key_dim=key_dim,
        val_dim=val_dim,
        context_size=context_size,
        batch_size=batch_size,
    )

    outer_optim.zero_grad()
    loss, predictions = inner_optimization_forward(
        inner_model,
        dataset,
        inner_optimizer=inner_opt,
        inner_lr_head=lr_head,
        inner_loss_weight_head=loss_weight_head,
        outer_window_size=outer_window_size,
        offset=offset
    )
    loss.backward()
    outer_optim.step()

    if step % 100 == 0 or step == num_steps - 1:
        with torch.no_grad():
            weight_mean = torch.mean(
                torch.cat([w.flatten() for w in inner_model.weights])
            ).item()
            weight_std = torch.std(
                torch.cat([w.flatten() for w in inner_model.weights])
            ).item()
            lr_val = lr_head()
            print(
                f"step {step:4d} | loss {loss.item():.4f} | lr {lr_val.item():.4f} | "
                f"inner params: weight_mean={weight_mean:.3f}, weight_std={weight_std:.3f}"
            )

step    0 | loss 378.1664 | lr 0.1006 | inner params: weight_mean=0.000, weight_std=0.000
step  100 | loss 376.5724 | lr 0.1424 | inner params: weight_mean=-0.000, weight_std=0.000
step  200 | loss 375.9243 | lr 0.1641 | inner params: weight_mean=0.000, weight_std=0.000
step  300 | loss 375.9665 | lr 0.1758 | inner params: weight_mean=-0.000, weight_std=0.000
step  400 | loss 376.1375 | lr 0.1823 | inner params: weight_mean=0.000, weight_std=0.000
step  500 | loss 375.7111 | lr 0.1861 | inner params: weight_mean=0.000, weight_std=0.000
step  600 | loss 376.0531 | lr 0.1884 | inner params: weight_mean=-0.000, weight_std=0.000
step  700 | loss 375.9659 | lr 0.1899 | inner params: weight_mean=0.000, weight_std=0.000
step  800 | loss 376.0570 | lr 0.1908 | inner params: weight_mean=0.000, weight_std=0.000
step  900 | loss 375.9998 | lr 0.1913 | inner params: weight_mean=0.000, weight_std=0.000
step  999 | loss 376.0276 | lr 0.1917 | inner params: weight_mean=0.000, weight_std=0.000
