In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
from torchvision.models import resnet18

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Define the model
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.act1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(16 * 16 * 16, 2)  # Adjusting to 2 output classes for binary classification

    def forward(self, x):
        x = self.pool(self.act1(self.conv1(x)))
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [None]:
# Load CIFAR-10 data
def load_cifar10_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    return train_dataset

class CustomCIFAR10(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

def sample_task(dataset, num_samples_per_class=5, num_query_per_class=15, used_indices=None):
    if used_indices is None:
        used_indices = set()

    chosen_classes = random.sample(range(10), 2)  # Select two random classes
    class_map = {chosen_classes[i]: i for i in range(2)}
    support_data = []
    query_data = []
    support_labels = []
    query_labels = []
    indices_per_class = {class_map[chosen_classes[0]]: [], class_map[chosen_classes[1]]: []}

    for idx, (image, label) in enumerate(dataset):
        if label in chosen_classes and idx not in used_indices:
            class_label = class_map[label]
            if len(indices_per_class[class_label]) < num_samples_per_class + num_query_per_class:
                if len(indices_per_class[class_label]) < num_samples_per_class:
                    support_data.append(image)
                    support_labels.append(class_label)
                else:
                    query_data.append(image)
                    query_labels.append(class_label)
                indices_per_class[class_label].append(idx)
                used_indices.add(idx)
            if all(len(indices_per_class[c]) == num_samples_per_class + num_query_per_class for c in indices_per_class):
                break

    return CustomCIFAR10(support_data, support_labels), CustomCIFAR10(query_data, query_labels)


In [None]:
def euclidean_dist(x, y):
    return torch.cdist(x, y)

def prototypical_loss(support_features, query_features, support_labels, query_labels):
    # Calculate prototypes as the mean of support features by class
    unique_labels = torch.unique(support_labels)
    prototypes = torch.stack([support_features[support_labels == label].mean(0) for label in unique_labels])

    # Calculate distances from query features to prototypes
    dists = euclidean_dist(query_features, prototypes)

    # Use log_softmax and negative log likelihood loss
    log_p_y = torch.nn.functional.log_softmax(-dists, dim=1)
    loss_val = torch.nn.functional.nll_loss(log_p_y, query_labels)
    return loss_val



def meta_train(model, dataset, device, epochs=50, tasks_per_epoch=5, num_support=5, num_query=15):
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        used_indices = set()
        total_loss = 0
        for _ in range(tasks_per_epoch):
            support_set, query_set = sample_task(dataset, num_samples_per_class=num_support, num_query_per_class=num_query, used_indices=used_indices)
            support_loader = DataLoader(support_set, batch_size=len(support_set), shuffle=True)
            query_loader = DataLoader(query_set, batch_size=len(query_set), shuffle=True)

            support_data, support_labels = next(iter(support_loader))
            query_data, query_labels = next(iter(query_loader))

            support_data, support_labels = support_data.to(device), support_labels.to(device)
            query_data, query_labels = query_data.to(device), query_labels.to(device)

            optimizer.zero_grad()
            support_features = model(support_data)
            query_features = model(query_data)
            loss = prototypical_loss(support_features, query_features, support_labels, query_labels)  # Updated to remove the extra argument
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch {epoch + 1}, Average Loss: {total_loss / tasks_per_epoch}')

# Initialize datasets and model
train_dataset = load_cifar10_data()
model = ConvNet().to(device)

# Start training
meta_train(model, train_dataset, device)

Files already downloaded and verified
Epoch 1, Average Loss: 0.6769265294075012
Epoch 2, Average Loss: 0.6539146900177002
Epoch 3, Average Loss: 0.6053136825561524
Epoch 4, Average Loss: 0.6480699181556702
Epoch 5, Average Loss: 0.5699179947376252
Epoch 6, Average Loss: 0.5340400099754333
Epoch 7, Average Loss: 0.46629777550697327
Epoch 8, Average Loss: 0.5943116784095764
Epoch 9, Average Loss: 0.40581966638565065
Epoch 10, Average Loss: 0.5645058929920197
Epoch 11, Average Loss: 0.5680080652236938
Epoch 12, Average Loss: 0.4537735223770142
Epoch 13, Average Loss: 0.43359091877937317
Epoch 14, Average Loss: 0.5606654942035675
Epoch 15, Average Loss: 0.5134420335292816
Epoch 16, Average Loss: 0.4687359631061554
Epoch 17, Average Loss: 0.45836056470870973
Epoch 18, Average Loss: 0.426683235168457
Epoch 19, Average Loss: 0.4798527956008911
Epoch 20, Average Loss: 0.33434508740901947
Epoch 21, Average Loss: 0.5170858144760132
Epoch 22, Average Loss: 0.3999248817563057
Epoch 23, Average Los

In [None]:
def meta_test(model, dataset, device, num_samples_per_class=5, num_support=3, num_query=2, num_tasks=10):
    accuracies = []
    model.eval()  # Set the model to evaluation mode
    used_indices = set()  # Initialize here to track indices across tasks

    with torch.no_grad():
        for _ in range(num_tasks):
            # Sample a new task
            support_set, query_set = sample_task(dataset, num_samples_per_class=num_samples_per_class, num_query_per_class=num_query, used_indices=used_indices)
            support_loader = DataLoader(support_set, batch_size=len(support_set), shuffle=True)
            query_loader = DataLoader(query_set, batch_size=len(query_set), shuffle=False)

            # Get data for support and query sets
            support_data, support_labels = next(iter(support_loader))
            query_data, query_labels = next(iter(query_loader))

            support_data, support_labels = support_data.to(device), support_labels.to(device)
            query_data, query_labels = query_data.to(device), query_labels.to(device)

            # Calculate prototypes
            support_features = model(support_data)
            unique_labels = torch.unique(support_labels)
            prototypes = torch.stack([support_features[support_labels == label].mean(0) for label in unique_labels])

            # Evaluate on query data
            query_features = model(query_data)
            dists = euclidean_dist(query_features, prototypes)
            _, predicted = torch.min(dists, 1)

            # Calculate accuracy
            correct = (predicted == query_labels).sum().item()
            total = query_labels.size(0)
            accuracy = correct / total
            accuracies.append(accuracy)

    average_accuracy = np.mean(accuracies)
    print(f'Average Test Accuracy on new tasks: {average_accuracy * 100:.2f}%')
    return average_accuracy


# Prepare a test dataset (assuming classes not seen during training)
def load_cifar10_test_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    return test_dataset


# Load test dataset
test_dataset = load_cifar10_test_data()

# Evaluate the model on new tasks
meta_test(model, test_dataset, device, num_samples_per_class=5, num_support=3, num_query=2, num_tasks=5)


Files already downloaded and verified
Average Test Accuracy on new tasks: 70.00%


0.7