In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt

from losses import windowed_p_loss
from func_memory_module import LinearModel, TTTMLP, WeightModel
from synthetic_datasets import InContextRecallDataset


# Define the dimensions of the vectors
key_dim = 10
val_dim = 10
hidden_dim = 20
context_dim = 5

# instantiate models
weight_model = WeightModel(key_dim, context_dim)
memory_module= TTTMLP(key_dim, val_dim)

# instantiate optimizers
inner_optimizer = torchopt.MetaSGD(memory_module, lr=0.3)
outer_optimizer = torch.optim.AdamW(weight_model.parameters(), lr=0.05)

#define loss function
inner_loss_func = windowed_p_loss
outer_loss_func = nn.CrossEntropyLoss()

# training loop
seq_len = 20

# save initial state
memory_module_init = torchopt.extract_state_dict(memory_module)
inner_optimizer_init = torchopt.extract_state_dict(inner_optimizer)

for epoch in range(100):
    # Generate a new sequence of keys and values for each epoch
    my_dataset = InContextRecallDataset(seq_len=seq_len, key_dim=key_dim, val_dim=val_dim, context_size=context_dim, output_corr=0.5)

    _, val_vectors = my_dataset[:]

    # Reset model and optimizer to initial state
    torchopt.recover_state_dict(memory_module, memory_module_init)
    torchopt.recover_state_dict(inner_optimizer, inner_optimizer_init)

    outer_optimizer.zero_grad()
    total_outer_loss = torch.tensor(0.0)

    # inner loop
    for i in range(seq_len):
        input, output = my_dataset[i]
        loss_weights = weight_model(input[-1]) # should only rely on current key in context
        if epoch % 10 == 0 and i % 3 == 0:
            print(f"Epoch {epoch}, Step {i}, Loss Weights: {loss_weights}")
        pred = memory_module(input)

        inner_loss = inner_loss_func(pred.T, output.T, loss_weights) # the loss has a transposed shape
        inner_optimizer.step(inner_loss)

        pred_after_update = memory_module(input[-1]) #now shape (1, val_dim) (we only take the last one.)
        #print(pred_after_update.shape, val_vectors.T.shape)
        logits = torch.matmul(pred_after_update, val_vectors.T) # shape (1, seq_len)
        target_index = torch.tensor([i], dtype=torch.long) # shape (1,)
        #print(temp.shape , target_index.shape, logits.shape)
        #print(temp.dtype , target_index.dtype)
        outer_loss_step = outer_loss_func(logits, target_index)
        total_outer_loss += outer_loss_step
    
    # outer loop calculation
    final_outer_loss = total_outer_loss / seq_len
    final_outer_loss.backward()
    outer_optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Outer Loss: {final_outer_loss.item()}")


Epoch 0, Step 0, Loss Weights: tensor([0.5236, 0.4659, 0.5576, 0.4184, 0.4073], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 3, Loss Weights: tensor([0.6497, 0.4965, 0.5808, 0.4881, 0.4623], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 6, Loss Weights: tensor([0.5867, 0.3844, 0.4811, 0.5296, 0.4653], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 9, Loss Weights: tensor([0.4678, 0.5109, 0.5481, 0.5413, 0.4111], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 12, Loss Weights: tensor([0.4525, 0.4480, 0.4397, 0.4338, 0.4029], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 15, Loss Weights: tensor([0.3771, 0.3392, 0.3810, 0.6110, 0.6096], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 18, Loss Weights: tensor([0.6107, 0.6313, 0.5277, 0.4985, 0.4253], grad_fn=<SigmoidBackward0>)
Epoch 0, Outer Loss: 2.6988954544067383
Epoch 10, Step 0, Loss Weights: tensor([0.4521, 0.5026, 0.3426, 0.6472, 0.7512], grad_fn=<SigmoidBackward0>)
Epoch 10, Step 3, Loss Weights: tensor([0.3757, 0.3497, 0.2867, 0.4909, 0.5674], grad_fn=<Si