# TP4 Nano — MAML

**Paper:** [Feature Learning in Infinite-Width Neural Networks](https://arxiv.org/abs/2011.14522)  
**Code reference:** [edwardjhu/TP4](https://github.com/edwardjhu/TP4)

Nano MAML: 5-way 1-shot Omniglot with linear 1LP; finite widths (2, 8, 32) and infinite-width MUP; first-order MAML. Runs in a few minutes.

In [None]:
import sys
from pathlib import Path

# Project root = nanomup (folder containing tp/ and notebooks/)
_cwd = Path.cwd().resolve()
if (_cwd / "tp" / "tp4_maml.py").exists():
    _root = _cwd
elif (_cwd.parent / "tp" / "tp4_maml.py").exists():
    _root = _cwd.parent
else:
    _root = Path("..").resolve()
sys.path.insert(0, str(_root))

import torch
from tp.tp4_maml import train_maml_nano, evaluate_maml_nano

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Train first-order MAML on Omniglot 5-way 1-shot for width 2, 8, 32 and infinite-width MUP. Evaluate on test split and show query accuracy vs width.

In [None]:
data_root = Path("..") / "data" if (Path("..") / "data").exists() else Path("./data")
data_root.mkdir(parents=True, exist_ok=True)
data_root = str(data_root)

maml_models = {}
maml_accs = {}

for width in [2, 8, 32]:
    model, acc_hist = train_maml_nano(
        width=width,
        inf_width=False,
        num_epochs=12,
        meta_lr=0.1,
        inner_lr=0.4,
        batch_size=24,
        data_root=data_root,
        device=device,
    )
    maml_models[f"width={width}"] = model
    maml_accs[f"width={width}"] = acc_hist[-1]
    acc = evaluate_maml_nano(model, num_tasks=100, data_root=data_root, device=device)
    maml_accs[f"width={width} (eval)"] = acc
    print(f"Width {width}: train acc (last) = {acc_hist[-1]:.4f}, test acc = {acc:.4f}")

model_inf, acc_hist_inf = train_maml_nano(
    width=None,
    inf_width=True,
    num_epochs=12,
    meta_lr=0.1,
    inner_lr=0.4,
    batch_size=24,
    data_root=data_root,
    device=device,
)
maml_models["inf (MUP)"] = model_inf
maml_accs["inf (MUP)"] = acc_hist_inf[-1]
acc_inf = evaluate_maml_nano(model_inf, num_tasks=100, data_root=data_root, device=device)
maml_accs["inf (MUP) (eval)"] = acc_inf
print(f"Inf width (MUP): train acc (last) = {acc_hist_inf[-1]:.4f}, test acc = {acc_inf:.4f}")

In [None]:
import matplotlib.pyplot as plt

widths = [2, 8, 32, "inf (MUP)"]
eval_accs = [maml_accs[f"width={w} (eval)"] for w in [2, 8, 32]] + [maml_accs["inf (MUP) (eval)"]]
print("MAML 5-way 1-shot Omniglot — Test query accuracy:")
for w, a in zip(widths, eval_accs):
    print(f"  width {w}: {a:.4f}")

plt.figure(figsize=(6, 4))
plt.bar([str(w) for w in widths], eval_accs, color="steelblue", edgecolor="navy", alpha=0.8)
plt.xlabel("Hidden width")
plt.ylabel("Query accuracy")
plt.title("MAML nano: linear 1LP, 5-way 1-shot (finite vs infinite-width MUP)")
plt.tight_layout()
plt.show()