In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from synthetic_datasets import InContextRecallDataset, BatchedInContextRecallDataset
from meta_optimizers import ManualSGD
from memory_module import unroll_with_inner_param_dict, TTT
from func_memory_module import HyperparamModel



In [2]:
key_dim=10
val_dim=10
context_dim=5 # inner loss computation window
seq_len=50
num_sequences=1500
batch_size=1
recall_window=1 # for outer loss computation
output_corr=0
inner_optimizer=ManualSGD()
inner_optimizer_kwargs={"lr": torch.tensor(0.01), "beta": torch.tensor(0.9)}
outer_optimizer_kwargs={"lr": 0.01}

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# The inner model's parameters are the hidden state we meta-learn to update
inner_model = TTT(key_dim,val_dim,1)
inner_model = inner_model.to(device)

inner_opt = ManualSGD()
# Use global heads for simplicity; set in_dim=embed_dim if you want them context-dependent.
lr_head = HyperparamModel(key_dim,-4.5)
lr_head = lr_head.to(device)
loss_weight_head = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]).to(device)  # Explicit 1D tensor


# Outer optimizer learns (1) the initial params of inner_model and (2) the heads
outer_optim = torch.optim.Adam(
    list(inner_model.parameters()) + list(lr_head.parameters()),
    lr=1e-3
)

for step in range(1000):
    dataset = BatchedInContextRecallDataset(
        seq_len=seq_len,
        key_dim=key_dim,
        val_dim=val_dim,
        context_size=context_dim,
        batch_size=batch_size,
    )

    outer_optim.zero_grad()
    loss = unroll_with_inner_param_dict(
        inner_model,
        dataset,
        inner_opt=inner_opt,
        lr_head=lr_head,
        loss_weight_head = loss_weight_head,
        inner_param_dict=inner_optimizer_kwargs
    )
    loss.backward()
    outer_optim.step()

    if step % 100 == 0:
        with torch.no_grad():
            lr_val = lr_head(dataset.inputs[-1])
            print(f"step {step:4d} | loss {loss.item():.4f} | lr {lr_val[-1].item():.4f}")

step    0 | loss 198.3665 | lr 0.0102
step  100 | loss 197.9378 | lr 0.0131
step  200 | loss 198.0583 | lr 0.0104
step  300 | loss 197.1029 | lr 0.0115
step  400 | loss 196.3359 | lr 0.0119
step  500 | loss 199.9416 | lr 0.0128
step  600 | loss 197.2411 | lr 0.0098
step  700 | loss 195.9579 | lr 0.0102
step  800 | loss 199.2926 | lr 0.0098
step  900 | loss 197.8062 | lr 0.0087
