In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt

# Define the dimensions of the vectors
key_dim = 10
val_dim = 10
hidden_dim = 20
context_dim = 5

# Define the two neural networks, f and g
class F_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(key_dim + hidden_dim, 8)
        self.fc2 = nn.Linear(8, val_dim)

    def forward(self, key, g_out):
        x = torch.cat([key, g_out], dim=-1)
        x = F.tanh(self.fc1(x))
        return self.fc2(x)

class G_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(key_dim, context_dim)
        self.fc2 = nn.Linear(context_dim, hidden_dim)

    def forward(self, key):
        x = F.sigmoid(self.fc1(key))
        return self.fc2(x)

# Instantiate the networks
f_net = F_Net() # inner loop network. depends on g_net which only updates in the outer loop
g_net = G_Net()

# Define the meta-optimizer for the inner loop (f_net)
inner_optimizer = torchopt.MetaAdam(f_net, lr=1e-3)

# Define the optimizer for the outer loop (g_net)
outer_optimizer = torch.optim.Adam(g_net.parameters(), lr=1e-4)

# Define the loss functions
inner_loss_fn = nn.MSELoss()
outer_loss_fn = lambda pred, target: 1 - F.cosine_similarity(pred, target).mean()

# --- Training Loop ---

# Generate a sequence of dummy data
num_steps = 5
batch_size = 4
sequence_of_keys = [torch.randn(batch_size, key_dim) for _ in range(num_steps)]
sequence_of_vals = [torch.randn(batch_size, val_dim) for _ in range(num_steps)]

# We need to save the initial state of f_net to reset it for each outer loop iteration
f_net_initial_state = torchopt.extract_state_dict(f_net)
inner_optim_initial_state = torchopt.extract_state_dict(inner_optimizer)

# Number of outer loop iterations
for outer_iter in range(100):
    # Reset f_net to its initial state before each outer loop
    torchopt.recover_state_dict(f_net, f_net_initial_state)
    torchopt.recover_state_dict(inner_optimizer, inner_optim_initial_state)
    
    outer_optimizer.zero_grad()

    # --- CHANGE 1: Initialize total outer loss for the sequence ---
    total_outer_loss = torch.tensor(0.0)

    # --- Inner Loop ---
    for i in range(num_steps):
        key = sequence_of_keys[i]
        val = sequence_of_vals[i]

        # --- Part 1: Inner Loop Update ---
        # This part builds the differentiable update rule for f_net
        g_output = g_net(key)
        f_output = f_net(key, g_output)
        inner_loss = inner_loss_fn(f_output, val)
        inner_optimizer.step(inner_loss)

        # --- Part 2: Outer Loss Calculation ---
        # This part evaluates the updated f_net and connects the outer loss
        # back to g_net. It MUST be done with gradients enabled.
        
        # We use the *updated* f_net to make a prediction.
        # The g_output tensor here will carry the gradients back to g_net.
        f_output_after_update = f_net(key, g_net(key)) # why do we use a different name here?
        
        # Calculate the outer loss for the current step
        step_outer_loss = outer_loss_fn(f_output_after_update, val)
        
        # Add it to the total outer loss for the sequence
        total_outer_loss = total_outer_loss + step_outer_loss


    # --- Outer Loop Loss Calculation and Update ---

    # --- CHANGE 3: Backpropagate the accumulated outer loss ---
    # This will compute gradients for g_net's parameters from all steps.
    # We also need to normalize the loss by the number of steps to keep the gradient magnitude stable
    final_outer_loss = total_outer_loss / num_steps
    final_outer_loss.backward()

    # Update g_net's parameters
    outer_optimizer.step()

    if (outer_iter + 1) % 10 == 0:
        # --- CHANGE 4: Print the averaged outer loss ---
        print(f"Outer Iteration: {outer_iter + 1}, Avg Outer Loss: {final_outer_loss.item():.4f}")

print("Training finished.")

Outer Iteration: 10, Avg Outer Loss: 1.0258
Outer Iteration: 20, Avg Outer Loss: 1.0253
Outer Iteration: 30, Avg Outer Loss: 1.0248
Outer Iteration: 40, Avg Outer Loss: 1.0242
Outer Iteration: 50, Avg Outer Loss: 1.0237
Outer Iteration: 60, Avg Outer Loss: 1.0230
Outer Iteration: 70, Avg Outer Loss: 1.0225
Outer Iteration: 80, Avg Outer Loss: 1.0219
Outer Iteration: 90, Avg Outer Loss: 1.0214
Outer Iteration: 100, Avg Outer Loss: 1.0208
Training finished.
