In [33]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from src.few_shot_learning import load_cinic10, calculate_accuracy, plot_confusion_matrix
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from PIL import Image
import random
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import os

In [22]:
# Define the basic architecture for the subnetwork (the same one used for both inputs)
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        
        # Dummy forward pass to determine fc input size
        dummy_input = torch.randn(1, 3, 64, 64)  # Assuming input size 64x64
        dummy_output = self.pool(F.relu(self.conv2(self.pool(F.relu(self.conv1(dummy_input))))))
        flattened_size = dummy_output.view(1, -1).shape[1]  # Get correct size
    
        self.fc = nn.Linear(flattened_size, 128)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten
        return self.fc(x)


In [3]:
# Contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, p=2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + 
                          (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

In [35]:
def train_siamese_network(model, dataloader, epochs=10, lr=0.0001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = ContrastiveLoss(margin=1.0)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for data1, data2, labels in dataloader:
            data1, data2, labels = data1.to(device), data2.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            output1 = model(data1)
            output2 = model(data2)
            
            # Calculate contrastive loss
            loss = criterion(output1, output2, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}")

In [28]:
class CINIC10SiameseDataset(Dataset):
    def __init__(self, root_dir, num_samples_per_class=100, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.num_samples_per_class = num_samples_per_class
        
        # Get all class names
        self.classes = os.listdir(root_dir)
        
        # Limit the number of images per class to `num_samples_per_class`
        self.image_paths = {}
        for c in self.classes:
            all_images = os.listdir(os.path.join(root_dir, c))
            sampled_images = random.sample(all_images, min(len(all_images), num_samples_per_class))
            self.image_paths[c] = [os.path.join(root_dir, c, img) for img in sampled_images]

    def __len__(self):
        return sum(len(imgs) for imgs in self.image_paths.values())

    def __getitem__(self, index):
        # Randomly select a class
        class_name = random.choice(self.classes)
        img1_path = random.choice(self.image_paths[class_name])

        # Decide if this will be a positive or negative pair
        if random.random() > 0.5:  # 50% chance of being a positive pair
            img2_path = random.choice(self.image_paths[class_name])
            label = 1
        else:  # Negative pair
            different_class = random.choice([c for c in self.classes if c != class_name])
            img2_path = random.choice(self.image_paths[different_class])
            label = 0

        # Load images
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")

        # Apply transformations
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, torch.tensor(label, dtype=torch.float32)


In [31]:
# Define transformations (remove ToTensor from dataset loading)
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])

dataset_path = '../../data/cinic-10/train'

# Define the few-shot dataset
siamese_dataset = CINIC10SiameseDataset(root_dir=dataset_path, transform=transform)

# DataLoader to feed the model
siamese_dataloader = torch.utils.data.DataLoader(siamese_dataset, batch_size=32, shuffle=True)

In [42]:
import torch

def evaluate_siamese_network(model, dataloader, threshold=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():  # No gradients needed for evaluation
        for img1, img2, labels in dataloader:
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

            # Forward pass
            output1 = model(img1)
            output2 = model(img2)

            # Compute Euclidean distance (or L1 distance)
            distances = torch.norm(output1 - output2, p=2, dim=1)

            # Convert distances to binary predictions
            predictions = (distances < threshold).float()

            # Count correct predictions
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = (correct / total) * 100
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy


In [39]:
model = SiameseNetwork()
train_siamese_network(model, siamese_dataloader, epochs=30)

Epoch [1/30], Loss: 9.3954
Epoch [2/30], Loss: 9.1318
Epoch [3/30], Loss: 8.9913
Epoch [4/30], Loss: 8.7910
Epoch [5/30], Loss: 8.9604
Epoch [6/30], Loss: 8.6815
Epoch [7/30], Loss: 8.8289
Epoch [8/30], Loss: 8.4779
Epoch [9/30], Loss: 8.8322
Epoch [10/30], Loss: 8.6088
Epoch [11/30], Loss: 8.5445
Epoch [12/30], Loss: 8.5265
Epoch [13/30], Loss: 8.7521
Epoch [14/30], Loss: 8.7096
Epoch [15/30], Loss: 8.5977
Epoch [16/30], Loss: 8.4794
Epoch [17/30], Loss: 8.4841
Epoch [18/30], Loss: 8.6347
Epoch [19/30], Loss: 8.3687
Epoch [20/30], Loss: 8.5102
Epoch [21/30], Loss: 8.6155
Epoch [22/30], Loss: 8.5831
Epoch [23/30], Loss: 8.5597
Epoch [24/30], Loss: 8.6220
Epoch [25/30], Loss: 8.6080
Epoch [26/30], Loss: 8.4588
Epoch [27/30], Loss: 8.6239
Epoch [28/30], Loss: 8.5991
Epoch [29/30], Loss: 8.7450
Epoch [30/30], Loss: 8.4158


In [43]:
train_accuracy = evaluate_siamese_network(model, siamese_dataloader)

Accuracy: 53.00%
