In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from synthetic_datasets import BatchedInContextRecallDataset
from meta_optimizers import MetaSGD
from memory_module import inner_optimization_forward, TTT
from model_components import HyperparamModel, LearnableHyperparam


In [2]:
key_dim=10
val_dim=10
context_size=5 # inner loss computation window
seq_len=50
batch_size=100
recall_window=5 # for outer loss computation 
output_corr=0
inner_optimizer=MetaSGD()
#inner_optimizer_kwargs={"beta": torch.tensor(0.9)}
outer_optimizer_kwargs={"lr": 0.001}

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# The inner model's parameters are the hidden state we meta-learn to update
inner_model = TTT(key_dim,val_dim,1)
inner_model = inner_model.to(device)

inner_opt = MetaSGD()
# Use global heads for simplicity; set in_dim=embed_dim if you want them context-dependent.
#lr_head = HyperparamModel(key_dim,-4.5)
#lr_head = lr_head.to(device)
lr_head = LearnableHyperparam().to(device)
loss_weight_head = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]).to(device)  # Explicit 1D tensor


# Outer optimizer learns (1) the initial params of inner_model and (2) learning rate
outer_optim = torch.optim.SGD(
    list(inner_model.parameters()) + list(lr_head.parameters()),
    **outer_optimizer_kwargs
)
num_steps = 1000

for step in range(num_steps):
    dataset = BatchedInContextRecallDataset(
        seq_len=seq_len,
        key_dim=key_dim,
        val_dim=val_dim,
        context_size=context_size,
        batch_size=batch_size,
    )

    outer_optim.zero_grad()
    loss, predictions = inner_optimization_forward(
        inner_model,
        dataset,
        inner_opt=inner_opt,
        lr_head=lr_head,
        loss_weight_head = loss_weight_head,
        outer_window_size=recall_window
    )
    loss.backward()
    outer_optim.step()

    if step % 100 == 0 or step == num_steps - 1:
        with torch.no_grad():
            weight_mean = torch.mean(torch.cat([w.flatten() for w in inner_model.weights])).item()
            weight_std = torch.std(torch.cat([w.flatten() for w in inner_model.weights])).item()
            #lr_val = lr_head(dataset.inputs[-1])
            lr_val = lr_head()
            print(f"step {step:4d} | loss {loss.item():.4f} | lr {lr_val.item():.4f} | inner params: weight_mean={weight_mean:.3f}, weight_std={weight_std:.3f}")

step    0 | loss 176.2836 | lr 0.1011 | inner params: weight_mean=0.001, weight_std=0.021
step  100 | loss 163.7454 | lr 0.2559 | inner params: weight_mean=0.001, weight_std=0.021
step  200 | loss 158.1222 | lr 0.4191 | inner params: weight_mean=0.001, weight_std=0.021
step  300 | loss 157.6905 | lr 0.4616 | inner params: weight_mean=0.001, weight_std=0.021
step  400 | loss 157.6535 | lr 0.4596 | inner params: weight_mean=0.001, weight_std=0.021
step  500 | loss 157.5878 | lr 0.4637 | inner params: weight_mean=0.001, weight_std=0.021
step  600 | loss 157.6816 | lr 0.4604 | inner params: weight_mean=0.001, weight_std=0.021
step  700 | loss 157.3551 | lr 0.4617 | inner params: weight_mean=0.001, weight_std=0.021
step  800 | loss 157.3085 | lr 0.4644 | inner params: weight_mean=0.001, weight_std=0.021
step  900 | loss 157.3561 | lr 0.4639 | inner params: weight_mean=0.001, weight_std=0.021
step  999 | loss 157.2093 | lr 0.4657 | inner params: weight_mean=0.001, weight_std=0.020
