In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torchvision.datasets.mnist import FashionMNIST
from torchvision.models import DenseNet121_Weights
from torch.utils.data import Dataset, DataLoader
import itertools
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import random

In [29]:
class FashionMNISTPairs(Dataset):
    def __init__(self, dataset, num_pairs_per_epoch=100000):
        self.dataset = dataset
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406],
            #                      std=[0.229, 0.224, 0.225]),
        ])
        self.length = len(dataset)
        self.num_pairs_per_epoch = num_pairs_per_epoch

    def __len__(self):
        return self.num_pairs_per_epoch

    def __getitem__(self, idx):
        i, j = random.sample(range(self.length), 2)
        
        img1, label1 = self.dataset[i]
        img2, label2 = self.dataset[j]
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        label = int(label1 == label2)
        return img1, img2, torch.tensor(label, dtype=torch.float32)

In [30]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        base_model = models.densenet121(weights=DenseNet121_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Linear(224*224, 128)

    def forward(self, input1, input2):
        output1 = self.feature_extractor(input1)
        output1 = self.fc(output1.view(output1.size(0), -1))
        output2 = self.feature_extractor(input2)
        output2 = self.fc(output2.view(output2.size(0), -1))
        return output1, output2

    def extract_features(self, input):
        with torch.no_grad():
            features = self.feature_extractor(input)
            features = features.view(features.size(0), -1)
        return features

In [31]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

In [32]:
class Trainer:
    def __init__(self, model, criterion, optimizer, dataloader, device):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.dataloader = dataloader
        self.device = device

    def train(self, num_epochs):
        self.model.to(self.device)
        self.model.train()
        for epoch in range(num_epochs):
            epoch_loss = 0
            for img1, img2, label in self.dataloader:
                img1, img2, label = img1.to(self.device), img2.to(self.device), label.to(self.device)
                self.optimizer.zero_grad()
                output1, output2 = self.model(img1, img2)
                loss = self.criterion(output1, output2, label)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(self.dataloader)}")

In [33]:
class EmbeddingVisualizer:
    def __init__(self, model, dataloader, device):
        self.model = model
        self.dataloader = dataloader
        self.device = device

    def extract_embeddings(self):
        self.model.to(self.device)
        self.model.eval()
        embeddings = []
        labels = []
        with torch.no_grad():
            for img, _, label in self.dataloader:
                img = img.to(self.device)
                embedding = self.model.extract_features(img)
                embeddings.append(embedding.cpu().numpy())
                labels.append(label.cpu().numpy())
        embeddings = np.concatenate(embeddings, axis=0)
        labels = np.concatenate(labels, axis=0)
        return embeddings, labels
    
    def visualize(self, embeddings, labels):
        tsne = TSNE(n_components=2)
        tsne_results = tsne.fit_transform(embeddings)
        plt.figure(figsize=(12, 8))
        plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='viridis')
        plt.colorbar()
        plt.title('t-SNE visualization of image embeddings')
        plt.xlabel('Dimension 1')
        plt.ylabel('Dimension 2')
        plt.show()

In [34]:
class Tester:
    def __init__(self, model, dataloader, device):
        self.model = model
        self.dataloader = dataloader
        self.device = device

    def test(self):
        self.model.to(self.device)
        self.model.eval()
        all_labels = []
        all_predictions = []

        with torch.no_grad():
            for img1, img2, label in self.dataloader:
                img1, img2, label = img1.to(self.device), img2.to(self.device), label.to(self.device)
                output1, output2 = self.model(img1, img2)
                euclidean_distance = nn.functional.pairwise_distance(output1, output2)
                predictions = (euclidean_distance < 0.5).float()
                
                all_labels.extend(label.cpu().numpy())
                all_predictions.extend(predictions.cpu().numpy())

        # Convert lists to numpy arrays
        all_labels = np.array(all_labels)
        all_predictions = np.array(all_predictions)

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions)
        recall = recall_score(all_labels, all_predictions)
        f1 = f1_score(all_labels, all_predictions)

        print(f"Test Accuracy: {accuracy * 100:.2f}%")
        print(f"Test Precision: {precision:.2f}")
        print(f"Test Recall: {recall:.2f}")
        print(f"Test F1 Score: {f1:.2f}")

        return accuracy, precision, recall, f1

In [36]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print('using device:', device)
# Load Fashion MNIST dataset
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True)

train_pairs_dataset = FashionMNISTPairs(train_dataset)
test_pairs_dataset = FashionMNISTPairs(test_dataset)

train_loader = DataLoader(train_pairs_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_pairs_dataset, batch_size=32, shuffle=False)

# Initialize model, loss function, and optimizer
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
print('starting training')
trainer = Trainer(model, criterion, optimizer, train_loader, device)
trainer.train(num_epochs=10)

# Test the model
print('starting testing')
tester = Tester(model, test_loader, device)
tester.test()

# Visualize embeddings
visualizer = EmbeddingVisualizer(model, test_loader, device)
embeddings, labels = visualizer.extract_embeddings()
visualizer.visualize(embeddings, labels)

using device: mps
starting training


KeyboardInterrupt: 