In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
from torchvision.models import resnet18

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Modify ResNet to fit the small data scenario
def create_small_resnet(num_classes=2, dropout_rate=0.5):
    model = resnet18(pretrained=False)
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(dropout_rate),
        nn.Linear(num_features, num_classes)
    )
    return model

In [None]:
# Load and prepare CIFAR-10 data only once
def load_cifar10_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    return train_dataset

# Custom dataset class
class CustomCIFAR10(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

def sample_task(dataset, num_samples_per_class=5, used_indices=None):
    if used_indices is None:
        used_indices = set()  # Use an empty set if none provided

    chosen_classes = random.sample(range(10), 2)  # Select two random classes
    class_map = {chosen_classes[i]: i for i in range(2)}
    data = []
    labels = []
    indices_per_class = {class_map[chosen_classes[0]]: [], class_map[chosen_classes[1]]: []}

    for idx, (image, label) in enumerate(dataset):
        if label in chosen_classes and idx not in used_indices:
            class_label = class_map[label]
            if len(indices_per_class[class_label]) < num_samples_per_class:
                data.append(image)
                labels.append(class_label)
                indices_per_class[class_label].append(idx)
                used_indices.add(idx)
            # Correct checking of completion for all class indices
            if all(len(indices_per_class[c]) == num_samples_per_class for c in indices_per_class):
                break

    return CustomCIFAR10(data, labels)


In [None]:
def meta_train(model, dataset, device, epochs=50, tasks_per_epoch=5, num_samples_per_class=5):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        used_indices = set()
        for _ in range(tasks_per_epoch):
            # Ensure the correct order and usage of arguments when calling sample_task
            task_dataset = sample_task(dataset, num_samples_per_class=num_samples_per_class, used_indices=used_indices)
            task_loader = DataLoader(task_dataset, batch_size=10, shuffle=True)

            for data, targets in task_loader:
                data, targets = data.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')



# Initialize datasets and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = load_cifar10_data()
model = create_small_resnet().to(device)

# Start training
meta_train(model, train_dataset, device)

Files already downloaded and verified
Epoch 1, Loss: 0.6195136904716492
Epoch 2, Loss: 1.529581069946289
Epoch 3, Loss: 0.8735393285751343
Epoch 4, Loss: 0.7255673408508301
Epoch 5, Loss: 0.6850746273994446
Epoch 6, Loss: 1.450914740562439
Epoch 7, Loss: 0.8798540234565735
Epoch 8, Loss: 1.5622708797454834
Epoch 9, Loss: 0.6429208517074585
Epoch 10, Loss: 0.7567564249038696
Epoch 11, Loss: 1.0010005235671997
Epoch 12, Loss: 0.9353543519973755
Epoch 13, Loss: 0.6908055543899536
Epoch 14, Loss: 0.7390526533126831
Epoch 15, Loss: 1.081953763961792
Epoch 16, Loss: 1.063437581062317
Epoch 17, Loss: 0.5718498229980469
Epoch 18, Loss: 1.3587660789489746
Epoch 19, Loss: 0.833337664604187
Epoch 20, Loss: 1.060869812965393
Epoch 21, Loss: 0.5736203193664551
Epoch 22, Loss: 1.050593376159668
Epoch 23, Loss: 0.6941630840301514
Epoch 24, Loss: 0.48005905747413635
Epoch 25, Loss: 0.8346608877182007
Epoch 26, Loss: 0.6782714128494263
Epoch 27, Loss: 0.8701767921447754
Epoch 28, Loss: 0.72374337911605

In [None]:
# Meta-testing function adapted for ResNet-18
def meta_test(model, dataset, device, num_classes=2, num_samples=25, n_inner_iter=5, inner_lr=0.01):
    accuracies = []
    for _ in range(10):  # Run several test tasks for better statistical measure
        task_dataset = sample_task(dataset, num_samples_per_class=num_samples)
        task_loader = DataLoader(task_dataset, batch_size=num_samples, shuffle=True)

        # Directly create a new ResNet-18 instance for adaptation
        adapted_model = create_small_resnet(num_classes=num_classes).to(device)
        adapted_model.load_state_dict(model.state_dict())
        optimizer = optim.SGD(adapted_model.parameters(), lr=inner_lr)

        # Adaptation phase: Fine-tune on the new task
        for data, targets in task_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = adapted_model(data)
            loss = nn.CrossEntropyLoss()(outputs, targets)
            loss.backward()
            optimizer.step()

        # Evaluation phase: Test the fine-tuned model
        correct = 0
        total = 0
        with torch.no_grad():
            for data, targets in task_loader:
                data, targets = data.to(device), targets.to(device)
                outputs = adapted_model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        accuracy = correct / total
        accuracies.append(accuracy)

    average_accuracy = np.mean(accuracies)
    print(f'Average Test Accuracy on new tasks: {average_accuracy * 100:.2f}%')
    return average_accuracy


# Prepare a test dataset (assuming classes not seen during training)
def load_cifar10_test_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    return test_dataset

# Load test dataset
test_dataset = load_cifar10_test_data()

# Evaluate the model on new tasks
meta_test(model, test_dataset, device, num_classes=2, num_samples=25, n_inner_iter=5, inner_lr=0.01)


Files already downloaded and verified
Average Test Accuracy on new tasks: 55.40%


0.5539999999999999