In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from src.few_shot_learning import load_cinic10, calculate_accuracy, plot_confusion_matrix
from torch.optim.lr_scheduler import StepLR

In [10]:
class ProtoCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(ProtoCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc = nn.Linear(128 * 32 * 32, 128)  # Assuming 32x32 image size
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        embeddings = self.fc(x)  # Output embeddings
        return embeddings


In [3]:
def train_prototypical(model, dataloader, epochs=10, lr=0.0001, num_classes=10, N_shot=5, N_query=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss, total_acc = 0, 0

        for images, labels in dataloader:
            print("Unique labels:", torch.unique(labels))
            print("Min label:", labels.min().item(), "Max label:", labels.max().item())
            print("Expected range: 0 to", num_classes - 1)
            images, labels = images.to(device), labels.to(device)

            # Select support & query sets
            indices = torch.randperm(images.size(0))
            support_idx, query_idx = indices[:N_shot*num_classes], indices[N_shot*num_classes:]

            support_images, support_labels = images[support_idx], labels[support_idx]
            query_images, query_labels = images[query_idx], labels[query_idx]

            # Extract embeddings
            support_embeddings = model(support_images)
            query_embeddings = model(query_images)

            # Compute prototypes
            prototypes = []
            for c in range(num_classes):
                class_mask = support_labels == c
                if class_mask.sum() > 0:
                    class_prototype = support_embeddings[class_mask].mean(0)
                    prototypes.append(class_prototype)

            prototypes = torch.stack(prototypes)  # [num_classes, embedding_dim]

            # Compute distances
            distances = torch.cdist(query_embeddings, prototypes)  # [num_query, num_classes]
            pred_labels = torch.argmin(distances, dim=1)

            # Compute loss (negative log-probability)
            loss = F.cross_entropy(-distances, query_labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += (pred_labels == query_labels).float().mean().item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}, Acc: {total_acc / len(dataloader):.4f}")

In [4]:
def test_prototypical(model, dataloader, num_classes=10, N_shot=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:

            images, labels = images.to(device), labels.to(device)

            indices = torch.randperm(images.size(0))
            support_idx, query_idx = indices[:N_shot*num_classes], indices[N_shot*num_classes:]

            support_images, support_labels = images[support_idx], labels[support_idx]
            query_images, query_labels = images[query_idx], labels[query_idx]

            support_embeddings = model(support_images)
            query_embeddings = model(query_images)

            prototypes = []
            for c in range(num_classes):
                class_mask = support_labels == c
                if class_mask.sum() > 0:
                    class_prototype = support_embeddings[class_mask].mean(0)
                    prototypes.append(class_prototype)

            prototypes = torch.stack(prototypes)
            distances = torch.cdist(query_embeddings, prototypes)
            pred_labels = torch.argmin(distances, dim=1)

            correct += (pred_labels == query_labels).sum().item()
            total += query_labels.size(0)

    acc = (correct / total) * 100 if total > 0 else 0
    print(f"Test Accuracy: {acc:.2f}%")
    return acc

In [5]:
data_dir = "../../data"
train_loader = load_cinic10(data_dir, split="train", few_shot_per_class=100, batch_size=128)
test_loader = load_cinic10(data_dir, split="test", few_shot_per_class=100, batch_size=128)
model = ProtoCNN()

In [11]:
train_prototypical(model, train_loader, epochs=10, lr=0.0001, num_classes=10, N_shot=5)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [7]:
import torch.nn.functional as F

def prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels, num_classes):

    support_embeddings = F.normalize(support_embeddings, p=2, dim=1)  # L2 Normalisation
    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)

    prototypes = []
    for cls in range(num_classes):
        cls_mask = (support_labels == cls)
        
        if cls_mask.sum() == 0:  # Avoid empty classes
            prototypes.append(torch.zeros((1, support_embeddings.shape[1]), device=support_embeddings.device))
        else:
            cls_proto = support_embeddings[cls_mask].mean(dim=0, keepdim=True)  # Keep [1, feature_dim]
            prototypes.append(cls_proto)

    prototypes = torch.cat(prototypes, dim=0)  # Stack into (num_classes, feature_dim)

    # Compute squared Euclidean distance
    dists = torch.cdist(query_embeddings, prototypes, p=2)

    # Convert distances to probabilities
    log_p_y = F.log_softmax(-dists, dim=1)  # Negative distance as similarity
    loss = F.nll_loss(log_p_y, query_labels)

    preds = log_p_y.argmax(dim=1)
    acc = (preds == query_labels).float().mean().item()

    return loss, acc


In [13]:
def train_prototypical(model, dataloader, epochs=10, lr=0.0001, num_classes=10, N_shot=5, N_query=5):
    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

    for epoch in range(epochs):
        model.train()
        total_loss, total_acc = 0, 0

        for images, labels in dataloader:  # Assume balanced batch per class
            images, labels = images.to(device), labels.to(device)

            # Randomly sample support and query sets
            indices = torch.randperm(images.size(0))
            support_idx, query_idx = indices[:N_shot*num_classes], indices[N_shot*num_classes:]
            
            support_images, support_labels = images[support_idx], labels[support_idx]
            query_images, query_labels = images[query_idx], labels[query_idx]

            # Forward pass
            support_embeddings = model(support_images)
            query_embeddings = model(query_images)

            # Compute prototypical loss
            loss, acc = prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels, num_classes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc

        scheduler.step()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}, Acc: {total_acc/len(dataloader):.4f}")


In [14]:
data_dir = "../../data"
dataloader = load_cinic10(data_dir, few_shot_per_class=100, batch_size=128)
model = ProtoCNN()
train_prototypical(model, dataloader, epochs=30, lr=0.0001, num_classes=10, N_shot=10, N_query=5)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
