In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import random

In [2]:
class SiameseNetwork(nn.Module):
    def __init__(self, in_channels=3):
        super(SiameseNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 3 * 3, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)

    def forward(self, x):
        x = self.pool(F.leaky_relu(self.bn1(self.conv1(x))))  # 64x16x16
        x = self.pool(F.leaky_relu(self.bn2(self.conv2(x))))   # 128x8x8
        x = self.pool(F.leaky_relu(self.bn3(self.conv3(x))))   # 256x4x4
        x = x.view(x.size(0), -1)
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, z1, z2, label):
        euclidean_distance = F.pairwise_distance(z1, z2, keepdim=False)
        loss_contrastive = torch.mean((label) * torch.pow(euclidean_distance, 2) +
                                    (1-label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

In [4]:
# Siamese Dataset Loader
class SiameseDataset(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.label_to_indices = {}
        for idx, (_, label) in enumerate(self.mnist_dataset):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)

    def __getitem__(self, index):
        img1, label1 = self.mnist_dataset[index]
        same_class = random.randint(0, 1)
        if same_class:
            idx2 = random.choice(self.label_to_indices[label1])
        else:
            diff_label = random.choice(list(self.label_to_indices.keys()))
            while diff_label == label1:
                diff_label = random.choice(list(self.label_to_indices.keys()))
            idx2 = random.choice(self.label_to_indices[diff_label])
        img2, label2 = self.mnist_dataset[idx2]
        label = 1 if label1 == label2 else 0
        return img1, img2, label

    def __len__(self):
        return len(self.mnist_dataset)

In [5]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [6]:
# Prepare datasets and dataloaders
train_dataset = SiameseDataset(mnist_train)
test_dataset = SiameseDataset(mnist_test)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [7]:
# Model, loss and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetwork(in_channels=1).to(device)

In [8]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0.01)

model.apply(weights_init)

SiameseNetwork(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=2304, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=64, bias=True)
)

In [9]:
criterion = ContrastiveLoss(margin=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [10]:
def sum_abs_changes(model, prev_params):
    """
    Compute the sum of absolute changes in model parameters.
    :param model: The neural network model.
    :param prev_params: Dictionary storing the previous parameter values.
    :return: Total sum of absolute changes.
    """
    total_change = 0.0
    for name, param in model.named_parameters():
        if name in prev_params:
            # Compute absolute change
            change = torch.sum(torch.abs(param.data - prev_params[name])).item()
            total_change += change
        # Update previous parameters
        prev_params[name] = param.data.clone()
    return total_change

In [11]:
prev_params = {name: param.data.clone() for name, param in model.named_parameters()}

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for img1, img2, label in train_loader:
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)
        optimizer.zero_grad()
        z1 = model(img1)
        z2 = model(img2)
        loss = criterion(z1, z2, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # with torch.no_grad():
        #     pos_dist = F.pairwise_distance(z1[label==1], z2[label==1])
        #     neg_dist = F.pairwise_distance(z1[label==0], z2[label==0])
        #     print(f"Pos dist: {pos_dist.mean():.2f}, Neg dist: {neg_dist.mean():.2f}")

    # Sum absolute values of parameters
    param_change_sum = sum_abs_changes(model, prev_params)

    prev_params = {name: param.data.clone() for name, param in model.named_parameters()}
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}, Param Change Sum: {param_change_sum:.4f}')


Epoch 1/50, Loss: 0.7766494184098345, Param Change Sum: 91876.6931
Epoch 2/50, Loss: 0.23984754662564461, Param Change Sum: 4784.5253
Epoch 3/50, Loss: 0.17942292087889733, Param Change Sum: 4358.9033
Epoch 4/50, Loss: 0.15445082289107304, Param Change Sum: 4808.3423
Epoch 5/50, Loss: 0.13324828756616472, Param Change Sum: 4526.8121
Epoch 6/50, Loss: 0.11485166974524234, Param Change Sum: 3922.0429
Epoch 7/50, Loss: 0.1016930083962197, Param Change Sum: 4171.3417
Epoch 8/50, Loss: 0.09220289749668, Param Change Sum: 4764.7257
Epoch 9/50, Loss: 0.08129568149117714, Param Change Sum: 4319.3327
Epoch 10/50, Loss: 0.07879835069179535, Param Change Sum: 6189.5499
Epoch 11/50, Loss: 0.06329500822627798, Param Change Sum: 4838.7850
Epoch 12/50, Loss: 0.05416015106788341, Param Change Sum: 4847.3699
Epoch 13/50, Loss: 0.05107551252746836, Param Change Sum: 6154.5487
Epoch 14/50, Loss: 0.04576607366350103, Param Change Sum: 6050.1599
Epoch 15/50, Loss: 0.03812443917577571, Param Change Sum: 546

In [12]:
# Embedding the test set
model.eval()
test_embeddings = []
test_images = []
test_labels = []

with torch.no_grad():
    for img, label in DataLoader(mnist_test, batch_size=256):
        img = img.to(device)
        z = model(img)
        test_embeddings.append(z.cpu().numpy())
        test_images.append(img.cpu().numpy())
        test_labels.append(label.numpy())

test_embeddings = np.concatenate(test_embeddings)
test_images = np.concatenate(test_images)
test_labels = np.concatenate(test_labels)


train_embeddings = []
train_images = []
train_labels = []

with torch.no_grad():
    for img, label in DataLoader(mnist_train, batch_size=256):
        img = img.to(device)
        z = model(img)
        train_embeddings.append(z.cpu().numpy())
        train_images.append(img.cpu().numpy())
        train_labels.append(label.numpy())

train_embeddings = np.concatenate(train_embeddings)
train_images = np.concatenate(train_images)
train_labels = np.concatenate(train_labels)

In [13]:
# KNN Classifier
knn_embedding = KNeighborsClassifier(n_neighbors=20)
knn_embedding.fit(train_embeddings, train_labels)
# accuracy = knn_embedding.score(train_embeddings, train_labels)
accuracy = knn_embedding.score(test_embeddings, test_labels)
print(f"KNN classification accuracy in embedding space: {accuracy * 100:.2f}%")

KNN classification accuracy in embedding space: 99.21%


In [14]:
# KNN Classifier
knn_images = KNeighborsClassifier(n_neighbors=20)
knn_images.fit(train_images.reshape(train_images.shape[0], -1), train_labels)
# accuracy = knn_images.score(train_images.reshape(train_images.shape[0], -1), train_labels)
accuracy = knn_images.score(test_images.reshape(test_images.shape[0], -1), test_labels)
print(f"KNN classification accuracy in image space: {accuracy * 100:.2f}%")

KNN classification accuracy in image space: 96.25%
