In [139]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models
from PIL import Image
%matplotlib inline

In [141]:
# Hyperparameters
num_epochs = 1
batch_size = 1
learning_rate = 0.001

num_classes = 4
num_images_per_class = 10

## Dataset

In [142]:
class SimilarFaceDataset(Dataset):
    def __init__(self, num_classes, num_images_per_class, image_size=224):
        self.num_classes = num_classes
        self.num_images_per_class = num_images_per_class
        self.images = []
        
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            lambda image: image[:-1],
        ])
        
        for i in range(self.num_classes):
            pathname = os.path.join('data/', f'c{i+1}')
            tmp = []
            for filename in glob.glob(f'{pathname}/*'):
                tmp.append(transform(Image.open(filename)))
            self.images.append(tmp)
            
        self.triplets = []
        for positive_class in range(self.num_classes):
            for anchor in range(self.num_images_per_class):
                for positive in range(self.num_images_per_class):
                    if anchor != positive:
                        for negative_class in range(self.num_classes):
                            if negative_class != positive_class:
                                for negative in range(self.num_images_per_class):
                                    self.triplets.append([
                                        self.images[positive_class][anchor],
                                        self.images[positive_class][positive],
                                        self.images[negative_class][negative]
                                    ])
            
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, index):
        """Return a tuple (Anchor, Positive, Negative)"""
        return self.triplets[index]
    
dataset = SimilarFaceDataset(num_classes, num_images_per_class)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Siamese Network

In [143]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.vgg11 = models.vgg11_bn()
        
    def forward(self, anchor, positive, negative):
        anchor_out = self.vgg11(anchor)
        positive_out = self.vgg11(positive)
        negative_out = self.vgg11(negative)
        return (anchor_out, positive_out, negative_out)

## Setup

In [144]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork().to(device)

# Loss and optimizer
criterion = torch.nn.TripletMarginLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Training

In [145]:
model.train()

total_step = len(data_loader)
for epoch in range(num_epochs):
    for step, (anchors, positives, negatives) in enumerate(data_loader):
        anchors = anchors.to(device)
        positives = positives.to(device)
        negatives = negatives.to(device)
        
        # forward pass
        anchor_outputs, positive_outputs, negative_outputs = model(anchors, positives, negatives)
        loss = criterion(anchor_outputs, positive_outputs, negative_outputs)
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Step [{step}/{total_step}], Loss: {loss.item():.4f}')

Epoch [1/1], Step [0/10800], Loss: 0.2952
Epoch [1/1], Step [1/10800], Loss: 6.2917
Epoch [1/1], Step [2/10800], Loss: 0.0000
Epoch [1/1], Step [3/10800], Loss: 3.7417
Epoch [1/1], Step [4/10800], Loss: 1.1707
Epoch [1/1], Step [5/10800], Loss: 3.6254
Epoch [1/1], Step [6/10800], Loss: 0.0000
Epoch [1/1], Step [7/10800], Loss: 0.0000
Epoch [1/1], Step [8/10800], Loss: 220.2189
Epoch [1/1], Step [9/10800], Loss: 0.0000
Epoch [1/1], Step [10/10800], Loss: 0.0000
Epoch [1/1], Step [11/10800], Loss: 203.3857
Epoch [1/1], Step [12/10800], Loss: 0.0000
Epoch [1/1], Step [13/10800], Loss: 103.4795
Epoch [1/1], Step [14/10800], Loss: 173.1426
Epoch [1/1], Step [15/10800], Loss: 172.5464
Epoch [1/1], Step [16/10800], Loss: 0.0000
Epoch [1/1], Step [17/10800], Loss: 0.0000
Epoch [1/1], Step [18/10800], Loss: 754.2585
Epoch [1/1], Step [19/10800], Loss: 326.6594
Epoch [1/1], Step [20/10800], Loss: 0.0000
Epoch [1/1], Step [21/10800], Loss: 673.5610
Epoch [1/1], Step [22/10800], Loss: 305.9104
Epo

KeyboardInterrupt: 