In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import DataLoader, Subset
import os
import random
from src.helper_functions import train_step, test_step, accuracy_fn

In [2]:
# Define the CNN backbone
class CNNFewShot(nn.Module):
    def __init__(self, num_classes=10):
        super(CNNFewShot, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [6]:
# Load CINIC-10 dataset
def load_cinic10(data_root, split='train', few_shot_per_class=10, batch_size=16):
    data_dir = os.path.join(data_root, "cinic-10", split)

    transform = transforms.Compose([
        transforms.Resize((32, 32)),  # Ensure images are 32x32
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    
    # Few-shot sampling
    class_indices = {label: [] for label in range(10)}
    for idx, (_, label) in enumerate(dataset.samples):
        class_indices[label].append(idx)

    few_shot_indices = []
    for indices in class_indices.values():
        few_shot_indices.extend(random.sample(indices, min(len(indices), few_shot_per_class)))  # Handle cases where a class has fewer than few_shot_per_class images

    few_shot_dataset = Subset(dataset, few_shot_indices)
    dataloader = DataLoader(few_shot_dataset, batch_size=batch_size, shuffle=True)

    return dataloader

In [4]:
# Train the CNN on few-shot data
def train_few_shot(model, dataloader, epochs=10, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

In [10]:
data_dir = "../data"
dataloader = load_cinic10(data_dir, few_shot_per_class=10)
model = CNNFewShot(num_classes=10)
train_few_shot(model, dataloader)

Epoch 1/10, Loss: 16.2679
Epoch 2/10, Loss: 16.0916
Epoch 3/10, Loss: 16.0031
Epoch 4/10, Loss: 15.6597
Epoch 5/10, Loss: 15.1344
Epoch 6/10, Loss: 14.4928
Epoch 7/10, Loss: 14.3576
Epoch 8/10, Loss: 14.4622
Epoch 9/10, Loss: 13.9526
Epoch 10/10, Loss: 12.3100
