# Tutorial 6-3: Identity Crisis â€“ "Siamese Networks for Re-ID"

**Course:** CSEN 342: Deep Learning  
**Topic:** Metric Learning, Siamese Networks, and Contrastive Loss

## Objective
In standard classification, we teach a network to say "This is a dog." In **Re-Identification (Re-ID)**, we don't know the classes beforehand (e.g., a new person walks into a camera frame). Instead, we must teach the network to answer: "Are these two images the **same** object?"

This is called **Metric Learning**. We want to learn an embedding space where similar objects are close together and different objects are far apart.

In this tutorial, we will:
1.  **Build a Siamese Dataset:** Create a data loader that produces *pairs* of images (Positive/Negative examples).
2.  **Design a Siamese Network:** A model with two identical branches that share weights.
3.  **Implement Contrastive Loss:** A custom loss function that pulls similar pairs together and pushes dissimilar pairs apart.
4.  **Visualize Embeddings:** See how the network organizes data in geometric space.

---

## Part 1: The Siamese Dataset

We need a dataset that returns two images and a label ($1$ if same class, $0$ if different). We will use Fashion-MNIST, treating each class (e.g., 'Sneaker') as a distinct 'identity'.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import random
import numpy as np
import matplotlib.pyplot as plt

# Import utility functions
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))
from utils import download_fashion_mnist

download_fashion_mnist()

# 1. Define the Paired Dataset
class SiameseFashionMNIST(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.ds = torchvision.datasets.FashionMNIST(root=root, train=train, download=True, transform=transform)
        # Group indices by class for fast lookup
        self.class_indices = [[] for _ in range(10)]
        for idx, (_, label) in enumerate(self.ds):
            self.class_indices[label].append(idx)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        # Image 1: Get the item at 'index'
        img1, label1 = self.ds[index]

        # Image 2: 50% chance of same class, 50% chance of diff class
        should_get_same_class = random.randint(0, 1)
        
        if should_get_same_class:
            # Pick random index from same class
            idx2 = random.choice(self.class_indices[label1])
            target = 1.0 # Similar
        else:
            # Pick random class (that isn't label1)
            diff_label = random.randint(0, 9)
            while diff_label == label1:
                diff_label = random.randint(0, 9)
            # Pick random index from that different class
            idx2 = random.choice(self.class_indices[diff_label])
            target = 0.0 # Dissimilar
            
        img2, label2 = self.ds[idx2]
        return img1, img2, torch.tensor(target, dtype=torch.float32)

# Visualize
transform = transforms.ToTensor()
train_ds = SiameseFashionMNIST(root='../data', train=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)

img1, img2, target = next(iter(train_loader))
fig, axs = plt.subplots(1, 2)
axs[0].imshow(img1[0].squeeze(), cmap='gray')
axs[1].imshow(img2[0].squeeze(), cmap='gray')
plt.title(f"Label: {target[0].item()} (1=Same, 0=Diff)")
plt.show()

---

## Part 2: The Siamese Network

A Siamese network passes both images through the **same** embedding network (shared weights). 

We will force the output to be 2-dimensional (vector of size 2) so we can plot it easily later.

In [None]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        # Simple CNN
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 4 * 4, 256), nn.ReLU(),
            nn.Linear(256, 2) # Output 2D embedding for easy visualization
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.embedding_net = EmbeddingNet()

    def forward(self, x1, x2):
        # Pass both images through the SAME network
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        return output1, output2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork().to(device)
print("Siamese Network Initialized.")

---

## Part 3: Contrastive Loss

We use the **Contrastive Loss** function.
Let $D$ be the Euclidean distance between the two embeddings: $D = || v_1 - v_2 ||_2$.
Let $Y$ be the label ($1$ for same, $0$ for different).

$$ L = Y \cdot D^2 + (1 - Y) \cdot \max(0, \text{margin} - D)^2 $$

* **If Same ($Y=1$):** We minimize $D^2$ (pull them together).
* **If Different ($Y=0$):** We minimize $\max(0, m - D)^2$. This pushes them apart, but only until they are separated by `margin`. Once they are far enough, we stop caring (loss becomes 0).

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Calculate Euclidean Distance
        euclidean_distance = F.pairwise_distance(output1, output2)
        
        # Contrastive Loss Formula
        loss_contrastive = torch.mean((label) * torch.pow(euclidean_distance, 2) + 
                                      (1 - label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = ContrastiveLoss()

## Part 4: Training and Visualization

We will train for a few epochs. After training, we will pass the test set through the network and plot the 2D embeddings. If the network works, images of the same class (e.g., all Shoes) should cluster together in space.

In [None]:
# Training Loop
print("Starting Training...")
model.train()
for epoch in range(5): 
    total_loss = 0
    for i, (img1, img2, label) in enumerate(train_loader):
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)
        
        optimizer.zero_grad()
        out1, out2 = model(img1, img2)
        loss = criterion(out1, out2, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}: Loss {total_loss/len(train_loader):.4f}")

# Visualization function
def plot_embeddings(model):
    # Use standard FashionMNIST test set (not pairs)
    test_ds = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=transforms.ToTensor())
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
    
    model.eval()
    embeddings = []
    labels = []
    
    with torch.no_grad():
        for img, label in test_loader:
            img = img.to(device)
            emb = model.embedding_net(img)
            embeddings.append(emb.cpu().numpy())
            labels.append(label.numpy())
            
    embeddings = np.concatenate(embeddings)
    labels = np.concatenate(labels)
    
    # Plot
    plt.figure(figsize=(10, 8))
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    scatter = plt.scatter(embeddings[:, 0], embeddings[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.legend(handles=scatter.legend_elements()[0], labels=classes, title="Classes")
    plt.title("Learned 2D Embeddings via Siamese Network")
    plt.grid(True)
    plt.show()

print("Visualizing Test Set Embeddings...")
plot_embeddings(model)

### Conclusion
Look at the plot! 
* **Clustering:** Notice how similar items (e.g., Sneakers, Sandals, and Boots) group together in one region of the 2D space, while Clothing items group in another.
* **Metric Learning:** We never trained the network to classify "This is a T-Shirt." We only taught it "These two look alike." Yet, it naturally discovered the class structure.
    
This is the foundation of modern Face ID, Re-ID, and Image Retrieval systems.