In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt

# Define the dimensions of the vectors
key_dim = 10
val_dim = 5
hidden_dim = 20

# Define the two neural networks, f and g
class F_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(key_dim + hidden_dim, 8)
        self.fc2 = nn.Linear(8, val_dim)

    def forward(self, key, g_out):
        x = torch.cat([key, g_out], dim=-1)
        x = F.tanh(self.fc1(x))
        return self.fc2(x)

class G_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(key_dim, 8)
        self.fc2 = nn.Linear(8, hidden_dim)

    def forward(self, key):
        x = F.sigmoid(self.fc1(key))
        return self.fc2(x)

# Instantiate the networks
f_net = F_Net()
g_net = G_Net()

# Define the meta-optimizer for the inner loop (f_net)
# We use MetaAdam, which allows us to differentiate through the optimization process.
inner_optimizer = torchopt.MetaAdam(f_net, lr=1e-3)

# Define the optimizer for the outer loop (g_net)
outer_optimizer = torch.optim.Adam(g_net.parameters(), lr=1e-4)

# Define the loss functions
inner_loss_fn = nn.MSELoss()
# For the outer loss, we want to minimize the cosine distance, which is equivalent
# to maximizing the cosine similarity. We can achieve this by taking 1 minus the
# cosine similarity.
outer_loss_fn = lambda pred, target: 1 - F.cosine_similarity(pred, target).mean()

# --- Training Loop ---

# Generate a sequence of dummy data
num_steps = 5
batch_size = 4
sequence_of_keys = [torch.randn(batch_size, key_dim) for _ in range(num_steps)]
sequence_of_vals = [torch.randn(batch_size, val_dim) for _ in range(num_steps)]

# We need to save the initial state of f_net to reset it for each outer loop iteration
f_net_initial_state = torchopt.extract_state_dict(f_net)
inner_optim_initial_state = torchopt.extract_state_dict(inner_optimizer)

# Number of outer loop iterations
for outer_iter in range(100):
    # Reset f_net to its initial state before each outer loop
    torchopt.recover_state_dict(f_net, f_net_initial_state)
    torchopt.recover_state_dict(inner_optimizer, inner_optim_initial_state)

    outer_optimizer.zero_grad()

    total_outer_loss = 0.0

    # --- Inner Loop ---
    # Traverse the sequence of key-value pairs
    for i in range(num_steps):
        key = sequence_of_keys[i]
        val = sequence_of_vals[i]

        # Get the output of g(key)
        g_output = g_net(key)

        # Get the prediction from f(key, g(key))
        f_output = f_net(key, g_output)

        # Calculate the inner loss
        inner_loss = inner_loss_fn(f_output, val)

        # Update f_net using the meta-optimizer
        # This step is differentiable
        inner_optimizer.step(inner_loss)

    # --- Outer Loop Loss Calculation ---
    # After the inner loop, calculate the outer loss on the last key-value pair
    # using the updated f_net
    final_key = sequence_of_keys[-1]
    final_val = sequence_of_vals[-1]
    final_g_output = g_net(final_key)
    final_f_output = f_net(final_key, final_g_output)

    outer_loss = outer_loss_fn(final_f_output, final_val)

    # Backpropagate the outer loss through the unrolled inner loop
    # This will compute gradients for g_net's parameters
    outer_loss.backward()

    # Update g_net's parameters
    outer_optimizer.step()

    if (outer_iter + 1) % 10 == 0:

        print(f"Outer Iteration: {outer_iter + 1}, Outer Loss: {outer_loss.item():.4f}")

print("Training finished.")

Outer Iteration: 10, Outer Loss: 1.0416
Outer Iteration: 20, Outer Loss: 1.0340
Outer Iteration: 30, Outer Loss: 1.0262
Outer Iteration: 40, Outer Loss: 1.0184
Outer Iteration: 50, Outer Loss: 1.0107
Outer Iteration: 60, Outer Loss: 1.0031
Outer Iteration: 70, Outer Loss: 0.9955
Outer Iteration: 80, Outer Loss: 0.9878
Outer Iteration: 90, Outer Loss: 0.9799
Outer Iteration: 100, Outer Loss: 0.9723
Training finished.
