In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
from sklearn.decomposition import PCA

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2, dropout_rate=0.5, initial_feature_dim=50):
        super(SimpleCNN, self).__init__()
        # Store these parameters as instance variables so they can be used in other methods
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        self.initial_feature_dim = initial_feature_dim
        self.define_layers()

    def define_layers(self):
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Access the stored instance variables for dropout_rate and initial_feature_dim
        self.projection = nn.Linear(128 * 4 * 4, self.initial_feature_dim).to(device)
        self.classifier = nn.Sequential(
            nn.Dropout(self.dropout_rate),
            nn.Linear(self.initial_feature_dim, self.num_classes)
        ).to(device)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.projection(x)
        x = self.classifier(x)
        return x

    def adapt_projection(self, features):
        num_samples, num_features = features.size()
        actual_n_components = min(num_samples, num_features, self.initial_feature_dim)

        # print(num_samples, num_features)
        if actual_n_components > 0:
            data = features.detach().cpu().numpy()
            pca = PCA(n_components=actual_n_components)
            pca.fit(data)
            with torch.no_grad():
                new_projection = nn.Linear(num_features, actual_n_components).to(device)
                new_projection.weight.data = torch.tensor(pca.components_, dtype=torch.float32).to(device)

                # Bias adjustment using PCA projection and reconstruction
                pca_mean = torch.tensor(pca.mean_, dtype=torch.float32).to(device)
                bias_adjustment = -torch.matmul(pca_mean, new_projection.weight.t())  # Corrected line

                # Ensure correct shape for bias tensor
                new_projection.bias.data = bias_adjustment.squeeze()

                # Update model
                self.projection = new_projection
                self.classifier[1] = nn.Linear(actual_n_components, 2).to(device)
                self.to(device)  # Ensure the entire model is on the right device


In [14]:
# Data preprocessing
def load_cifar10_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    return train_dataset, test_dataset


# Custom dataset class
class CustomCIFAR10(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

def sample_task(dataset, num_samples_per_class=5, used_indices=None, total_samples_per_class=25):
    if used_indices is None:
        used_indices = {}

    chosen_classes = random.sample(range(10), 2)  # Randomly select 2 classes from CIFAR-10
    class_map = {chosen_classes[i]: i for i in range(2)}  # Map selected classes to 0 and 1
    data = []
    labels = []
    indices_per_class = {class_map[chosen_classes[0]]: [], class_map[chosen_classes[1]]: []}

    for idx, (image, label) in enumerate(dataset):
        if label in chosen_classes and len(used_indices.get(label, [])) < total_samples_per_class:
            class_label = class_map[label]
            if len(indices_per_class[class_label]) < num_samples_per_class:
                if idx not in used_indices.get(label, []):
                    data.append(image)
                    labels.append(class_label)
                    indices_per_class[class_label].append(idx)
                    used_indices.setdefault(label, []).append(idx)
                if all(len(indices_per_class[c]) == num_samples_per_class for c in indices_per_class):
                    break

    return CustomCIFAR10(data, labels)


In [15]:
def meta_train(model, dataset, device, epochs=50, tasks_per_epoch=5, num_samples_per_class=5):
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        used_indices = {}  # This should be a dictionary now

        for _ in range(tasks_per_epoch):
            task_dataset = sample_task(dataset, num_samples_per_class, used_indices)
            task_loader = DataLoader(task_dataset, batch_size=10, shuffle=True)

            features = []
            labels = []
            for data, targets in task_loader:
                data = data.to(device)
                outputs = model.features(data)
                features.append(outputs.view(outputs.size(0), -1))
                labels.append(targets)

            features = torch.cat(features)
            labels = torch.cat(labels)

            # Adapt projection dynamically
            # print(model)
            model.adapt_projection(features)
            # print(model)
            optimizer = optim.Adam(model.parameters(), lr=0.001)  # Reinitialize the optimizer after adapting the model

            # Training on task-specific adapted model
            for data, targets in task_loader:
                data = data.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

            print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

# Initialize and train
model = SimpleCNN().to(device)
train_dataset, test_dataset = load_cifar10_data()
meta_train(model, train_dataset, device)

Files already downloaded and verified
Files already downloaded and verified
Epoch 1, Loss: 0.8445464372634888
Epoch 1, Loss: 0.703630805015564
Epoch 1, Loss: 0.769275426864624
Epoch 1, Loss: 0.7249687314033508
Epoch 1, Loss: 0.7564619779586792
Epoch 2, Loss: 0.7281139492988586
Epoch 2, Loss: 0.7719835638999939
Epoch 2, Loss: 0.6546828746795654
Epoch 2, Loss: 0.686348021030426
Epoch 2, Loss: 0.7898918986320496
Epoch 3, Loss: 0.688485324382782
Epoch 3, Loss: 1.0316861867904663
Epoch 3, Loss: 0.5061514973640442
Epoch 3, Loss: 0.7642078995704651
Epoch 3, Loss: 0.6776653528213501
Epoch 4, Loss: 0.6736304759979248
Epoch 4, Loss: 0.6223071217536926
Epoch 4, Loss: 0.6925918459892273
Epoch 4, Loss: 0.7855917811393738
Epoch 4, Loss: 0.6933327913284302
Epoch 5, Loss: 0.6159733533859253
Epoch 5, Loss: 0.8442236185073853
Epoch 5, Loss: 1.0197415351867676
Epoch 5, Loss: 0.8348296880722046
Epoch 5, Loss: 0.7667015790939331
Epoch 6, Loss: 0.7032516598701477
Epoch 6, Loss: 0.7188277244567871
Epoch 6, L

In [16]:
def meta_test(model, dataset, device, num_tasks=5, num_samples_per_class=5, total_samples_per_class=25):
    model.eval()
    tasks_accuracy = []
    criterion = nn.CrossEntropyLoss()
    used_indices = {}

    with torch.no_grad():
        for _ in range(num_tasks):
            task_dataset = sample_task(dataset, num_samples_per_class, used_indices, total_samples_per_class)
            task_loader = DataLoader(task_dataset, batch_size=10, shuffle=False)

            total_loss = 0.0
            correct = 0
            total = 0

            for data, targets in task_loader:
                data = data.to(device)
                targets = targets.to(device)
                outputs = model(data)
                loss = criterion(outputs, targets)
                total_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

            avg_loss = total_loss / total
            accuracy = 100 * correct / total
            tasks_accuracy.append(accuracy)

    average_accuracy = sum(tasks_accuracy) / len(tasks_accuracy)
    print(f'Meta-Test Average Accuracy: {average_accuracy:.2f}% over {num_tasks} tasks')

# Meta-test the model to evaluate its performance on new tasks
meta_test(model, test_dataset, device)

Meta-Test Average Accuracy: 48.00% over 5 tasks
