In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt

from losses import windowed_p_loss
from func_memory_module import LinearModel, TTTMLP, WeightModel
from synthetic_datasets import InContextRecallDataset

# USING TORCHOPT FOR META-LEARNING
# Define the dimensions of the vectors
key_dim = 10
val_dim = 10
hidden_dim = 20
context_dim = 5

# instantiate models
weight_model = WeightModel(key_dim, context_dim)
memory_module= TTTMLP(key_dim, val_dim)

# instantiate optimizers
inner_optimizer = torchopt.MetaSGD(memory_module, lr=0.3)
outer_optimizer = torch.optim.AdamW(weight_model.parameters(), lr=0.05)

#define loss function
inner_loss_func = windowed_p_loss
outer_loss_func = nn.CrossEntropyLoss()

# training loop
seq_len = 100

# save initial state
memory_module_init = torchopt.extract_state_dict(memory_module)
inner_optimizer_init = torchopt.extract_state_dict(inner_optimizer)

for epoch in range(100):
    # Generate a new sequence of keys and values for each epoch
    my_dataset = InContextRecallDataset(seq_len=seq_len, key_dim=key_dim, val_dim=val_dim, context_size=context_dim, output_corr=0.5)

    _, val_vectors = my_dataset[:]

    # Reset model and optimizer to initial state
    torchopt.recover_state_dict(memory_module, memory_module_init)
    torchopt.recover_state_dict(inner_optimizer, inner_optimizer_init)

    outer_optimizer.zero_grad()
    total_outer_loss = torch.tensor(0.0)

    # inner loop
    for i in range(seq_len):
        input, output = my_dataset[i]
        loss_weights = weight_model(input[-1]) # should only rely on current key in context
        if epoch % 10 == 0 and i % 50 == 0:
            print(f"Epoch {epoch}, Step {i}, Loss Weights: {loss_weights}")
        pred = memory_module(input)

        inner_loss = inner_loss_func(pred.T, output.T, loss_weights) # the loss has a transposed shape
        inner_optimizer.step(inner_loss)

        pred_after_update = memory_module(input[-1]) #now shape (1, val_dim) (we only take the last one.)
        #print(pred_after_update.shape, val_vectors.T.shape)
        logits = torch.matmul(pred_after_update, val_vectors.T) # shape (1, seq_len)
        target_index = torch.tensor([i], dtype=torch.long) # shape (1,)
        #print(temp.shape , target_index.shape, logits.shape)
        #print(temp.dtype , target_index.dtype)
        outer_loss_step = outer_loss_func(logits, target_index)
        total_outer_loss += outer_loss_step
    
    # outer loop calculation
    final_outer_loss = total_outer_loss / seq_len
    final_outer_loss.backward()
    outer_optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Outer Loss: {final_outer_loss.item()}")


Epoch 0, Step 0, Loss Weights: tensor([0.3706, 0.4795, 0.6065, 0.4579, 0.4516], grad_fn=<SigmoidBackward0>)
Epoch 0, Step 50, Loss Weights: tensor([0.5511, 0.3821, 0.5415, 0.5142, 0.7165], grad_fn=<SigmoidBackward0>)
Epoch 0, Outer Loss: 4.188896656036377
Epoch 10, Step 0, Loss Weights: tensor([0.4109, 0.2953, 0.3787, 0.6504, 0.7293], grad_fn=<SigmoidBackward0>)
Epoch 10, Step 50, Loss Weights: tensor([0.4633, 0.4141, 0.5108, 0.5672, 0.6180], grad_fn=<SigmoidBackward0>)
Epoch 10, Outer Loss: 4.134033679962158
Epoch 20, Step 0, Loss Weights: tensor([0.2572, 0.4298, 0.5219, 0.7240, 0.6554], grad_fn=<SigmoidBackward0>)
Epoch 20, Step 50, Loss Weights: tensor([0.2516, 0.2383, 0.5506, 0.7069, 0.7413], grad_fn=<SigmoidBackward0>)
Epoch 20, Outer Loss: 4.033472061157227
Epoch 30, Step 0, Loss Weights: tensor([0.2045, 0.1867, 0.2012, 0.8190, 0.7957], grad_fn=<SigmoidBackward0>)
Epoch 30, Step 50, Loss Weights: tensor([0.1665, 0.1707, 0.3483, 0.7525, 0.7909], grad_fn=<SigmoidBackward0>)
Epoch 3