# Few-Shot Learning with Prototypical Networks


# Import necessary libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import Omniglot
import numpy as np

# Check if GPU is available and set device



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the Prototypical Network model

In [None]:
class ProtoNet(nn.Module):
    def __init__(self, input_channels=1, hidden_channels=64):
        super(ProtoNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        return self.encoder(x).view(x.size(0), -1)

# Function to compute prototypes and classify, returning distances instead of indices

In [None]:
def compute_prototypes(support, support_labels, query):
    support_labels = support_labels.flatten()  # Convert from [1, 25] to [25]
    unique_labels = torch.unique(support_labels)
    prototypes = []

    # Compute the prototype for each class
    for label in unique_labels:
        class_samples = support[support_labels == label]
        prototypes.append(class_samples.mean(0))

    prototypes = torch.stack(prototypes).to(device)

    # Compute Euclidean distance between query samples and prototypes
    distances = torch.cdist(query, prototypes)
    return distances

# Example data loader for Few-Shot Learning

In [None]:
class FewShotDataset(Dataset):
    def __init__(self, dataset, n_way, k_shot, k_query, transform=None):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.transform = transform
        self.classes = list(range(len(dataset._flat_character_images)))

    def __len__(self):
        return 250  # Arbitrary length

    def __getitem__(self, idx):
        sampled_classes = np.random.choice(self.classes, self.n_way, replace=False)
        support = []
        query = []
        support_labels = []
        query_labels = []

        for i, class_id in enumerate(sampled_classes):
            class_samples = np.random.choice(range(20), self.k_shot + self.k_query, replace=False)
            for j in range(self.k_shot):
                img, _ = self.dataset[class_id]
                if self.transform:
                    img = self.transform(img)
                support.append(img)
                support_labels.append(i)
            for j in range(self.k_shot, self.k_shot + self.k_query):
                img, _ = self.dataset[class_id]
                if self.transform:
                    img = self.transform(img)
                query.append(img)
                query_labels.append(i)

        support = torch.stack(support)
        query = torch.stack(query)
        return support, torch.tensor(support_labels), query, torch.tensor(query_labels)


# Load the Omniglot dataset


In [None]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

In [None]:
train_dataset = Omniglot(root='./data', background=True, transform=transform, download=True)
train_loader = DataLoader(FewShotDataset(train_dataset, n_way=5, k_shot=5, k_query=15), batch_size=4, shuffle=True)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to ./data/omniglot-py/images_background.zip


100%|██████████| 9464212/9464212 [00:00<00:00, 131235704.69it/s]

Extracting ./data/omniglot-py/images_background.zip to ./data/omniglot-py





# Initialize the model, loss function, and optimizer


In [None]:
model = ProtoNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Training loop

In [None]:
# Training loop
for epoch in range(20):
    model.train()  # Set the model to training mode
    for support, support_labels, query, query_labels in train_loader:
        # Move data to GPU
        support = support.to(device)  # Shape: [batch_size, n_way, k_shot, 1, 28, 28]
        query = query.to(device)  # Shape: [batch_size, n_way, k_query, 1, 28, 28]
        support_labels = support_labels.to(device)  # Shape: [batch_size, n_way * k_shot]
        query_labels = query_labels.to(device)  # Shape: [batch_size, n_way * k_query]

        # Flatten the support and query sets to process in batches
        support = support.view(-1, *support.size()[2:])  # Shape: [batch_size * (n_way * k_shot), 1, 28, 28]
        query = query.view(-1, *query.size()[2:])  # Shape: [batch_size * (n_way * k_query), 1, 28, 28]

        # Encode support and query sets
        support_encoded = model(support)  # Shape: [batch_size * (n_way * k_shot), feature_dim]
        query_encoded = model(query)  # Shape: [batch_size * (n_way * k_query), feature_dim]

        # Compute prototypes and get distances to prototypes
        distances = compute_prototypes(support_encoded, support_labels, query_encoded)  # Shape: [batch_size * (n_way * k_query), n_way]

        # Convert distances to logits by negating them
        predictions = -distances  # Shape: [batch_size * (n_way * k_query), n_way]

        # Flatten query_labels to match batch size
        query_labels = query_labels.view(-1)  # Shape: [batch_size * (n_way * k_query)]

        # Compute the loss
        loss = criterion(predictions, query_labels)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

Epoch 1, Loss: 1.3597129583358765
Epoch 2, Loss: 1.2048062086105347
Epoch 3, Loss: 0.9956464767456055
Epoch 4, Loss: 0.7996284365653992
Epoch 5, Loss: 0.9966538548469543
Epoch 6, Loss: 0.9167237281799316
Epoch 7, Loss: 1.1162410974502563
Epoch 8, Loss: 0.5236536264419556
Epoch 9, Loss: 0.3805531859397888
Epoch 10, Loss: 0.6793351173400879
Epoch 11, Loss: 0.438449501991272
Epoch 12, Loss: 0.44927963614463806
Epoch 13, Loss: 0.9363599419593811
Epoch 14, Loss: 0.3109646737575531
Epoch 15, Loss: 0.40421977639198303
Epoch 16, Loss: 0.14029915630817413
Epoch 17, Loss: 0.11672678589820862
Epoch 18, Loss: 0.022993579506874084
Epoch 19, Loss: 0.09544401615858078
Epoch 20, Loss: 0.011652615070343
