In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from src.few_shot_learning import load_cinic10, calculate_accuracy, plot_confusion_matrix
from torch.optim.lr_scheduler import StepLR

In [2]:
class ProtoConvNeXt(nn.Module):
    def __init__(self, feature_dim=128):
        super(ProtoConvNeXt, self).__init__()
        self.backbone = timm.create_model('convnext_tiny', pretrained=True)  # Load ConvNeXt
        self.backbone.head.fc = nn.Identity()  # Remove classification layer
        self.projection = nn.Linear(768, feature_dim)  # Map to embedding space

    def forward(self, x):
        features = self.backbone(x)  # Extract features
        return self.projection(features)  # Reduce to feature_dim

In [23]:
import torch.nn.functional as F

def prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels, num_classes):

    support_embeddings = F.normalize(support_embeddings, p=2, dim=1)  # L2 Normalisation
    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)

    prototypes = []
    for cls in range(num_classes):
        cls_mask = (support_labels == cls)
        
        if cls_mask.sum() == 0:  # Avoid empty classes
            prototypes.append(torch.zeros((1, support_embeddings.shape[1]), device=support_embeddings.device))
        else:
            cls_proto = support_embeddings[cls_mask].mean(dim=0, keepdim=True)  # Keep [1, feature_dim]
            prototypes.append(cls_proto)

    prototypes = torch.cat(prototypes, dim=0)  # Stack into (num_classes, feature_dim)

    # Compute squared Euclidean distance
    dists = torch.cdist(query_embeddings, prototypes, p=2)

    # Convert distances to probabilities
    log_p_y = F.log_softmax(-dists, dim=1)  # Negative distance as similarity
    loss = F.nll_loss(log_p_y, query_labels)

    preds = log_p_y.argmax(dim=1)
    acc = (preds == query_labels).float().mean().item()

    return loss, acc


In [24]:
def train_prototypical(model, dataloader, epochs=10, lr=0.0001, num_classes=10, N_shot=5, N_query=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

    for epoch in range(epochs):
        model.train()
        total_loss, total_acc = 0, 0

        for images, labels in dataloader:  # Assume balanced batch per class
            images, labels = images.to(device), labels.to(device)

            # Randomly sample support and query sets
            indices = torch.randperm(images.size(0))
            support_idx, query_idx = indices[:N_shot*num_classes], indices[N_shot*num_classes:]
            
            support_images, support_labels = images[support_idx], labels[support_idx]
            query_images, query_labels = images[query_idx], labels[query_idx]

            # Forward pass
            support_embeddings = model(support_images)
            query_embeddings = model(query_images)

            # Compute prototypical loss
            loss, acc = prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels, num_classes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc

        scheduler.step()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}, Acc: {total_acc/len(dataloader):.4f}")


In [25]:
data_dir = "../../data"
dataloader = load_cinic10(data_dir, few_shot_per_class=100, batch_size=128)
model = ProtoConvNeXt()
train_prototypical(model, dataloader, epochs=30, lr=0.0001, num_classes=10, N_shot=10, N_query=5)

Epoch [1/30], Loss: 16.2234, Acc: 0.3750
Epoch [2/30], Loss: 14.6931, Acc: 0.3393
Epoch [3/30], Loss: 14.8907, Acc: 0.4375
Epoch [4/30], Loss: 13.9202, Acc: 0.4330
Epoch [5/30], Loss: 13.5166, Acc: 0.5268
Epoch [6/30], Loss: 12.9577, Acc: 0.5804
Epoch [7/30], Loss: 12.6123, Acc: 0.5625
Epoch [8/30], Loss: 12.2977, Acc: 0.6563
Epoch [9/30], Loss: 12.3473, Acc: 0.6250
Epoch [10/30], Loss: 12.0317, Acc: 0.6786
Epoch [11/30], Loss: 11.9193, Acc: 0.7277
Epoch [12/30], Loss: 11.7431, Acc: 0.6741
Epoch [13/30], Loss: 11.6248, Acc: 0.7009
Epoch [14/30], Loss: 11.2430, Acc: 0.7500
Epoch [15/30], Loss: 11.4542, Acc: 0.7902
Epoch [16/30], Loss: 11.3084, Acc: 0.7768
Epoch [17/30], Loss: 11.2786, Acc: 0.8170
Epoch [18/30], Loss: 10.8633, Acc: 0.8661
Epoch [19/30], Loss: 10.8716, Acc: 0.8571
Epoch [20/30], Loss: 10.9152, Acc: 0.8304
Epoch [21/30], Loss: 10.8023, Acc: 0.8482
Epoch [22/30], Loss: 10.7338, Acc: 0.8750
Epoch [23/30], Loss: 10.8194, Acc: 0.8616
Epoch [24/30], Loss: 10.5974, Acc: 0.9330
E

In [13]:
def calculate_prototypical_accuracy(model, data_root, split='test', batch_size=32, num_classes=10, N_shot=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    test_loader = load_cinic10(data_root, split=split, few_shot_per_class=1000, batch_size=batch_size)

    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Randomly split into support & query
            indices = torch.randperm(images.size(0))
            support_idx, query_idx = indices[:N_shot*num_classes], indices[N_shot*num_classes:]

            support_images, support_labels = images[support_idx], labels[support_idx]
            query_images, query_labels = images[query_idx], labels[query_idx]

            # Get embeddings
            support_embeddings = model(support_images)
            query_embeddings = model(query_images)

            # Compute prototypes
            prototypes = []
            for c in range(num_classes):
                class_mask = support_labels == c
                if class_mask.sum() > 0:  # Avoid empty classes
                    class_prototype = support_embeddings[class_mask].mean(0)
                    prototypes.append(class_prototype)

            prototypes = torch.stack(prototypes)  # Shape: [num_classes, embedding_dim]

            # Compute distances
            distances = torch.cdist(query_embeddings, prototypes)  # Shape: [num_query, num_classes]
            predicted_labels = torch.argmin(distances, dim=1)  # Nearest prototype = predicted class

            correct += (predicted_labels == query_labels).sum().item()
            total += query_labels.size(0)

    accuracy = (correct / total) * 100 if total > 0 else 0
    print(f"Accuracy on {split} set: {accuracy:.2f}%")
    return accuracy


In [27]:
calculate_prototypical_accuracy(model, data_dir, split='train', batch_size=128, num_classes=10, N_shot=10)

Accuracy on train set: 47.30%


47.298534798534796

In [28]:
calculate_prototypical_accuracy(model, data_dir, split='test', batch_size=128, num_classes=10, N_shot=10)

Accuracy on test set: 45.51%


45.51282051282051