# Tutorial 18: Meta-Learning and Few-Shot Learning

This tutorial explores meta-learning (learning to learn) and few-shot learning techniques, including MAML, Prototypical Networks, and Matching Networks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from typing import Tuple, List, Dict
import matplotlib.pyplot as plt
from collections import OrderedDict
import copy
import random
from sklearn.manifold import TSNE
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Understanding Few-Shot Learning

Few-shot learning aims to learn from very few examples. Let's visualize what a few-shot task looks like.

In [None]:
# Generate synthetic few-shot task
def generate_task(n_way=5, k_shot=5, q_queries=15, feature_dim=100, noise_level=0.3):
    """Generate a synthetic N-way K-shot task"""
    # Generate class prototypes
    prototypes = torch.randn(n_way, feature_dim) * 2
    
    support_set = []
    support_labels = []
    query_set = []
    query_labels = []
    
    for class_idx in range(n_way):
        # Support set (few examples per class)
        class_samples = prototypes[class_idx] + noise_level * torch.randn(k_shot, feature_dim)
        support_set.append(class_samples)
        support_labels.extend([class_idx] * k_shot)
        
        # Query set (examples to classify)
        class_queries = prototypes[class_idx] + noise_level * torch.randn(q_queries, feature_dim)
        query_set.append(class_queries)
        query_labels.extend([class_idx] * q_queries)
    
    support_set = torch.cat(support_set, dim=0)
    query_set = torch.cat(query_set, dim=0)
    support_labels = torch.tensor(support_labels)
    query_labels = torch.tensor(query_labels)
    
    # Shuffle
    support_perm = torch.randperm(len(support_labels))
    query_perm = torch.randperm(len(query_labels))
    
    return (support_set[support_perm], support_labels[support_perm],
            query_set[query_perm], query_labels[query_perm])

In [None]:
# Visualize a few-shot task in 2D
support_x, support_y, query_x, query_y = generate_task(n_way=3, k_shot=5, q_queries=10, feature_dim=2)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Colors and markers for different classes
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12', '#9b59b6']
markers = ['o', 's', '^', 'D', 'v']

# Plot support and query sets
for ax, title in [(ax1, 'Few-Shot Task Visualization'), (ax2, 'Task Structure')]:
    for class_idx in range(3):
        # Support set
        support_mask = support_y == class_idx
        ax.scatter(support_x[support_mask, 0], support_x[support_mask, 1], 
                  c=colors[class_idx], marker=markers[class_idx], s=150, 
                  label=f'Class {class_idx} (support)', edgecolors='black', linewidth=2)
        
        # Query set
        query_mask = query_y == class_idx
        ax.scatter(query_x[query_mask, 0], query_x[query_mask, 1], 
                  c=colors[class_idx], marker=markers[class_idx], s=80, 
                  label=f'Class {class_idx} (query)', alpha=0.5, edgecolors='gray')
    
    ax.set_xlabel('Feature 1', fontsize=12)
    ax.set_ylabel('Feature 2', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.grid(True, alpha=0.3)

ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Task structure explanation
ax2.text(0.5, 0.95, '3-way 5-shot Learning Task', 
         transform=ax2.transAxes, ha='center', fontsize=16, fontweight='bold')
ax2.text(0.5, 0.85, '• 3-way: 3 classes', transform=ax2.transAxes, ha='center', fontsize=12)
ax2.text(0.5, 0.75, '• 5-shot: 5 examples per class in support set', 
         transform=ax2.transAxes, ha='center', fontsize=12)
ax2.text(0.5, 0.65, '• Goal: Classify query examples using support set', 
         transform=ax2.transAxes, ha='center', fontsize=12)

plt.tight_layout()
plt.show()

print("Few-shot Learning Task:")
print(f"Support set: {support_x.shape[0]} examples ({support_x.shape[0]//3} per class)")
print(f"Query set: {query_x.shape[0]} examples to classify")

## 2. Model-Agnostic Meta-Learning (MAML)

MAML learns an initialization that can be quickly fine-tuned to new tasks with just a few gradient steps.

In [None]:
class SimpleClassifier(nn.Module):
    """Simple neural network for few-shot classification"""
    def __init__(self, input_size=84*84*3, hidden_size=128, output_size=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        self.classifier = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        features = self.features(x)
        out = self.classifier(features)
        return out, features

In [None]:
class MAML:
    """Model-Agnostic Meta-Learning implementation"""
    def __init__(self, model, inner_lr=0.01, meta_lr=0.001, inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.inner_steps = inner_steps
        self.meta_optimizer = optim.Adam(self.model.parameters(), lr=meta_lr)
        
    def inner_loop(self, support_x, support_y, fast_weights=None):
        """Inner loop adaptation on support set"""
        if fast_weights is None:
            fast_weights = OrderedDict(self.model.named_parameters())
        
        # Track adaptation path
        adaptation_losses = []
        
        for step in range(self.inner_steps):
            # Forward pass with fast weights
            logits = self.functional_forward(support_x, fast_weights)
            loss = F.cross_entropy(logits, support_y)
            adaptation_losses.append(loss.item())
            
            # Compute gradients
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            
            # Update fast weights
            fast_weights = OrderedDict(
                (name, param - self.inner_lr * grad)
                for (name, param), grad in zip(fast_weights.items(), grads)
            )
        
        return fast_weights, adaptation_losses
    
    def functional_forward(self, x, params):
        """Forward pass using provided parameters"""
        x = x.view(x.size(0), -1)
        
        # Manually apply layers with given parameters
        x = F.linear(x, params['features.0.weight'], params['features.0.bias'])
        x = F.relu(x)
        x = F.linear(x, params['features.2.weight'], params['features.2.bias'])
        x = F.relu(x)
        x = F.linear(x, params['classifier.weight'], params['classifier.bias'])
        
        return x
    
    def meta_train_step(self, tasks):
        """Meta-training step on batch of tasks"""
        meta_loss = 0
        task_accuracies = []
        
        for support_x, support_y, query_x, query_y in tasks:
            # Inner loop adaptation
            fast_weights, _ = self.inner_loop(support_x, support_y)
            
            # Evaluate on query set
            query_logits = self.functional_forward(query_x, fast_weights)
            query_loss = F.cross_entropy(query_logits, query_y)
            
            # Track accuracy
            with torch.no_grad():
                accuracy = (query_logits.argmax(dim=1) == query_y).float().mean()
                task_accuracies.append(accuracy.item())
            
            meta_loss += query_loss
        
        # Meta-update
        meta_loss = meta_loss / len(tasks)
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item(), np.mean(task_accuracies)

In [None]:
# Train MAML
input_size = 100
maml_model = SimpleClassifier(input_size=input_size, output_size=5).to(device)
maml = MAML(maml_model, inner_lr=0.01, meta_lr=0.001, inner_steps=5)

# Training loop
meta_losses = []
meta_accuracies = []

print("Training MAML...")
for episode in range(200):
    # Generate batch of tasks
    tasks = []
    for _ in range(4):  # 4 tasks per meta-batch
        task = generate_task(n_way=5, k_shot=5, q_queries=15, feature_dim=input_size)
        tasks.append([t.to(device) for t in task])
    
    # Meta-train step
    loss, accuracy = maml.meta_train_step(tasks)
    meta_losses.append(loss)
    meta_accuracies.append(accuracy)
    
    if episode % 40 == 0:
        print(f"Episode {episode}, Meta Loss: {loss:.4f}, Meta Accuracy: {accuracy:.4f}")

In [None]:
# Visualize MAML adaptation process
# Generate a new task and show adaptation
test_task = generate_task(n_way=5, k_shot=5, q_queries=15, feature_dim=input_size)
support_x, support_y, query_x, query_y = [t.to(device) for t in test_task]

# Track adaptation
adaptation_steps = 10
fast_weights = OrderedDict(maml.model.named_parameters())
adaptation_losses = []
adaptation_accuracies = []

for step in range(adaptation_steps):
    # Evaluate current performance
    with torch.no_grad():
        query_logits = maml.functional_forward(query_x, fast_weights)
        accuracy = (query_logits.argmax(dim=1) == query_y).float().mean()
        adaptation_accuracies.append(accuracy.item())
    
    # Adaptation step
    support_logits = maml.functional_forward(support_x, fast_weights)
    loss = F.cross_entropy(support_logits, support_y)
    adaptation_losses.append(loss.item())
    
    # Update
    grads = torch.autograd.grad(loss, fast_weights.values())
    fast_weights = OrderedDict(
        (name, param - maml.inner_lr * grad)
        for (name, param), grad in zip(fast_weights.items(), grads)
    )

# Plot adaptation
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(adaptation_losses, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel('Adaptation Step')
ax1.set_ylabel('Support Set Loss')
ax1.set_title('MAML Adaptation: Loss')
ax1.grid(True, alpha=0.3)

ax2.plot(adaptation_accuracies, 'o-', linewidth=2, markersize=8, color='green')
ax2.set_xlabel('Adaptation Step')
ax2.set_ylabel('Query Set Accuracy')
ax2.set_title('MAML Adaptation: Accuracy')
ax2.set_ylim(0, 1)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Initial accuracy: {adaptation_accuracies[0]:.2%}")
print(f"Final accuracy: {adaptation_accuracies[-1]:.2%}")
print(f"Improvement: {adaptation_accuracies[-1] - adaptation_accuracies[0]:.2%}")

## 3. Prototypical Networks

Prototypical Networks learn an embedding space where classification is performed by finding the nearest class prototype.

In [None]:
class PrototypicalNetwork(nn.Module):
    """Prototypical Networks for few-shot classification"""
    def __init__(self, input_size, hidden_size=128, embedding_size=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embedding_size)
        )
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)
    
    def compute_prototypes(self, support_embeddings, support_labels, n_way):
        """Compute class prototypes from support set"""
        prototypes = torch.zeros(n_way, support_embeddings.size(1)).to(support_embeddings.device)
        
        for class_idx in range(n_way):
            mask = support_labels == class_idx
            class_embeddings = support_embeddings[mask]
            prototypes[class_idx] = class_embeddings.mean(dim=0)
        
        return prototypes
    
    def prototypical_loss(self, prototypes, query_embeddings, query_labels):
        """Compute prototypical loss"""
        # Compute distances from queries to prototypes
        distances = torch.cdist(query_embeddings, prototypes)
        
        # Convert to similarities (negative distance)
        log_p_y = F.log_softmax(-distances, dim=1)
        
        # Compute loss
        loss = F.nll_loss(log_p_y, query_labels)
        
        # Compute accuracy
        predictions = (-distances).argmax(dim=1)
        accuracy = (predictions == query_labels).float().mean()
        
        return loss, accuracy, distances

In [None]:
# Train Prototypical Network
proto_net = PrototypicalNetwork(input_size=100, embedding_size=64).to(device)
optimizer = optim.Adam(proto_net.parameters(), lr=0.001)

proto_losses = []
proto_accuracies = []

print("Training Prototypical Network...")
for episode in range(300):
    # Generate task
    support_x, support_y, query_x, query_y = generate_task(
        n_way=5, k_shot=5, q_queries=15, feature_dim=100
    )
    
    support_x = support_x.to(device)
    support_y = support_y.to(device)
    query_x = query_x.to(device)
    query_y = query_y.to(device)
    
    # Forward pass
    support_embeddings = proto_net(support_x)
    query_embeddings = proto_net(query_x)
    
    # Compute prototypes
    prototypes = proto_net.compute_prototypes(support_embeddings, support_y, n_way=5)
    
    # Compute loss
    loss, accuracy, _ = proto_net.prototypical_loss(prototypes, query_embeddings, query_y)
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    proto_losses.append(loss.item())
    proto_accuracies.append(accuracy.item())
    
    if episode % 60 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}")

In [None]:
# Visualize prototypical space
# Generate a task and visualize embeddings
test_task = generate_task(n_way=5, k_shot=5, q_queries=20, feature_dim=100)
support_x, support_y, query_x, query_y = [t.to(device) for t in test_task]

with torch.no_grad():
    support_embeddings = proto_net(support_x)
    query_embeddings = proto_net(query_x)
    prototypes = proto_net.compute_prototypes(support_embeddings, support_y, n_way=5)

# Combine all embeddings for t-SNE
all_embeddings = torch.cat([support_embeddings, query_embeddings, prototypes], dim=0)
all_embeddings_np = all_embeddings.cpu().numpy()

# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(all_embeddings_np)

# Split back
n_support = len(support_y)
n_query = len(query_y)
support_2d = embeddings_2d[:n_support]
query_2d = embeddings_2d[n_support:n_support+n_query]
prototypes_2d = embeddings_2d[n_support+n_query:]

# Plot
plt.figure(figsize=(10, 8))
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12', '#9b59b6']

for class_idx in range(5):
    # Support points
    support_mask = support_y.cpu() == class_idx
    plt.scatter(support_2d[support_mask, 0], support_2d[support_mask, 1],
               c=colors[class_idx], marker='o', s=100, alpha=0.6,
               label=f'Class {class_idx} support')
    
    # Query points
    query_mask = query_y.cpu() == class_idx
    plt.scatter(query_2d[query_mask, 0], query_2d[query_mask, 1],
               c=colors[class_idx], marker='x', s=100, alpha=0.8)
    
    # Prototype
    plt.scatter(prototypes_2d[class_idx, 0], prototypes_2d[class_idx, 1],
               c=colors[class_idx], marker='*', s=500, edgecolors='black', linewidth=2)

plt.xlabel('t-SNE dimension 1')
plt.ylabel('t-SNE dimension 2')
plt.title('Prototypical Network Embedding Space (t-SNE visualization)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Legend: ○ = support examples, × = query examples, ★ = prototypes")

## 4. Matching Networks

Matching Networks use attention mechanisms to compare query examples with the support set.

In [None]:
class MatchingNetwork(nn.Module):
    """Matching Networks with attention mechanism"""
    def __init__(self, input_size, hidden_size=128, embedding_size=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embedding_size)
        )
        
        # Bidirectional LSTM for full context embeddings
        self.lstm = nn.LSTM(embedding_size, embedding_size, batch_first=True, bidirectional=True)
        self.attention_fc = nn.Linear(embedding_size * 2, embedding_size)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)
    
    def attention(self, query, support, support_labels):
        """Compute attention-weighted predictions"""
        # Compute cosine similarity
        query_norm = F.normalize(query, p=2, dim=1)
        support_norm = F.normalize(support, p=2, dim=1)
        similarities = torch.mm(query_norm, support_norm.t())
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(similarities, dim=1)
        
        # Convert labels to one-hot
        n_way = support_labels.max().item() + 1
        support_labels_onehot = F.one_hot(support_labels, n_way).float()
        
        # Weighted sum of support labels
        predictions = torch.mm(attention_weights, support_labels_onehot)
        
        return predictions, attention_weights
    
    def full_context_embedding(self, embeddings):
        """Apply bidirectional LSTM for full context"""
        # Add batch dimension if needed
        if embeddings.dim() == 2:
            embeddings = embeddings.unsqueeze(0)
        
        lstm_out, _ = self.lstm(embeddings)
        lstm_out = self.attention_fc(lstm_out)
        
        return lstm_out.squeeze(0)

In [None]:
# Train Matching Network
matching_net = MatchingNetwork(input_size=100).to(device)
matching_optimizer = optim.Adam(matching_net.parameters(), lr=0.001)

matching_losses = []
matching_accuracies = []

print("Training Matching Network...")
for episode in range(300):
    # Generate task
    support_x, support_y, query_x, query_y = generate_task(
        n_way=5, k_shot=5, q_queries=15, feature_dim=100
    )
    
    support_x = support_x.to(device)
    support_y = support_y.to(device)
    query_x = query_x.to(device)
    query_y = query_y.to(device)
    
    # Get embeddings
    support_embeddings = matching_net(support_x)
    query_embeddings = matching_net(query_x)
    
    # Apply full context embedding
    support_embeddings = matching_net.full_context_embedding(support_embeddings)
    
    # Get predictions using attention
    predictions, attention_weights = matching_net.attention(
        query_embeddings, support_embeddings, support_y
    )
    
    # Compute loss
    loss = F.cross_entropy(predictions, query_y)
    accuracy = (predictions.argmax(dim=1) == query_y).float().mean()
    
    # Backward pass
    matching_optimizer.zero_grad()
    loss.backward()
    matching_optimizer.step()
    
    matching_losses.append(loss.item())
    matching_accuracies.append(accuracy.item())
    
    if episode % 60 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}")

In [None]:
# Visualize attention mechanism
# Generate a small task for visualization
vis_task = generate_task(n_way=3, k_shot=3, q_queries=5, feature_dim=100)
support_x, support_y, query_x, query_y = [t.to(device) for t in vis_task]

with torch.no_grad():
    support_embeddings = matching_net(support_x)
    query_embeddings = matching_net(query_x)
    support_embeddings = matching_net.full_context_embedding(support_embeddings)
    predictions, attention_weights = matching_net.attention(
        query_embeddings, support_embeddings, support_y
    )

# Plot attention heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(attention_weights.cpu().numpy(), 
            cmap='YlOrRd', 
            xticklabels=[f'S{i}\n(C{support_y[i].item()})' for i in range(len(support_y))],
            yticklabels=[f'Q{i}\n(C{query_y[i].item()})' for i in range(len(query_y))],
            annot=True, 
            fmt='.2f',
            cbar_kws={'label': 'Attention Weight'})

plt.xlabel('Support Examples', fontsize=12)
plt.ylabel('Query Examples', fontsize=12)
plt.title('Matching Network Attention Weights', fontsize=14)
plt.tight_layout()
plt.show()

print("S = Support, Q = Query, C = Class")
print("High attention weights indicate which support examples are most relevant for each query")

## 5. Algorithm Comparison

Let's compare the performance and characteristics of different meta-learning algorithms.

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# MAML
axes[0, 0].plot(meta_losses, label='Loss', alpha=0.7)
axes[0, 0].set_title('MAML Training', fontsize=14)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Meta Loss')
axes[0, 0].grid(True, alpha=0.3)

axes[1, 0].plot(meta_accuracies, label='Accuracy', alpha=0.7, color='green')
axes[1, 0].set_title('MAML Accuracy', fontsize=14)
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Meta Accuracy')
axes[1, 0].set_ylim(0, 1)
axes[1, 0].grid(True, alpha=0.3)

# Prototypical Networks
axes[0, 1].plot(proto_losses, label='Loss', alpha=0.7, color='orange')
axes[0, 1].set_title('Prototypical Networks Training', fontsize=14)
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True, alpha=0.3)

axes[1, 1].plot(proto_accuracies, label='Accuracy', alpha=0.7, color='green')
axes[1, 1].set_title('Prototypical Networks Accuracy', fontsize=14)
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].set_ylim(0, 1)
axes[1, 1].grid(True, alpha=0.3)

# Matching Networks
axes[0, 2].plot(matching_losses, label='Loss', alpha=0.7, color='purple')
axes[0, 2].set_title('Matching Networks Training', fontsize=14)
axes[0, 2].set_xlabel('Episode')
axes[0, 2].set_ylabel('Loss')
axes[0, 2].grid(True, alpha=0.3)

axes[1, 2].plot(matching_accuracies, label='Accuracy', alpha=0.7, color='green')
axes[1, 2].set_title('Matching Networks Accuracy', fontsize=14)
axes[1, 2].set_xlabel('Episode')
axes[1, 2].set_ylabel('Accuracy')
axes[1, 2].set_ylim(0, 1)
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Algorithm comparison radar chart
categories = ['Training\nSpeed', 'Inference\nSpeed', 'Memory\nEfficiency', 
              'Adaptation\nFlexibility', 'Implementation\nSimplicity']

# Scores (1-5 scale)
algorithms = {
    'MAML': [2, 3, 2, 5, 2],
    'Prototypical': [5, 5, 5, 3, 5],
    'Matching': [4, 4, 3, 4, 3],
    'Reptile': [4, 3, 4, 4, 4]
}

# Create radar chart
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='polar')

# Number of variables
num_vars = len(categories)

# Compute angle for each axis
angles = [n / float(num_vars) * 2 * np.pi for n in range(num_vars)]
angles += angles[:1]

# Plot each algorithm
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']
for idx, (algo, scores) in enumerate(algorithms.items()):
    values = scores + scores[:1]
    ax.plot(angles, values, 'o-', linewidth=2, label=algo, color=colors[idx])
    ax.fill(angles, values, alpha=0.15, color=colors[idx])

# Fix axis to go in the right order and start at 12 o'clock
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)

# Draw axis lines for each angle and label
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories)

# Set y-axis limits and labels
ax.set_ylim(0, 5)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(['1', '2', '3', '4', '5'])

# Add legend and title
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
plt.title('Meta-Learning Algorithm Comparison', size=16, y=1.08)

plt.tight_layout()
plt.show()

## 6. Practical Guidelines

Let's summarize when to use each algorithm and best practices.

In [None]:
# Create decision flowchart
fig, ax = plt.subplots(figsize=(12, 8))
ax.axis('off')

# Title
ax.text(0.5, 0.95, 'Meta-Learning Algorithm Selection Guide', 
        ha='center', fontsize=18, fontweight='bold')

# Decision tree text
decision_tree = """
Start: Do you need fast adaptation at test time?
│
├─ Yes → Do you have computational resources for second-order gradients?
│   │
│   ├─ Yes → Use MAML
│   │        • Best for: Tasks requiring fine-tuning
│   │        • Pros: Flexible, strong performance
│   │        • Cons: Computationally expensive
│   │
│   └─ No → Use Reptile
│            • Best for: Scalable meta-learning
│            • Pros: Simple, efficient
│            • Cons: Slightly worse than MAML
│
└─ No → Is your problem metric-based?
    │
    ├─ Yes → Do you need attention mechanisms?
    │   │
    │   ├─ Yes → Use Matching Networks
    │   │        • Best for: Complex similarity measures
    │   │        • Pros: Flexible comparisons
    │   │        • Cons: More complex
    │   │
    │   └─ No → Use Prototypical Networks
    │            • Best for: Simple classification
    │            • Pros: Fast, simple, effective
    │            • Cons: Less flexible
    │
    └─ Consider other approaches
"""

ax.text(0.1, 0.5, decision_tree, fontsize=11, family='monospace', 
        verticalalignment='center')

plt.tight_layout()
plt.show()

In [None]:
# Best practices summary
print("Meta-Learning Best Practices")
print("=" * 60)

practices = [
    ("Task Construction", [
        "Ensure meta-train and meta-test tasks come from same distribution",
        "Use sufficient task diversity during training",
        "Balance classes within each task"
    ]),
    ("Model Architecture", [
        "Keep models relatively small for few-shot scenarios",
        "Use appropriate embedding dimensions",
        "Consider domain-specific architectures (CNN for images, etc.)"
    ]),
    ("Training Tips", [
        "Use episodic training matching test conditions",
        "Monitor both support and query set performance",
        "Implement proper train/val/test splits at task level"
    ]),
    ("Hyperparameters", [
        "Inner LR (MAML): typically 0.01-0.1",
        "Outer LR: typically 0.001-0.003",
        "Inner steps: 5-10 for most problems"
    ]),
    ("Evaluation", [
        "Test on completely new classes/tasks",
        "Report confidence intervals over multiple runs",
        "Consider both mean and worst-case performance"
    ])
]

for category, tips in practices:
    print(f"\n{category}:")
    for tip in tips:
        print(f"  • {tip}")

print("\n" + "=" * 60)
print("Remember: The choice of algorithm depends heavily on your specific")
print("problem constraints and requirements!")

## Summary

In this tutorial, we explored meta-learning and few-shot learning:

1. **MAML**: Learns an initialization for fast gradient-based adaptation
2. **Prototypical Networks**: Uses class prototypes in embedding space
3. **Matching Networks**: Employs attention mechanisms for classification
4. **Key Concepts**: Task distributions, episodic training, fast adaptation

### Key Takeaways:
- Meta-learning enables learning from very few examples
- Different algorithms suit different scenarios
- Optimization-based (MAML) vs. metric-based (ProtoNet) approaches
- Proper task construction is crucial
- These techniques are vital for data-scarce domains

Meta-learning is an active research area with applications in robotics, drug discovery, personalized systems, and more!