In [1]:
# # comment out this part when run locally
# from google.colab import drive
# drive.mount('/gdrive')
# %cd /gdrive/My\ Drive/similar_faces/

In [2]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models
from PIL import Image
%matplotlib inline

In [3]:
if torch.cuda.is_available():
    print(f'Cuda device: {torch.cuda.get_device_name(0)}')
else:
    print('Cuda unavailable')

Cuda unavailable


In [4]:
# Hyperparameters
num_epochs = 1
batch_size = 8
learning_rate = 0.0001

num_classes = 4
num_images_per_class = 10

## Dataset

In [5]:
class SimilarFaceDataset(Dataset):
    def __init__(self, num_classes, num_images_per_class, image_size=128):
        self.num_classes = num_classes
        self.num_images_per_class = num_images_per_class
        self.images = []
        
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            lambda image: image[:-1],
        ])
        
        for i in range(self.num_classes):
            pathname = os.path.join('data/', f'c{i+1}')
            tmp = []
            for filename in glob.glob(f'{pathname}/*'):
                tmp.append(transform(Image.open(filename)))
            self.images.append(tmp)
            
        self.triplets = []
        for positive_class in range(self.num_classes):
            for anchor in range(self.num_images_per_class):
                for positive in range(self.num_images_per_class):
                    if anchor != positive:
                        for negative_class in range(self.num_classes):
                            if negative_class != positive_class:
                                for negative in range(self.num_images_per_class):
                                    self.triplets.append([
                                        self.images[positive_class][anchor],
                                        self.images[positive_class][positive],
                                        self.images[negative_class][negative]
                                    ])
            
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, index):
        """Return a tuple (Anchor, Positive, Negative)"""
        return self.triplets[index]
    
dataset = SimilarFaceDataset(num_classes, num_images_per_class)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Siamese Network

In [6]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 5),
            nn.PReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5),
            nn.PReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 5),
            nn.PReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 5),
            nn.PReLU(),
            nn.MaxPool2d(2),
        )

        self.fc = nn.Sequential(
            nn.Linear(64 * 4 * 4, 256),
            nn.PReLU(),
            nn.Linear(256, 256),
            nn.PReLU(),
            nn.Linear(256, 2)
        )

    def forward_once(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
        
    def forward(self, anchor, positive, negative):
        anchor_out = self.forward_once(anchor)
        positive_out = self.forward_once(positive)
        negative_out = self.forward_once(negative)
        return (anchor_out, positive_out, negative_out)

## Setup

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork().to(device)

# Loss and optimizer
criterion = torch.nn.TripletMarginLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Training

In [8]:
# num_step_to_print_info = 100
# test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# test_iter = iter(test_loader)
# accu_loss = 0
# for i in range(num_step_to_print_info):
#     anchors, positives, negatives = test_iter.next()
#     anchors = anchors.to(device)
#     positives = positives.to(device)
#     negatives = negatives.to(device)
#     anchor_outputs, positive_outputs, negative_outputs = model(anchors, positives, negatives)
#     loss = criterion(anchor_outputs, positive_outputs, negative_outputs)
#     accu_loss += loss.item()
# print(f'Initial accu loss: {accu_loss}')

In [9]:
# # Load checkpoint
# checkpoint = torch.load('checkpoint')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [0]:
model.train()

accu_loss = 0
accu_loss_history = []

total_step = len(data_loader)
for epoch in range(num_epochs):
    for step, (anchors, positives, negatives) in enumerate(data_loader):
        anchors = anchors.to(device)
        positives = positives.to(device)
        negatives = negatives.to(device)
        
        # forward pass
        anchor_outputs, positive_outputs, negative_outputs = model(anchors, positives, negatives)
        loss = criterion(anchor_outputs, positive_outputs, negative_outputs)
        accu_loss += loss.item()
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (step + 1) % num_step_to_print_info == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{total_step}], Accu loss: {accu_loss:.4f}')
            accu_loss_history.append(accu_loss)
            accu_loss = 0

plt.plot(accu_loss_history)

Epoch [1/1], Step [100/1350], Accu loss: 78.8668
Epoch [1/1], Step [200/1350], Accu loss: 49.3609
Epoch [1/1], Step [300/1350], Accu loss: 26.2187
Epoch [1/1], Step [400/1350], Accu loss: 12.1826
Epoch [1/1], Step [500/1350], Accu loss: 5.4100
Epoch [1/1], Step [600/1350], Accu loss: 3.8224
Epoch [1/1], Step [700/1350], Accu loss: 1.1509
Epoch [1/1], Step [800/1350], Accu loss: 0.5299
Epoch [1/1], Step [900/1350], Accu loss: 0.2747
Epoch [1/1], Step [1000/1350], Accu loss: 0.4619
Epoch [1/1], Step [1100/1350], Accu loss: 0.4699
Epoch [1/1], Step [1200/1350], Accu loss: 0.4699
Epoch [1/1], Step [1300/1350], Accu loss: 2.6714


In [13]:
# # save checkpoint
# torch.save({
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict()
#             }, 'checkpoint')

## Validating

In [0]:
test_loader = DataLoader(dataset, batch_size=1, shuffle=True)
test_iter = iter(test_loader)
for i in range(100):
    anchors, positives, negatives = test_iter.next()
anchors = anchors.to(device)
positives = positives.to(device)
negatives = negatives.to(device)
anchor_outputs, positive_outputs, negative_outputs = model(anchors, positives, negatives)

In [0]:
AP = F.pairwise_distance(anchor_outputs, positive_outputs, keepdim=True).item()
AN = F.pairwise_distance(anchor_outputs, negative_outputs, keepdim=True).item()
print(AP, AN)

1.0646741390228271 12.584348678588867
