In [1]:
import os
import random
import pandas as pd
import numpy as np
import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torchvision.models as models
from torchvision import datasets, transforms

from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm  # For progress bar

In [2]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel
    transforms.Resize((224, 224)),  # Resize to Inception v3 input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

])


# Load MNIST dataset
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# Efficient function: Store only image indices, not full images
def organize_by_label(dataset):
    label_dict = {}
    for idx in range(len(dataset)):
        _, label = dataset[idx]  # Get label only
        if label not in label_dict:
            label_dict[label] = []
        label_dict[label].append(idx)  # Store index, not image
    return label_dict


train_label_dict = organize_by_label(mnist_train)
test_label_dict = organize_by_label(mnist_test)

In [3]:
# Function to create positive and negative pairs
def create_pairs(dataset, label_dict, num_pairs=1000):
    pairs = []
    labels = list(label_dict.keys())

    for _ in range(num_pairs):
        # Positive pair (same digit)
        label = random.choice(labels)
        if len(label_dict[label]) >= 2:
            idx1, idx2 = random.sample(label_dict[label], 2)
            pairs.append((idx1, idx2, 0))  # Label 0 for similar images

        # Negative pair (different digits)
        label1, label2 = random.sample(labels, 2)
        idx1 = random.choice(label_dict[label1])
        idx2 = random.choice(label_dict[label2])
        pairs.append((idx1, idx2, 1))  # Label 1 for dissimilar images

    return pairs

# Generate pairs using the optimized function
train_pairs = create_pairs(mnist_train, train_label_dict, num_pairs=5000)
test_pairs = create_pairs(mnist_test, test_label_dict, num_pairs=2500)

In [4]:
# Updated Dataset Class (Loads images dynamically)
class ContrastiveMNISTDataset(Dataset):
    def __init__(self, dataset, pairs):
        self.dataset = dataset  # Store reference to dataset
        self.pairs = pairs  # Store pairs (indices, not images)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        idx1, idx2, label = self.pairs[idx]  # Get image indices & label
        img1, _ = self.dataset[idx1]  # Load image dynamically
        img2, _ = self.dataset[idx2]  # Load image dynamically
        return (img1, img2), torch.tensor(label, dtype=torch.float32)
    
# Create dataset instances
train_dataset = ContrastiveMNISTDataset(mnist_train, train_pairs)
test_dataset = ContrastiveMNISTDataset(mnist_test, test_pairs)

In [5]:
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True , num_workers= 16, pin_memory=True, prefetch_factor=2, persistent_workers=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False , num_workers= 16, pin_memory=True, prefetch_factor=2, persistent_workers=True)

print(f"Training dataset size: {len(train_dataset)} pairs")
print(f"Testing dataset size: {len(test_dataset)} pairs")

Training dataset size: 10000 pairs
Testing dataset size: 5000 pairs


In [6]:
# Use ResNet50 pretrained on ImageNet as the feature extractor
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Load pretrained ResNet50 and remove the last fully connected layer
        resnet50 = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])  # Remove the FC layer

        # Add a fully connected layer for feature comparison (optional)
        self.fc = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward_once(self, x):
        """Pass the input through the feature extractor."""
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # Flatten the output
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        """Compute embeddings for both inputs and return the distance."""
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2
    

    # Define the contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        """Compute contrastive loss."""
        euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss
    

In [8]:
# Create the model, loss function, and optimizer
model = SiameseNetwork().cuda()  # Move model to GPU if available
criterion = ContrastiveLoss(margin=1.0)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)  # Reduce LR if no improvement

# Training loop with validation using the split dataset
num_epochs = 50
best_val_loss = float('inf')
early_stop_counter = 0
patience = 10  # Early stopping patience



In [None]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # Training phase
    for (img1, img2), labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        img1, img2, labels = img1.cuda(), img2.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        output1, output2 = model(img1, img2)
        
        loss = criterion(output1, output2, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        # Compute accuracy
        similarity_scores = torch.nn.functional.pairwise_distance(output1, output2)
        predictions = (similarity_scores > 0.5).float()  # Assuming 0.5 as the threshold
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)
    
    avg_train_loss = total_loss / len(train_dataloader)
    train_accuracy = correct_predictions / total_samples
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    correct_val_predictions = 0
    total_val_samples = 0

    with torch.no_grad():
        for (val_img1, val_img2), val_labels in test_dataloader:
            val_img1, val_img2, val_labels = val_img1.cuda(), val_img2.cuda(), val_labels.cuda()
            val_output1, val_output2 = model(val_img1, val_img2)
            val_loss += criterion(val_output1, val_output2, val_labels).item()

            # Compute validation accuracy
            val_similarity_scores = torch.nn.functional.pairwise_distance(val_output1, val_output2)
            val_predictions = (val_similarity_scores > 0.5).float()
            correct_val_predictions += (val_predictions == val_labels).sum().item()
            total_val_samples += val_labels.size(0)

    avg_val_loss = val_loss / len(test_dataloader)
    val_accuracy = correct_val_predictions / total_val_samples
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

    # Early stopping and model saving
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), "best_siamese_model.pth")
        print("Best model saved!")
    else:
        early_stop_counter += 1
        print(f"Early stopping counter: {early_stop_counter}/{patience}")

    # Reduce learning rate if validation loss doesn't improve
    scheduler.step(avg_val_loss)
    
    if early_stop_counter >= patience:
        print("Early stopping triggered. Training stopped.")
        break


Epoch 1/50: 100%|██████████| 79/79 [00:59<00:00,  1.33it/s]

Epoch [1/50], Training Loss: 0.0843, Training Accuracy: 0.9071





Validation Loss: 0.0307, Validation Accuracy: 0.9824
Best model saved!


Epoch 2/50:  65%|██████▍   | 51/79 [00:36<00:20,  1.38it/s]


KeyboardInterrupt: 

: 