In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from synthetic_datasets import BatchedInContextRecallDataset
from meta_optimizers import MetaSGD
from memory_module import inner_optimization_forward, TTT
from func_memory_module import HyperparamModel, LearnableHyperparam


In [None]:
key_dim=10
val_dim=10
context_dim=5 # inner loss computation window
seq_len=50
num_sequences=1000
batch_size=100
recall_window=1 # for outer loss computation # where doe this get used?
output_corr=0
inner_optimizer=MetaSGD()
inner_optimizer_kwargs={"lr": torch.tensor(0.01), "beta": torch.tensor(0.9)}
outer_optimizer_kwargs={"lr": 0.01}

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# The inner model's parameters are the hidden state we meta-learn to update
inner_model = TTT(key_dim,val_dim,1)
inner_model = inner_model.to(device)

inner_opt = MetaSGD()
# Use global heads for simplicity; set in_dim=embed_dim if you want them context-dependent.
#lr_head = HyperparamModel(key_dim,-4.5)
#lr_head = lr_head.to(device)
lr_head = LearnableHyperparam().to(device)
loss_weight_head = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]).to(device)  # Explicit 1D tensor


# Outer optimizer learns (1) the initial params of inner_model and (2) the heads
outer_optim = torch.optim.AdamW(
    list(inner_model.parameters()) + list(lr_head.parameters()),
    lr=1e-3
)

for step in range(1000):
    dataset = BatchedInContextRecallDataset(
        seq_len=seq_len,
        key_dim=key_dim,
        val_dim=val_dim,
        context_size=context_dim,
        batch_size=batch_size,
    )

    outer_optim.zero_grad()
    loss, predictions = inner_optimization_forward(
        inner_model,
        dataset,
        inner_opt=inner_opt,
        lr_head=lr_head,
        loss_weight_head = loss_weight_head,
        inner_param_dict=inner_optimizer_kwargs
    )
    loss.backward()
    outer_optim.step()

    if step % 100 == 0:
        with torch.no_grad():
            weight_mean = torch.mean(torch.cat([w.flatten() for w in inner_model.weights])).item()
            weight_std = torch.std(torch.cat([w.flatten() for w in inner_model.weights])).item()
            #lr_val = lr_head(dataset.inputs[-1])
            lr_val = lr_head()
            print(f"step {step:4d} | loss {loss.item():.4f} | lr {lr_val.item():.4f} | inner params: weight_mean={weight_mean:.3f}, weight_std={weight_std:.3f}")

step    0 | loss 193.3253 | lr 0.0997 | inner params: weight_norms=['0.210'], bias_norms=['0.003'], weight_mean=-0.001, weight_std=0.021
step  100 | loss 192.1266 | lr 0.1000 | inner params: weight_norms=['0.154'], bias_norms=['0.023'], weight_mean=-0.001, weight_std=0.015
step  200 | loss 192.3997 | lr 0.1000 | inner params: weight_norms=['0.118'], bias_norms=['0.025'], weight_mean=-0.001, weight_std=0.012
step  300 | loss 192.2317 | lr 0.1013 | inner params: weight_norms=['0.109'], bias_norms=['0.023'], weight_mean=0.001, weight_std=0.011
step  400 | loss 192.3799 | lr 0.1012 | inner params: weight_norms=['0.103'], bias_norms=['0.018'], weight_mean=-0.001, weight_std=0.010
step  500 | loss 192.4138 | lr 0.1006 | inner params: weight_norms=['0.096'], bias_norms=['0.023'], weight_mean=-0.001, weight_std=0.010
step  600 | loss 193.2105 | lr 0.1010 | inner params: weight_norms=['0.100'], bias_norms=['0.022'], weight_mean=-0.000, weight_std=0.010
step  700 | loss 191.9583 | lr 0.1016 | in