***Challenge 1***

Here the goal is to train on 25 samples. In this preliminary testbed the evaluation will be done on a 2000 sample validation set. Note in the end the final evaluation will be done on the full CIFAR-10 test set as well as potentially a separate dataset. The validation samples here should not be used for training in any way, the final evaluation will provide only random samples of 25 from a datasource that is not the CIFAR-10 training data.

Feel free to modify this testbed to your liking, including the normalization transformations etc. Note however the final evaluation testbed will have a rigid set of components where you will need to place your answer. The only constraint is the data. Refer to the full project instructions for more information.


Setup training functions. Again you are free to fully modify this testbed in your prototyping within the constraints of the data used. You can use tools outside of pytorch for training models if desired as well although the torchvision dataloaders will still be useful for interacting with the cifar-10 dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset, random_split
import random
import numpy as np
from torchvision.models import resnet18

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def euclidean_dist(x, y):
    return torch.cdist(x, y)

def prototypical_loss(support_features, query_features, support_labels, query_labels):
    # Calculate prototypes as the mean of support features by class
    unique_labels = torch.unique(support_labels)
    prototypes = torch.stack([support_features[support_labels == label].mean(0) for label in unique_labels])

    # Calculate distances from query features to prototypes
    dists = euclidean_dist(query_features, prototypes)

    # Use log_softmax and negative log likelihood loss
    log_p_y = torch.nn.functional.log_softmax(-dists, dim=1)
    loss_val = torch.nn.functional.nll_loss(log_p_y, query_labels)
    return loss_val


In [None]:
def load_cifar10_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    return train_dataset

def load_cifar10_test_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    return test_dataset

class CustomCIFAR10(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

def sample_task(dataset, num_samples_per_class=5, num_query_per_class=15, used_indices=None):
    if used_indices is None:
        used_indices = set()

    chosen_classes = random.sample(range(10), 2)  # Select two random classes
    class_map = {chosen_classes[i]: i for i in range(2)}
    support_data = []
    query_data = []
    support_labels = []
    query_labels = []
    indices_per_class = {class_map[chosen_classes[0]]: [], class_map[chosen_classes[1]]: []}

    for idx, (image, label) in enumerate(dataset):
        if label in chosen_classes and idx not in used_indices:
            class_label = class_map[label]
            if len(indices_per_class[class_label]) < num_samples_per_class + num_query_per_class:
                if len(indices_per_class[class_label]) < num_samples_per_class:
                    support_data.append(image)
                    support_labels.append(class_label)
                else:
                    query_data.append(image)
                    query_labels.append(class_label)
                indices_per_class[class_label].append(idx)
                used_indices.add(idx)
            if all(len(indices_per_class[c]) == num_samples_per_class + num_query_per_class for c in indices_per_class):
                break

    return CustomCIFAR10(support_data, support_labels), CustomCIFAR10(query_data, query_labels)


In [None]:
def train(model, dataset, device, epochs=50, tasks_per_epoch=5, num_support=5, num_query=15):
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    # optimizer = torch.optim.SGD(model.parameters(),
    #                           lr=0.01, momentum=0.9,
    #                           weight_decay=0.0005)
    for epoch in range(epochs):
        model.train()
        used_indices = set()
        total_loss = 0
        for _ in range(tasks_per_epoch):
            support_set, query_set = sample_task(dataset, num_samples_per_class=num_support, num_query_per_class=num_query, used_indices=used_indices)
            support_loader = DataLoader(support_set, batch_size=len(support_set), shuffle=True)
            query_loader = DataLoader(query_set, batch_size=len(query_set), shuffle=True)

            support_data, support_labels = next(iter(support_loader))
            query_data, query_labels = next(iter(query_loader))

            support_data, support_labels = support_data.to(device), support_labels.to(device)
            query_data, query_labels = query_data.to(device), query_labels.to(device)

            optimizer.zero_grad()
            support_features = model(support_data)
            query_features = model(query_data)
            loss = prototypical_loss(support_features, query_features, support_labels, query_labels)  # Updated to remove the extra argument
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch {epoch + 1}, Average Loss: {total_loss / tasks_per_epoch}')


def test(model, dataset, device, num_samples_per_class=5, num_support=3, num_query=2, num_tasks=10):
    accuracies = []
    model.eval()  # Set the model to evaluation mode
    used_indices = set()  # Initialize here to track indices across tasks

    with torch.no_grad():
        for _ in range(num_tasks):
            # Sample a new task
            support_set, query_set = sample_task(dataset, num_samples_per_class=num_samples_per_class, num_query_per_class=num_query, used_indices=used_indices)
            support_loader = DataLoader(support_set, batch_size=len(support_set), shuffle=True)
            query_loader = DataLoader(query_set, batch_size=len(query_set), shuffle=False)

            # Get data for support and query sets
            support_data, support_labels = next(iter(support_loader))
            query_data, query_labels = next(iter(query_loader))

            support_data, support_labels = support_data.to(device), support_labels.to(device)
            query_data, query_labels = query_data.to(device), query_labels.to(device)

            # Calculate prototypes
            support_features = model(support_data)
            unique_labels = torch.unique(support_labels)
            prototypes = torch.stack([support_features[support_labels == label].mean(0) for label in unique_labels])

            # Evaluate on query data
            query_features = model(query_data)
            dists = euclidean_dist(query_features, prototypes)
            _, predicted = torch.min(dists, 1)

            # Calculate accuracy
            correct = (predicted == query_labels).sum().item()
            total = query_labels.size(0)
            accuracy = correct / total
            accuracies.append(accuracy)

    average_accuracy = np.mean(accuracies)
    print(f'Average Test Accuracy on new tasks: {average_accuracy * 100:.2f}%')
    return average_accuracy * 100



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18


class Net(nn.Module):
    def __init__(self, num_classes=2, dropout_rate=0.5):
        super(Net, self).__init__()
        self.resnet = resnet18(pretrained=False)
        num_features = self.resnet.fc.in_features

        self.resnet.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)

The below tries  2 random problem instances. In your development you may choose to prototype with 1 problem instances but keep in mind for small sample problems the variance is high so continously evaluating on several subsets will be important.

In [None]:
from numpy.random import RandomState
import numpy as np
import torch.optim as optim
from torch.utils.data import Subset
import time


accs = []
times = []


for seed in range(1, 5):
  train_dataset = load_cifar10_data()
  val_dataset = load_cifar10_test_data()

  model = Net()
  model.to(device)

  start_time = time.time()
  train(model, train_dataset, device, 100)
  end_time = time.time()

  times.append(end_time - start_time)

  accuracy = test(model, val_dataset, device, num_samples_per_class=5, num_support=5, num_query=25, num_tasks=8)
  accs.append(accuracy)

times = np.array(times)
accs = np.array(accs)
print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))
print(f"Average Time over 5 instances: {times.mean()} +-{times.std()}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 15663952.18it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




Epoch 1, Average Loss: 0.6958709239959717
Epoch 2, Average Loss: 0.485071262717247
Epoch 3, Average Loss: 1.2803401708602906
Epoch 4, Average Loss: 1.0569176435470582
Epoch 5, Average Loss: 0.67303067445755
Epoch 6, Average Loss: 0.7610782265663147
Epoch 7, Average Loss: 0.6956743121147155
Epoch 8, Average Loss: 0.7018378555774689
Epoch 9, Average Loss: 0.7004133641719819
Epoch 10, Average Loss: 0.7237789988517761
Epoch 11, Average Loss: 0.6578987300395965
Epoch 12, Average Loss: 0.7245121717453002
Epoch 13, Average Loss: 0.7154530882835388
Epoch 14, Average Loss: 0.6710154414176941
Epoch 15, Average Loss: 0.62708580493927
Epoch 16, Average Loss: 0.594921863079071
Epoch 17, Average Loss: 0.6030842363834381
Epoch 18, Average Loss: 0.5240469455718995
Epoch 19, Average Loss: 0.674282306432724
Epoch 20, Average Loss: 0.4921069726347923
Epoch 21, Average Loss: 0.6302272260189057
Epoch 22, Average Loss: 0.47578747272491456
Epoch 23, Average Loss: 0.6639476478099823
Epoch 24, Average Loss: 0.