In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from src.few_shot_learning import load_cinic10, calculate_accuracy, plot_confusion_matrix

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleCNN(nn.Module):
    def __init__(self, input_channels=3, output_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)  # Reduces spatial size by half

        # Forward pass once to compute flattened size dynamically
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_channels, 32, 32)  # Assuming CINIC-10 (32x32 images)
            dummy_output = self._get_conv_output(dummy_input)
            flattened_size = dummy_output.shape[1]

        self.fc1 = nn.Linear(flattened_size, 128)  # Adjust dynamically
        self.fc2 = nn.Linear(128, output_classes)

    def _get_conv_output(self, x):
        """Passes dummy input through conv layers to determine output size."""
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        return x.view(x.size(0), -1)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.relu(self.fc1(x))
        return self.fc2(x)


In [15]:
class MAML:
    def __init__(self, model, lr=0.001, meta_lr=0.001, n_shot=5, n_query=5, n_classes=5):
        self.model = model
        self.lr = lr  # Learning rate for task-specific updates
        self.meta_lr = meta_lr  # Learning rate for meta-learning updates
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_classes = n_classes
        self.meta_optimizer = optim.Adam(self.model.parameters(), lr=self.meta_lr)

    def inner_loop(self, model, images, labels, n_shot, n_query):
        """
        Performs a task-specific update on a few-shot task
        """
        # Task-specific update (fine-tuning)
        optimizer = optim.SGD(model.parameters(), lr=self.lr)
        optimizer.zero_grad()

        outputs = model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        return model

    def meta_update(self, task_models, task_images, task_labels):
        """
        Performs the meta-update, accumulating gradients from multiple tasks.
        """
        self.meta_optimizer.zero_grad()  # Reset gradients before accumulating new ones
    
        total_loss = 0
        for task_model, images, labels in zip(task_models, task_images, task_labels):
            task_model = self.inner_loop(task_model, images, labels, self.n_shot, self.n_query)
    
            # Calculate loss
            outputs = task_model(images)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss = total_loss + loss
    
        # Backpropagate the accumulated gradients
        total_loss.backward()  # No need for retain_graph=True
        self.meta_optimizer.step()


    def train(self, train_loader, epochs=10):
        for epoch in range(epochs):
            task_models = []
            task_images = []
            task_labels = []

            for batch_idx, (images, labels) in enumerate(train_loader):
                images, labels = images.cuda(), labels.cuda()

                # Sample tasks and generate task models
                task_models.append(self.model)
                task_images.append(images)
                task_labels.append(labels)

                if len(task_models) == self.n_classes:
                    self.meta_update(task_models, task_images, task_labels)

            print(f"Epoch [{epoch+1}/{epochs}] completed.")

In [19]:
model = SimpleCNN().cuda()
maml = MAML(model, lr=0.01, meta_lr=0.001, n_shot=5, n_query=5, n_classes=5)

In [20]:
data_dir = "../../data"
dataloader = load_cinic10(data_dir, few_shot_per_class=100, batch_size=128)

In [21]:
torch.autograd.set_detect_anomaly(True)
maml.train(dataloader, epochs=10)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn