In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
from sklearn.decomposition import PCA
from torchvision.models import resnet18

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
class ResNetSimple(nn.Module):
    def __init__(self, num_classes=2, dropout_rate=0.5, initial_feature_dim=50):
        super(ResNetSimple, self).__init__()
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        self.initial_feature_dim = initial_feature_dim

        self.define_layers()

    def define_layers(self):
        # Load the ResNet18 model
        backbone = resnet18(pretrained=False)  # Set pretrained=False as per your requirement
        # Remove the fully connected layer to use ResNet as a feature extractor
        modules = list(backbone.children())[:-1]  # Remove the last fully connected layer
        self.features = nn.Sequential(*modules)

        # Assuming the output feature size of ResNet18 is 512 * 1 * 1 after adaptive pooling
        self.projection = nn.Linear(512, self.initial_feature_dim).to(device)
        self.classifier = nn.Sequential(
            nn.Dropout(self.dropout_rate),
            nn.Linear(self.initial_feature_dim, self.num_classes)
        ).to(device)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)  # Flatten the features out
        x = self.projection(x)
        x = self.classifier(x)
        return x

    def adapt_projection(self, features):
        num_samples, num_features = features.size()
        actual_n_components = min(num_samples, num_features, self.initial_feature_dim)

        if actual_n_components > 0:
            data = features.detach().cpu().numpy()
            pca = PCA(n_components=actual_n_components)
            pca.fit(data)
            with torch.no_grad():
                new_projection = nn.Linear(num_features, actual_n_components).to(device)
                new_projection.weight.data = torch.tensor(pca.components_, dtype=torch.float32).to(device)

                pca_mean = torch.tensor(pca.mean_, dtype=torch.float32).to(device)
                bias_adjustment = -torch.matmul(pca_mean, new_projection.weight.t())

                new_projection.bias.data = bias_adjustment.squeeze()

                self.projection = new_projection
                self.classifier[1] = nn.Linear(actual_n_components, self.num_classes).to(device)
                self.to(device)


In [6]:
# Data preprocessing
def load_cifar10_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    return train_dataset, test_dataset


# Custom dataset class
class CustomCIFAR10(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

def sample_task(dataset, num_samples_per_class=5, used_indices=None, total_samples_per_class=25):
    if used_indices is None:
        used_indices = {}

    chosen_classes = random.sample(range(10), 2)  # Randomly select 2 classes from CIFAR-10
    class_map = {chosen_classes[i]: i for i in range(2)}  # Map selected classes to 0 and 1
    data = []
    labels = []
    indices_per_class = {class_map[chosen_classes[0]]: [], class_map[chosen_classes[1]]: []}

    for idx, (image, label) in enumerate(dataset):
        if label in chosen_classes and len(used_indices.get(label, [])) < total_samples_per_class:
            class_label = class_map[label]
            if len(indices_per_class[class_label]) < num_samples_per_class:
                if idx not in used_indices.get(label, []):
                    data.append(image)
                    labels.append(class_label)
                    indices_per_class[class_label].append(idx)
                    used_indices.setdefault(label, []).append(idx)
                if all(len(indices_per_class[c]) == num_samples_per_class for c in indices_per_class):
                    break

    return CustomCIFAR10(data, labels)


In [9]:
def meta_train(model, dataset, device, epochs=50, tasks_per_epoch=5, num_samples_per_class=5):
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        used_indices = {}  # This should be a dictionary now

        for _ in range(tasks_per_epoch):
            task_dataset = sample_task(dataset, num_samples_per_class, used_indices)
            task_loader = DataLoader(task_dataset, batch_size=10, shuffle=True)

            features = []
            labels = []
            for data, targets in task_loader:
                data = data.to(device)
                outputs = model.features(data)
                features.append(outputs.view(outputs.size(0), -1))
                labels.append(targets)

            features = torch.cat(features)
            labels = torch.cat(labels)

            # Adapt projection dynamically
            # print(model)
            model.adapt_projection(features)
            # print(model)
            optimizer = optim.Adam(model.parameters(), lr=0.001)  # Reinitialize the optimizer after adapting the model

            # Training on task-specific adapted model
            for data, targets in task_loader:
                data = data.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

            print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

# Initialize and train
model = ResNetSimple().to(device)
train_dataset, test_dataset = load_cifar10_data()
meta_train(model, train_dataset, device)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 84938376.77it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch 1, Loss: 1.8899400234222412
Epoch 1, Loss: 6.480332851409912
Epoch 1, Loss: 1.7794368267059326
Epoch 1, Loss: 5.455295085906982
Epoch 1, Loss: 2.5589866638183594
Epoch 2, Loss: 1.6316055059432983
Epoch 2, Loss: 5.5969343185424805
Epoch 2, Loss: 2.136298656463623
Epoch 2, Loss: 1.7889372110366821
Epoch 2, Loss: 3.1014115810394287
Epoch 3, Loss: 2.3236403465270996
Epoch 3, Loss: 5.886884689331055
Epoch 3, Loss: 3.0967020988464355
Epoch 3, Loss: 2.064586639404297
Epoch 3, Loss: 4.515196800231934
Epoch 4, Loss: 2.3893187046051025
Epoch 4, Loss: 7.7742919921875
Epoch 4, Loss: 2.0694055557250977
Epoch 4, Loss: 2.2880756855010986
Epoch 4, Loss: 0.3488388657569885
Epoch 5, Loss: 3.7408950328826904
Epoch 5, Loss: 9.733691215515137
Epoch 5, Loss: 1.4353702068328857
Epoch 5, Loss: 7.117983818054199
Epoch 5, Loss: 8.099514961242676
Epoch 6, Loss: 2.1255295276641846
Epoch 6, Loss: 1.9754788875579834
Epoch

In [10]:
def meta_test(model, dataset, device, num_tasks=5, num_samples_per_class=5, total_samples_per_class=25):
    model.eval()
    tasks_accuracy = []
    criterion = nn.CrossEntropyLoss()
    used_indices = {}

    with torch.no_grad():
        for _ in range(num_tasks):
            task_dataset = sample_task(dataset, num_samples_per_class, used_indices, total_samples_per_class)
            task_loader = DataLoader(task_dataset, batch_size=10, shuffle=False)

            total_loss = 0.0
            correct = 0
            total = 0

            for data, targets in task_loader:
                data = data.to(device)
                targets = targets.to(device)
                outputs = model(data)
                loss = criterion(outputs, targets)
                total_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

            avg_loss = total_loss / total
            accuracy = 100 * correct / total
            tasks_accuracy.append(accuracy)

    average_accuracy = sum(tasks_accuracy) / len(tasks_accuracy)
    print(f'Meta-Test Average Accuracy: {average_accuracy:.2f}% over {num_tasks} tasks')

# Meta-test the model to evaluate its performance on new tasks
meta_test(model, test_dataset, device)

Meta-Test Average Accuracy: 50.00% over 5 tasks
