In [7]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from src.few_shot_learning import load_cinic10, calculate_accuracy, plot_confusion_matrix

In [42]:
# Define a simple CNN model
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [43]:
# MAML Algorithm
class MAML(nn.Module):
    def __init__(self, model, num_inner_steps=5, lr_inner=0.01, lr_outer=0.001, num_classes=10):
        super(MAML, self).__init__()
        self.model = model
        self.num_inner_steps = num_inner_steps
        self.lr_inner = lr_inner
        self.lr_outer = lr_outer
        self.num_classes = num_classes
        self.loss_fn = nn.CrossEntropyLoss()
        # self.outer_optimizer = optim.Adam(self.model.parameters(), lr=self.lr_outer)
        self.outer_optimizer = optim.Adam(self.model.parameters(), lr=self.lr_outer, weight_decay=1e-5)

    def adapt(self, support_x, support_y):
        # Instantiate a new ResNet18 model
        model_copy = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        model_copy.fc = nn.Linear(model_copy.fc.in_features, self.num_classes)  # Adjust the final layer for your dataset
        
        # Copy weights from the original model (ResNet18) to the new model
        model_copy.load_state_dict(self.model.state_dict(), strict=False)

        # Set the model to training mode
        model_copy.train()

        # Define optimizer for inner loop
        # optimizer = optim.SGD(model_copy.parameters(), lr=self.lr_inner)
        optimizer = optim.Adam(model_copy.parameters(), lr=self.lr_inner)

        # Perform inner updates (adaptation)
        for _ in range(self.num_inner_steps):
            optimizer.zero_grad()
            predictions = model_copy(support_x)
            loss = self.loss_fn(predictions, support_y)
            loss.backward()
            optimizer.step()

        return model_copy

    def meta_train(self, dataloader, epochs=10):
        for epoch in range(epochs):
            total_meta_loss = 0
            total_correct = 0
            total_samples = 0

            for support_x, support_y in dataloader:
                query_x, query_y = support_x.clone(), support_y.clone()
                adapted_model = self.adapt(support_x, support_y)

                # Perform the meta-update
                self.outer_optimizer.zero_grad()
                query_predictions = adapted_model(query_x)
                meta_loss = self.loss_fn(query_predictions, query_y)
                total_meta_loss += meta_loss.item()

                # Calculate accuracy
                _, predicted = torch.max(query_predictions, 1)
                correct = (predicted == query_y).sum().item()
                total_correct += correct
                total_samples += query_y.size(0)

                meta_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                self.outer_optimizer.step()

            # Print loss and accuracy
            accuracy = total_correct / total_samples * 100
            print(f"Epoch {epoch+1}/{epochs}, Meta-Loss: {total_meta_loss/len(dataloader):.4f}, Accuracy: {accuracy:.2f}%")


In [44]:
model = CNN(num_classes=10)
maml = MAML(model)

In [45]:
data_dir = "../../data"
dataloader = load_cinic10(data_dir, few_shot_per_class=100, batch_size=128)

In [46]:
maml.meta_train(dataloader, epochs=30)

RuntimeError: Error(s) in loading state_dict for ResNet:
	size mismatch for conv1.weight: copying a param with shape torch.Size([32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

In [14]:
train_accuracy = calculate_accuracy(model, data_dir, split='train')

Accuracy on train set: 9.95%


In [47]:
from torchvision import models

class ResNet18MAML(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18MAML, self).__init__()
        self.resnet = models.resnet18(weights='IMAGENET1K_V1') # Use 'weights' instead of 'pretrained'
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)  # Change output layer

    def forward(self, x):
        return self.resnet(x)

In [48]:
model = ResNet18MAML(num_classes=10)
maml = MAML(model)

In [49]:
maml.meta_train(dataloader, epochs=30)

Epoch 1/30, Meta-Loss: 2.2694, Accuracy: 18.20%
Epoch 2/30, Meta-Loss: 2.2356, Accuracy: 16.60%
Epoch 3/30, Meta-Loss: 2.2384, Accuracy: 20.30%
Epoch 4/30, Meta-Loss: 2.3180, Accuracy: 22.20%
Epoch 5/30, Meta-Loss: 2.3077, Accuracy: 17.40%
Epoch 6/30, Meta-Loss: 2.3088, Accuracy: 17.00%
Epoch 7/30, Meta-Loss: 2.3772, Accuracy: 20.80%
Epoch 8/30, Meta-Loss: 2.1979, Accuracy: 23.30%
Epoch 9/30, Meta-Loss: 2.2891, Accuracy: 17.30%
Epoch 10/30, Meta-Loss: 2.2231, Accuracy: 21.50%
Epoch 11/30, Meta-Loss: 2.3260, Accuracy: 20.70%
Epoch 12/30, Meta-Loss: 2.3390, Accuracy: 19.70%
Epoch 13/30, Meta-Loss: 2.3027, Accuracy: 19.80%
Epoch 14/30, Meta-Loss: 2.3206, Accuracy: 18.20%
Epoch 15/30, Meta-Loss: 2.4277, Accuracy: 18.90%
Epoch 16/30, Meta-Loss: 2.1952, Accuracy: 23.20%
Epoch 17/30, Meta-Loss: 2.0985, Accuracy: 23.40%
Epoch 18/30, Meta-Loss: 2.3289, Accuracy: 17.70%
Epoch 19/30, Meta-Loss: 2.2675, Accuracy: 21.10%
Epoch 20/30, Meta-Loss: 2.2076, Accuracy: 23.50%
Epoch 21/30, Meta-Loss: 2.240