In [212]:
import os
import torch
import itertools
import torchvision
import numpy as np
from PIL import Image
%pylab inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision import transforms
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torch.autograd import Function

Populating the interactive namespace from numpy and matplotlib


In [35]:
data_dir = 'lfw_modified'

In [34]:
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5))
                                ])

In [293]:
class DataLoader():
    def __init__(self, dir_path, transform):
        self.images_dict = {}
        self.id2image = {}
        self.labels = None
        self.dir_path = dir_path
        self.transform = transform
        self.load_images()
    
    def load_images(self):
        # returns labels/names list
        self.labels = os.listdir(self.dir_path)
        for label in self.labels:
            path = os.path.join(self.dir_path, label)
            images = os.listdir(path)
            self.images_dict[label] = images
            for image_id in images:
                img_path = os.path.join(path, image_id)
                self.id2image[image_id] = self.transform(Image.open(img_path))
    
    def gen_data(self):
        labels = []
        image_ids = []
        for label, images in self.images_dict.items():
            num_images = len(images)
            labels.extend([label] * num_images)
            image_ids.extend(images)
        return image_ids, labels
        
    def get_image(self, image_id):
        return self.id2image[image_id]

def shuffle_data(data, seed = 0):
    image_ids, labels = data
    shuffled_image_ids = []
    shuffled_labels = []
    num_images = len(image_ids)
    torch.manual_seed(seed)
    perm = list(torch.randperm(num_images))
    for i in range(num_images):
        shuffled_image_ids.append(image_ids[perm[i]])
        shuffled_labels.append(labels[perm[i]])
    return shuffled_image_ids, shuffled_labels

def make_minibatches(data, minibatch_size = 16,  seed = 0):
    X, Y = data
    m = len(X)
    minibatches = []

    shuffled_X, shuffled_Y = shuffle_data(data, seed = seed)

    num_complete_minibatches = math.floor(m/minibatch_size)
    for k in range(0, num_complete_minibatches):
        minibatch_X = shuffled_X[k * minibatch_size : k * minibatch_size + minibatch_size]
        minibatch_Y = shuffled_Y[k * minibatch_size : k * minibatch_size + minibatch_size]
        minibatches.append((minibatch_X, minibatch_Y))

    rem_size = m - num_complete_minibatches * minibatch_size
    if m % minibatch_size != 0:
        minibatch_X = shuffled_X[num_complete_minibatches * minibatch_size : m]
        minibatch_Y = shuffled_Y[num_complete_minibatches * minibatch_size : m]
        minibatches.append((minibatch_X, minibatch_Y))

    return minibatches

def batch2embeddings(minibatch_X, cnn, dataloader, gpu_device):
    minibatch_size = len(minibatch_X)
    images_tensor = torch.zeros(minibatch_size, 3, 224, 224)
    id2embeds = {}
    for i in range(minibatch_size):
        x = minibatch_X[i]
        x_image = dataloader.get_image(x)
        images_tensor[i, :, :, :] = x_image
    images_tensor = Variable(images_tensor)
    if torch.cuda.is_available():
        with torch.cuda.device(gpu_device):
            images_tensor = images_tensor.cuda()
    embeds = cnn(images_tensor)
    for i in range(minibatch_size):
        x = minibatch_X[i]
        id2embeds[x] = embeds[i, :]
    return id2embeds


def gen_triplets(minibatch, id2embeds, embedding_dim, mode = 'all'):
    X, Y = minibatch
    Y_prod = itertools.product(Y, repeat=3)
    X_prod = itertools.product(X, repeat=3)
    triplet = []
    for x, y  in zip(X_prod, Y_prod):
        xa, xp, xn = x
        ya, yp, yn = y
        if (ya == yp) and (ya!=yn) and (xa!=xp):
            triplet.append((xa, xp, xn))
    
    num_triplets = len(triplet)
    anchor = torch.zeros(num_triplets, embedding_dim)
    positive = torch.zeros(num_triplets, embedding_dim)
    negative = torch.zeros(num_triplets, embedding_dim)
    for i in range(num_triplets):
        xa, xp, xn = triplet[i]
        anchor[i, :] = id2embeds[xa]
        positive[i, :] = id2embeds[xp]
        negative[i, :] = id2embeds[xn]
        
    return anchor, positive, negative

class TripletLoss(nn.Module):
    def __init__(self, alpha = 0.2):
        super(TripletLoss, self).__init__()
        self.alpha = alpha
            
    def forward(self, anchor, positive, negative):
        alpha = self.alpha
        pos_dist = anchor - positive
        pos_dist = torch.pow(pos_dist, 2).sum(dim=0)
        print(pos_dist, pos_dist.shape)
        neg_dist = anchor - negative
        neg_dist = torch.pow(neg_dist, 2).sum(dim=0)
        print(neg_dist)
        basic_loss = pos_dist - neg_dist + alpha
    #     loss = torch.max(basic_loss, torch.zeros(basic_loss.shape[0])).sum()
        loss = torch.clamp(basic_loss, min=0.0).sum()
        return loss

In [53]:
dataloader = DataLoader(data_dir, transform)
data = dataloader.gen_data()

In [45]:
class Alexnet(nn.Module):
    def __init__(self, embedding_dim=32):
        super(Alexnet, self).__init__()
        self.alexnet = models.alexnet(pretrained=True)
        in_features = self.alexnet.classifier[6].in_features
        self.linear = nn.Linear(in_features, embedding_dim)
        self.alexnet.classifier[6] = self.linear
        self.init_weights()
    
    def init_weights(self):
        self.linear.weight.data.normal_(0.0, 0.02)
        self.linear.bias.data.fill_(0)
    
    def forward(self, images):
        embed = self.alexnet(images)
        return embed

In [46]:
embedding_dim=32

In [47]:
cnn = Alexnet(embedding_dim = embedding_dim)

In [48]:
gpu_device = 0
if torch.cuda.is_available():
    with torch.cuda.device(gpu_device):
        cnn.cuda()

In [270]:
learning_rate = 1e-7
params = cnn.parameters()
optimizer = torch.optim.Adam(params, lr = learning_rate)

In [281]:
def triplet_loss(anchor, positive, negative, alpha=0.2):
    return TripletLoss(alpha)(anchor, positive, negative)

In [265]:
num_epochs = 50

In [273]:
cnn.train()

Alexnet(
  (alexnet): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (classifier): Sequential(
      (0): Dropout(p=0.5)
      (1): Linear(in_features=9216, out_features=4096, bias=True)
      (2): ReLU(inplace)
      (3)

In [286]:
triplet_loss(anchor, positive, negative)

tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]) torch.Size([32])
tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])


tensor(6.4000)

In [287]:
minibatches = make_minibatches(data, seed = epoch)
loss = []

In [288]:
cur_minibatch = minibatches[0]

In [290]:
id2embeds = batch2embeddings(cur_minibatch[0], cnn, dataloader, gpu_device)

In [292]:
anchor, positive, negative = gen_triplets(cur_minibatch, id2embeds, embedding_dim)
anchor, positive, negative

(tensor([[-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         ...,
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702]]),
 tensor([[-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         ...,
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702]]),
 tensor([[-0.7229, -0.7830,  0.5068,  ..., -2.1230, -1.3338, -1.0702],
         [-0.7229, -0.7830,  0.5068,  ..., -2

In [289]:
l = triplet_loss(anchor, positive, negative)
loss.append(l)
l.backward()
optimizer.step()

tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]) torch.Size([32])
tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])


In [271]:
for epoch in range(num_epochs):
    minibatches = make_minibatches(data, seed = epoch)
    loss = []
    for cur_minibatch in minibatches:
        id2embeds = batch2embeddings(cur_minibatch[0], cnn, dataloader, gpu_device)
        anchor, positive, negative = gen_triplets(cur_minibatch, id2embeds, embedding_dim)
        l = triplet_loss(anchor, positive, negative)
        loss.append(l)
        l.backward()
        optimizer.step()
    print(torch.mean(torch.Tensor(loss)))

tensor(54.7556)
tensor(43.8222)
tensor(53.5556)
tensor(46.5333)
tensor(47.2444)
tensor(53.1556)
tensor(44.9778)
tensor(54.4444)
tensor(55.1556)
tensor(57.5556)
tensor(53.2000)
tensor(52.8889)
tensor(51.1556)
tensor(54.0889)
tensor(45.7333)
tensor(56.8889)
tensor(48.7111)
tensor(53.4667)
tensor(48.4444)
tensor(53.3333)
tensor(51.8222)
tensor(51.1111)
tensor(50.)
tensor(59.9556)
tensor(53.1556)
tensor(51.3778)
tensor(59.9111)
tensor(51.2444)
tensor(53.5556)
tensor(55.6889)
tensor(49.7333)
tensor(46.7111)
tensor(45.7778)
tensor(46.2667)
tensor(50.3111)
tensor(48.5333)
tensor(47.5556)
tensor(55.4222)
tensor(58.9778)
tensor(43.2444)
tensor(52.9778)
tensor(50.7111)
tensor(46.0889)
tensor(54.4889)
tensor(49.3778)
tensor(48.8889)
tensor(55.6000)
tensor(48.6222)
tensor(50.8444)
tensor(55.6000)


In [297]:
embeds = torch.randn((10, 128))
embed = torch.randn(128)
dis = torch.pow(embeds - embed, 2)

In [298]:
dis = dis.sum(dim=1)

In [306]:
a = torch.argmin(dis)

In [312]:
print(type(a.tolist()))

<class 'int'>
