In [1]:
import sys
sys.path.append("../")

import torch
from torch import Tensor, nn, no_grad, zeros_like
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import SGD, Optimizer

from src.models.maml import Model, Reptile

def main():
    # Set device.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create a synthetic dataset.
    # We'll generate a binary classification problem in 2D.
    N = 1000
    X = torch.randn(N, 2)
    y = (X.sum(dim=1) > 0).long()  # Label is 1 if sum > 0, else 0.
    
    dataset = TensorDataset(X, y)
    # Use a small batch size so that each batch is treated as a "task".
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    # Define a simple model: 2 input features, one hidden layer with 10 units, 2 outputs.
    model = Model(nodes_per_layer=[2, 10, 2],
                  activations_per_layer=[nn.ReLU(), nn.Identity()]).to(device)

    # Define the loss function.
    loss_fn = nn.CrossEntropyLoss()

    # Outer (meta) optimizer: using SGD.
    meta_lr = 1.0  # Outer learning rate.
    meta_optimizer = SGD(model.parameters(), lr=meta_lr)

    # Inner-loop learning rate (α).
    inner_lr = 0.01

    # Number of inner-loop gradient steps per task.
    n_gradient_steps = 5

    # Instantiate the Reptile meta-learner.
    reptile = Reptile(model=model, n_gradient_steps=n_gradient_steps, device=device,
                      loss_function=loss_fn, meta_optimizer=meta_optimizer, inner_lr=inner_lr)

    # Train using Reptile for a few epochs.
    n_epochs = 5
    n_parallel_tasks = 4  # Number of tasks (batches) per meta-update.
    reptile.fit(dataloader, n_epochs=n_epochs, n_parallel_tasks=n_parallel_tasks)

    # (Optional) Evaluate the trained model on the training data.
        # Create a synthetic dataset.
    # We'll generate a binary classification problem in 2D.
    N = 1000
    X = torch.randn(N, 2)
    y = (X.sum(dim=1) > 0).long()  # Label is 1 if sum > 0, else 0.
    
    dataset = TensorDataset(X, y)
    # Use a small batch size so that each batch is treated as a "task".
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            predictions = outputs.argmax(dim=1)
            correct += (predictions == y_batch).sum().item()
            total += y_batch.size(0)
    print(f"Training accuracy: {correct/total*100:.2f}%")

In [2]:
main()

[32m2025-02-04 15:26:15.887[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128[0m - [1mMeta-update iteration 1/80 complete: processed 4 parallel tasks.[0m
[32m2025-02-04 15:26:15.946[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128[0m - [1mMeta-update iteration 2/80 complete: processed 4 parallel tasks.[0m
[32m2025-02-04 15:26:15.995[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128[0m - [1mMeta-update iteration 3/80 complete: processed 4 parallel tasks.[0m
[32m2025-02-04 15:26:16.066[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128[0m - [1mMeta-update iteration 4/80 complete: processed 4 parallel tasks.[0m
[32m2025-02-04 15:26:16.132[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128[0m - [1mMeta-update iteration 5/80 complete: processed 4 parallel tasks.[0m
[32m2025-02-04 15:26:16.292[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128

Training accuracy: 97.30%


: 