In [1]:
# # comment when run locally
# from google.colab import drive
# drive.mount('/gdrive')
# %cd /gdrive/My\ Drive/similar_faces/

In [2]:
# # uncomment when run locally
# !rm -rf /data_celeba
# !mkdir /data_celeba
# %cp celeba_identity.txt /data_celeba
# %cp img_align_celeba.zip /data_celeba
# %cd /data_celeba
# !unzip -q img_align_celeba.zip
# %ls

In [6]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models
from torchvision.utils import make_grid
from PIL import Image
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import euclidean_distances
from collections import defaultdict
%matplotlib inline

In [7]:
sns.set()
if torch.cuda.is_available():
    print(f'Cuda device: {torch.cuda.get_device_name(0)}')
else:
    print('Cuda unavailable')

Cuda unavailable


## Hyperparameters

In [8]:
num_epochs = 10
batch_size = 32
learning_rate = 0.001
train_test_split = 0.9

## Dataset

In [9]:
class SimilarFaceDatasetOnline(Dataset):
    def __init__(self, num_image_per_identity=4):
        with open('celeba_identity.txt') as f:
            filename_identity = [x.split() for x in f.readlines()]
        identity_filenames_dict = defaultdict(list)
        for i in filename_identity:
            identity_filenames_dict[int(i[1])].append(i[0])
        self.identity_filenames_list = sorted(list(identity_filenames_dict.items()), key=lambda x: x[0])
        self.identity_filenames_list = list(map(lambda x: x[1], self.identity_filenames_list))
        num_identities = len(self.identity_filenames_list)
        self.identity_filenames_list = self.identity_filenames_list[:int(num_identities * train_test_split)]
        self.identity_filenames_list = list(filter(lambda x: len(x) >= num_image_per_identity, self.identity_filenames_list))
        self.transform = transforms.ToTensor()
        self.num_image_per_identity = num_image_per_identity
        
    def __len__(self):
        return len(self.identity_filenames_list)
    
    def __getitem__(self, index):
        filenames = random.sample(self.identity_filenames_list[index], self.num_image_per_identity)
        images = torch.stack([self.transform(Image.open(f'img_align_celeba/{x}')) for x in filenames])
        return images

## Network

In [10]:
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        
        if self.downsample is not None:
            residual = self.downsample(residual)
            
        x += residual
        x = F.relu(x)
        
        return x
    
class ResNet(nn.Module):
    def __init__(self, block, layers, out_channels=10):
        super().__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], 2)
        self.layer3 = self._make_layer(block, 64, layers[2], 2)
        self.layer4 = self._make_layer(block, 128, layers[3], 2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, out_channels)
        
    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_net = ResNet(ResidualBlock, [2, 2, 2, 2], out_channels=10)
        
    def get_embedding(self, x):
        return self.embedding_net(x)
        
    def forward(self, a, p, n):
        a_out = self.embedding_net(a)
        p_out = self.embedding_net(p)
        n_out = self.embedding_net(n)
        return (a_out, p_out, n_out)

## Preparation

In [11]:
num_identity = 10 # number of identities
num_image_per_identity = 4
online_train_dataset = SimilarFaceDatasetOnline(num_image_per_identity=num_image_per_identity)
online_train_loader = DataLoader(online_train_dataset, batch_size=num_identity, shuffle=True)
triplet_indices = []
for anchor_identity in range(num_identity):
    for anchor_idx in range(num_image_per_identity):
        for positive_idx in range(num_image_per_identity):
            if anchor_idx != positive_idx:
                for negative_identity in range(num_identity):
                    if anchor_identity != negative_identity:
                        for negative_idx in range(num_image_per_identity):
                            triplet_indices.append((anchor_identity, negative_identity, anchor_idx, positive_idx, negative_idx))
print(f'{len(triplet_indices)} triplets per batch')

4320 triplets per batch


In [12]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork().to(device)

# Loss and optimizer
criterion = torch.nn.TripletMarginLoss(margin=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Training

In [13]:
# # Load checkpoint
# checkpoint = torch.load('checkpoint_online', map_location=torch.device('cpu'))
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
def hard_triplet_loss(anchor, positive, negative):
    """Consider only"""

In [None]:
for epoch in range(num_epochs):
    for step, batch in enumerate(online_train_loader):
        batch = batch.reshape((-1, 3, 218, 178))
        embeddings = model.get_embedding(batch)
        embeddings = embeddings.reshape((num_identity, num_image_per_identity, -1))
        anchor_embeddings = []
        positive_embeddings = []
        negative_embeddings = []
        for i, j, k, l, m in triplet_indices:
            anchor_embeddings.append(embeddings[i, k])
            positive_embeddings.append(embeddings[i, l])
            negative_embeddings.append(embeddings[j, m])
        anchor_embeddings = torch.stack(anchor_embeddings)
        positive_embeddings = torch.stack(positive_embeddings)
        negative_embeddings = torch.stack(negative_embeddings)
        
        loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(loss.item())

## Test on custom data

In [None]:
model.eval()
count = 22
transform = transforms.Compose([
    transforms.Resize((217, 178)),
    transforms.ToTensor(),
    lambda x: x[:3]
])

images = torch.empty(220, 3, 217, 178)
j = 0
for i in range(count):
    pathname = os.path.join('./data', f'c{i+1}')
    for filename in glob.glob(f'{pathname}/*'):
        images[j] = transform(Image.open(filename))
        j += 1
    print(f'{i} ', end='')

embeddings = model.get_embedding(images).detach().numpy()

pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)
df = pd.DataFrame()
df['pca-1'] = pca_result[:, 0]
df['pca-2'] = pca_result[:, 1]
df['label'] = [i for i in range(count) for j in range(10)]
plt.figure(figsize=(16,10))
sns.scatterplot(
    x='pca-1', 
    y='pca-2',
    hue='label',
    palette=sns.color_palette("hls", count),
    data=df
)
print(f'\nExplained variance vatio: {pca.explained_variance_ratio_}')

In [None]:
target_face_index = 200
k = 10
l2_distances = ((embeddings[target_face_index] - embeddings) ** 2).sum(axis=1)
closest_k_face_indices = np.argsort(l2_distances)[:k]
closest_k_faces = images[closest_k_face_indices]
print(l2_distances[closest_k_face_indices])

plt.figure(dpi=100)
plt.imshow(np.transpose(images[target_face_index], (1, 2, 0)))

In [None]:
plt.figure(figsize=(16, 10))
plt.imshow(np.transpose(make_grid(closest_k_faces, 5), (1, 2, 0)))

## Test on celeba data

In [None]:
model.eval()
count = 22
with open('celeba_identity.txt') as f:
    filename_identity = [x.split() for x in f.readlines()]
    identity_filenames_dict = defaultdict(list)
    for i in filename_identity:
        identity_filenames_dict[int(i[1])].append(i[0])
    identity_filenames_list = []
    for i in range(len(identity_filenames_dict)):
        tmp = []
        for filename in identity_filenames_dict[i+1]:
            tmp.append(filename)
        if len(tmp) >= 10:
            identity_filenames_list.append(tmp[:10])
    mask = random.sample(range(len(identity_filenames_list)), count)
    identity_filenames_list_ = []
    for i in mask:
        identity_filenames_list_.append(identity_filenames_list[i])    
    transform = transforms.ToTensor()

images = []
for i in range(count):
    tmp = []
    for j in range(10):
        filename = identity_filenames_list[i][j]
        images.append(transform(Image.open(f'./img_align_celeba/{filename}')))
images = torch.stack(images)

embeddings = model.get_embedding(images)
embeddings = embeddings.detach().numpy()

pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)
df = pd.DataFrame()
df['pca-1'] = pca_result[:, 0]
df['pca-2'] = pca_result[:, 1]
df['label'] = [i for i in range(count) for j in range(10)]
plt.figure(figsize=(16,10))
sns.scatterplot(
    x='pca-1', 
    y='pca-2',
    hue='label',
    palette=sns.color_palette("hls", count),
    data=df
)
print(f'\nExplained variance vatio: {pca.explained_variance_ratio_}')

In [None]:
target_face_index = 144
k = 3
l2_distances = ((embeddings[target_face_index] - embeddings) ** 2).sum(axis=1)
closest_k_face_indices = np.argsort(l2_distances)[:k]
closest_k_faces = images[closest_k_face_indices]
print(l2_distances[closest_k_face_indices])

plt.figure(dpi=100)
plt.imshow(np.transpose(images[target_face_index], (1, 2, 0)))

In [None]:
plt.figure(figsize=(16, 10))
plt.imshow(np.transpose(make_grid(closest_k_faces, 5), (1, 2, 0)))