In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from src.few_shot_learning import load_cinic10, calculate_accuracy, plot_confusion_matrix
from torch.optim.lr_scheduler import StepLR

In [2]:
# Define the basic architecture for the subnetwork (the same one used for both inputs)
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc = nn.Linear(128 * 32 * 32, 128)  # Final embedding dimension
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        embedding = self.fc(x)
        return embedding

In [3]:
# Contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, p=2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + 
                          (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

In [4]:
def train_siamese_network(model, dataloader, epochs=10, lr=0.0001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = ContrastiveLoss(margin=1.0)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for data1, data2, labels in dataloader:
            data1, data2, labels = data1.to(device), data2.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            output1 = model(data1)
            output2 = model(data2)
            
            # Calculate contrastive loss
            loss = criterion(output1, output2, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}")

In [5]:
import random
from torchvision import datasets, transforms
from torch.utils.data import Dataset

class FewShotSiameseDataset(Dataset):
    def __init__(self, dataset, n_shot, n_query, n_classes, transform=None):
        """
        Initialize the FewShotSiameseDataset.

        dataset (torchvision dataset): Dataset to sample from (e.g., CIFAR-10)
        n_shot (int): Number of support samples per class (5)
        n_query (int): Number of query samples per class (5)
        n_classes (int): Number of classes in the few-shot episode
        transform (callable): Optional transform to be applied on an image.
        """
        self.dataset = dataset
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_classes = n_classes
        self.transform = transform

        # Create class-to-images mapping
        self.class_to_images = {}
        for img, label in dataset:
            if label not in self.class_to_images:
                self.class_to_images[label] = []
            self.class_to_images[label].append(img)

        # Limit the dataset to only `n_classes`
        self.selected_classes = random.sample(list(self.class_to_images.keys()), n_classes)

        # Ensure each class has at least 10 images (adjusting for few-shot)
        for class_idx in self.selected_classes:
            assert len(self.class_to_images[class_idx]) >= 10, f"Class {class_idx} has less than 10 images."

    def __len__(self):
        # Number of few-shot episodes (here we generate 1000 episodes)
        return 1000

    def __getitem__(self, index):
        # Randomly sample `n_classes` and then `n_shot` support images and `n_query` query images
        selected_classes = random.sample(self.selected_classes, self.n_classes)

        support_images, support_labels = [], []
        query_images, query_labels = [], []

        # Get support and query images for each class
        for class_idx in selected_classes:
            images = self.class_to_images[class_idx]
            random.shuffle(images)

            # Select 10 images for each class
            support_images.extend(images[:self.n_shot])
            support_labels.extend([class_idx] * self.n_shot)
            query_images.extend(images[self.n_shot:self.n_shot + self.n_query])
            query_labels.extend([class_idx] * self.n_query)

        # Create pairs (support, query) and label them (1 for same class, 0 for different class)
        image_pairs, labels = [], []

        # Create positive pairs (same class)
        for i in range(len(support_images)):
            for j in range(len(query_images)):
                if support_labels[i] == query_labels[j]:
                    image_pairs.append((support_images[i], query_images[j]))
                    labels.append(1)  # Same class
        
        # Create negative pairs (different classes)
        for i in range(len(support_images)):
            for j in range(len(query_images)):
                if support_labels[i] != query_labels[j]:
                    image_pairs.append((support_images[i], query_images[j]))
                    labels.append(0)  # Different class

        # Apply transformations if defined
        if self.transform:
            image_pairs = [(self.transform(img1), self.transform(img2)) for img1, img2 in image_pairs]

        return image_pairs, torch.tensor(labels)


In [6]:
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Define the few-shot dataset
few_shot_dataset = FewShotSiameseDataset(cifar10, n_shot=5, n_query=5, n_classes=5, transform=transform)

# DataLoader to feed the model
siamese_dataloader = torch.utils.data.DataLoader(few_shot_dataset, batch_size=32, shuffle=True)

100%|██████████| 170M/170M [03:55<00:00, 725kB/s] 


In [7]:
model = SiameseNetwork()
train_siamese_network(model, siamese_dataloader, epochs=10)

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>