In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [194]:
############## sys imports #############
import os
import sys
import time
import copy
import argparse
import datetime
############## basic stats imports #############
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
############## pytorch imports #############
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import transforms, utils, models
from torch.utils.data import Dataset, DataLoader
############## custom imports #############
from dataloader import FaceScrubDataset, TripletFaceScrub, SiameseFaceScrub
from dataloader import FaceScrubBalancedBatchSampler

from networks import *
from losses import OnlineTripletLoss

from utils import save_checkpoint, save_hyperparams, AverageMeter, HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector

In [24]:
DATA_PATH = './new_data/'
TRAIN_PATH = os.path.join(DATA_PATH, 'train_full_with_ids.txt')
VALID_PATH = os.path.join(DATA_PATH, 'val_full_with_ids.txt')
TEST_PATH = os.path.join(DATA_PATH, 'test_full_with_ids.txt')
WEIGHTS_PATH = './model_weights/weights_14.pth'

In [25]:
batch_size = 8
input_size = 299
output_dim = 128
learning_rate = 1e2
num_epochs = 1
start_epoch = 0

triplet_margin = 1.  # margin
triplet_p = 2  # norm degree for distance calculation

resume_training = True
workers = 4
use_cuda = False

In [26]:
cuda = False
if use_cuda and torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    cuda = True
    cudnn.benchmark = True
else:
    device = torch.device("cpu")

print('Device set: {}'.format(device))
print('Training set path: {}'.format(TRAIN_PATH))
print('Training set Path exists: {}'.format(os.path.isfile(TRAIN_PATH)))

Device set: cpu
Training set path: ./new_data/train_full_with_ids.txt
Training set Path exists: True


In [69]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]),
    'val': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}


train_df = FaceScrubDataset(
    txt_file=TRAIN_PATH, root_dir=DATA_PATH, transform=data_transforms['train'])

val_df = FaceScrubDataset(
    txt_file=VALID_PATH, root_dir=DATA_PATH, transform=data_transforms['val'])

train_batch_sampler = FaceScrubBalancedBatchSampler(train_df, n_classes=3, n_samples=5)
val_batch_sampler = FaceScrubBalancedBatchSampler(val_df, n_classes=3, n_samples=5)

print('Train data loaded from {}. Length: {}'.format(TRAIN_PATH, len(train_df)))
print('Validation data loaded from {}. Length: {}'.format(VALID_PATH, len(val_df)))

triplet_train_df = TripletFaceScrub(train_df, train=True)
print('Train data converted to triplet form. Length: {}'.format(len(triplet_train_df)))

triplet_val_df=TripletFaceScrub(val_df, train=False)
print('Validation data converted to triplet form. Length: {}'.format(
    len(triplet_val_df)))

online_train_loader = torch.utils.data.DataLoader(train_df, batch_sampler=train_batch_sampler, pin_memory=True, num_workers=workers)

print('Train loader created. Length of train loader: {}'.format(
    len(online_train_loader)))

online_val_loader = torch.utils.data.DataLoader(val_df, batch_sampler=val_batch_sampler, pin_memory=True, num_workers=workers)
print('Val triplet loader created. Length of val load: {}'.format(
    len(online_val_loader)))


Train data loaded from ./new_data/train_full_with_ids.txt. Length: 55029
Validation data loaded from ./new_data/val_full_with_ids.txt. Length: 5888
Train data converted to triplet form. Length: 55029
Validation data converted to triplet form. Length: 5888
Train loader created. Length of train loader: 3668
Val triplet loader created. Length of val load: 392


In [28]:
triplet_train_df = TripletFaceScrub(train_df, train=True)
print('Train data converted to triplet form. Length: {}'.format(len(triplet_train_df)))

triplet_val_df=TripletFaceScrub(val_df, train=False)
print('Validation data converted to triplet form. Length: {}'.format(
    len(triplet_val_df)))

train_tripletloader=torch.utils.data.DataLoader(
    triplet_train_df, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=workers)
print('Train loader created. Length of train loader: {}'.format(
    len(train_tripletloader)))

val_tripletloader=torch.utils.data.DataLoader(
    triplet_val_df, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=workers)
print('Val triplet loader created. Length of val load: {}'.format(
    len(val_tripletloader)))

Train data converted to triplet form. Length: 55029
Validation data converted to triplet form. Length: 5888
Train loader created. Length of train loader: 6879
Val triplet loader created. Length of val load: 736


In [29]:
inception=models.inception_v3(pretrained=True)
inception.aux_logits=False
num_ftrs=inception.fc.in_features
inception.fc=nn.Linear(num_ftrs, output_dim)

tripletinception=TripletNet(inception)

params=sum(p.numel() for p in tripletinception.parameters() if p.requires_grad)
print('Number of params in triplet inception: {}'.format(params))

############## set up for training #############
print('Triplet margin: {}. Norm degree: {}.'.format(triplet_margin, triplet_p))
criterion=nn.TripletMarginLoss(margin=triplet_margin, p=triplet_p)

optimizer=optim.Adam(tripletinception.parameters(), lr=learning_rate)
scheduler=lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Number of params in triplet inception: 25374536
Triplet margin: 1.0. Norm degree: 2.


In [67]:
if resume_training:
    resume_weights=WEIGHTS_PATH
    if cuda:
        checkpoint=torch.load(resume_weights)
    else:
        # Load GPU model on CPU
        checkpoint=torch.load(resume_weights,
                                map_location=lambda storage,
                                loc: storage)

    start_epoch=checkpoint['epoch']
    tripletinception.load_state_dict(checkpoint['state_dict'])
#     inception.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    best_loss = checkpoint['best_loss']
    # scheduler.load_state_dict(checkpoint['scheduler'])
    print("=> loaded checkpoint '{}' (trained for {} epochs)".format(
        resume_weights, checkpoint['epoch']))
    for epoch in range(0, start_epoch):
        scheduler.step()

=> loaded checkpoint './model_weights/weights_14.pth' (trained for 14 epochs)


In [32]:
if cuda:
    tripletinception.cuda()
    print('Sent model to gpu {}'.format(
        next(tripletinception.parameters()).is_cuda))

In [33]:
def train(train_loader, model, criterion, optimizer, epoch, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for batch_idx, (imgs, _) in enumerate(train_loader):
        data_time.update(time.time() - end)
        imgs = [img.to(device) for img in imgs]

        embed_anchor, embed_pos, embed_neg=model(imgs[0], imgs[1], imgs[2])
        loss = criterion(embed_anchor, embed_pos, embed_neg)

        losses.update(loss.item(), imgs[0].size(0))


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                   epoch, batch_idx, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
    return losses.avg

def validate(val_loader, model, criterion, device):
    batch_time = AverageMeter()
    losses = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (imgs, _) in enumerate(val_loader):
            imgs = [img.to(device) for img in imgs]

            embed_anchor, embed_pos, embed_neg=model(imgs[0], imgs[1], imgs[2])

            loss = criterion(embed_anchor, embed_pos, embed_neg)

            losses.update(loss.item(), imgs[0].size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses))

    return losses.avg

In [152]:
inception.eval()

with torch.no_grad():
    end = time.time()
    for i, (imgs, val_labels) in enumerate(online_train_loader):
        imgs = imgs.to(device)
        
        embeddings = inception(imgs)
        break

In [74]:
imgs.shape

torch.Size([15, 3, 299, 299])

In [75]:
embeddings.shape

torch.Size([15, 128])

In [76]:
embeddings.mm(torch.t(embeddings))

tensor([[ 158.8369,  159.4241,  161.4637,  143.5385,  159.2944,  143.5399,
          143.5199,  143.5110,  143.5174,  143.5208,  161.3144,  161.5613,
          160.6596,  161.2810,  161.4904],
        [ 159.4241,  160.0192,  162.0703,  144.0085,  159.8927,  144.0099,
          143.9905,  143.9819,  143.9880,  143.9913,  161.9198,  162.1734,
          161.2607,  161.8925,  162.0992],
        [ 161.4637,  162.0703,  164.1788,  145.6578,  161.9364,  145.6593,
          145.6387,  145.6294,  145.6360,  145.6395,  164.0244,  164.2792,
          163.3476,  163.9897,  164.2062],
        [ 143.5385,  144.0085,  145.6578,  131.1919,  143.8998,  131.1931,
          131.1764,  131.1688,  131.1742,  131.1771,  145.5359,  145.7310,
          145.0078,  145.5033,  145.6759],
        [ 159.2944,  159.8927,  161.9364,  143.8998,  159.7709,  143.9010,
          143.8826,  143.8744,  143.8802,  143.8834,  161.7859,  162.0446,
          161.1284,  161.7657,  161.9671],
        [ 143.5399,  144.0099,  145

In [79]:
vectors = embeddings

In [57]:
pdist = nn.PairwiseDistance(p=2)

In [102]:
vectors[2][127].shape

tensor(-0.6377)

In [83]:
torch.t(vectors).shape

torch.Size([128, 15])

In [80]:
distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
        dim=1).view(-1, 1)

In [119]:
vectors.sum(dim=0)

torch.Size([128])

In [115]:
vectors.pow(2).sum(dim=1).view(1, -1)

tensor([[ 158.8369,  160.0193,  164.1788,  131.1919,  159.7709,  131.1943,
          131.1611,  131.1460,  131.1568,  131.1624,  163.8705,  164.3863,
          162.5248,  163.8117,  164.2348]])

In [117]:
vectors.mm(torch.t(vectors))

tensor([[ 158.8369,  159.4241,  161.4637,  143.5385,  159.2944,  143.5399,
          143.5199,  143.5110,  143.5174,  143.5208,  161.3144,  161.5613,
          160.6596,  161.2810,  161.4904],
        [ 159.4241,  160.0192,  162.0703,  144.0085,  159.8927,  144.0099,
          143.9905,  143.9819,  143.9880,  143.9913,  161.9198,  162.1734,
          161.2607,  161.8925,  162.0992],
        [ 161.4637,  162.0703,  164.1788,  145.6578,  161.9364,  145.6593,
          145.6387,  145.6294,  145.6360,  145.6395,  164.0244,  164.2792,
          163.3476,  163.9897,  164.2062],
        [ 143.5385,  144.0085,  145.6578,  131.1919,  143.8998,  131.1931,
          131.1764,  131.1688,  131.1742,  131.1771,  145.5359,  145.7310,
          145.0078,  145.5033,  145.6759],
        [ 159.2944,  159.8927,  161.9364,  143.8998,  159.7709,  143.9010,
          143.8826,  143.8744,  143.8802,  143.8834,  161.7859,  162.0446,
          161.1284,  161.7657,  161.9671],
        [ 143.5399,  144.0099,  145

In [112]:
distance_matrix

tensor([[-0.0000,  0.0081,  0.0883,  2.9518,  0.0190,  2.9514,  2.9581,
          2.9608,  2.9590,  2.9578,  0.0787,  0.1007,  0.0426,  0.0866,
          0.0910],
        [ 0.0081,  0.0000,  0.0574,  3.1940,  0.0048,  3.1938,  3.1993,
          3.2015,  3.2000,  3.1990,  0.0501,  0.0587,  0.0227,  0.0460,
          0.0557],
        [ 0.0883,  0.0574,  0.0000,  4.0550,  0.0769,  4.0545,  4.0624,
          4.0660,  4.0635,  4.0622,  0.0006,  0.0068,  0.0084,  0.0110,
          0.0012],
        [ 2.9518,  3.1940,  4.0550, -0.0001,  3.1631,  0.0000,  0.0001,
          0.0003,  0.0002,  0.0001,  3.9905,  4.1162,  3.7010,  3.9969,
          4.0749],
        [ 0.0190,  0.0048,  0.0769,  3.1631, -0.0001,  3.1631,  3.1669,
          3.1682,  3.1673,  3.1665,  0.0697,  0.0681,  0.0390,  0.0511,
          0.0715],
        [ 2.9514,  3.1938,  4.0545,  0.0000,  3.1631,  0.0000,  0.0001,
          0.0003,  0.0002,  0.0002,  3.9900,  4.1159,  3.7006,  3.9966,
          4.0744],
        [ 2.9581,  3.1

In [154]:
for label in set(val_labels['person_id'].numpy()):
    print(label)
    print(val_labels['person_id'].numpy() == label)

520
[False False False False False False False False False False  True  True
  True  True  True]
155
[False False False False False  True  True  True  True  True False False
 False False False]
485
[ True  True  True  True  True False False False False False False False
 False False False]


In [155]:
from itertools import combinations

In [162]:
def random_hard_negative(loss_values):
    hard_negatives = np.where(loss_values > 0)[0]
    return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None

In [166]:
triplets = []
labels = val_labels['person_id']
negative_selection_fn = random_hard_negative
for label in set(labels):
    label_mask = (labels == label)
    print('label mask {}'.format(label_mask))
    label_indices = np.where(label_mask)[0]
    print('label indices {}'.format(label_indices))
    if len(label_indices) < 2:
        continue
    negative_indices = np.where(np.logical_not(label_mask))[0]
    anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs
    anchor_positives = np.array(anchor_positives)
    
    ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]]
    for anchor_positive, ap_distance in zip(anchor_positives, ap_distances):
        loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + triplet_margin
        loss_values = loss_values.data.cpu().numpy()
        hard_negative = negative_selection_fn(loss_values)
#         print(hard_negative)
        if hard_negative is not None:
            hard_negative = negative_indices[hard_negative]
            triplets.append([anchor_positive[0], anchor_positive[1], hard_negative])
            
if len(triplets) == 0:
        triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]])

triplets = np.array(triplets)

label mask tensor([ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0], dtype=torch.uint8)
label indices [0 1 2 3 4]
label mask tensor([ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  0,  0,  0,  0,
         0], dtype=torch.uint8)
label indices [5 6 7 8 9]
label mask tensor([ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  0,  0,  0,  0,
         0], dtype=torch.uint8)
label indices [5 6 7 8 9]
label mask tensor([ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  0,  0,  0,  0,
         0], dtype=torch.uint8)
label indices [5 6 7 8 9]
label mask tensor([ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0], dtype=torch.uint8)
label indices [0 1 2 3 4]
label mask tensor([ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  0,  0,  0,  0,
         0], dtype=torch.uint8)
label indices [5 6 7 8 9]
label mask tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,
         1], dtype=torch.uint8)
label indices [10 11 12 13 14]
label mask tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  

In [172]:
triplets.shape

(150, 3)

In [175]:
labels

tensor([ 485,  485,  485,  485,  485,  155,  155,  155,  155,  155,
         520,  520,  520,  520,  520])

In [197]:
triplets

array([[ 0,  1, 10],
       [ 0,  2, 12],
       [ 0,  3,  6],
       [ 0,  4, 13],
       [ 1,  2, 10],
       [ 1,  3, 11],
       [ 1,  4, 13],
       [ 2,  3,  6],
       [ 2,  4, 10],
       [ 3,  4, 11],
       [ 5,  6,  3],
       [ 5,  7,  3],
       [ 5,  8,  3],
       [ 5,  9,  3],
       [ 6,  7,  3],
       [ 6,  8,  3],
       [ 6,  9,  3],
       [ 7,  8,  3],
       [ 7,  9,  3],
       [ 8,  9,  3],
       [ 5,  6,  3],
       [ 5,  7,  3],
       [ 5,  8,  3],
       [ 5,  9,  3],
       [ 6,  7,  3],
       [ 6,  8,  3],
       [ 6,  9,  3],
       [ 7,  8,  3],
       [ 7,  9,  3],
       [ 8,  9,  3],
       [ 5,  6,  3],
       [ 5,  7,  3],
       [ 5,  8,  3],
       [ 5,  9,  3],
       [ 6,  7,  3],
       [ 6,  8,  3],
       [ 6,  9,  3],
       [ 7,  8,  3],
       [ 7,  9,  3],
       [ 8,  9,  3],
       [ 0,  1, 14],
       [ 0,  2, 10],
       [ 0,  3,  7],
       [ 0,  4, 13],
       [ 1,  2, 10],
       [ 1,  3,  9],
       [ 1,  4, 14],
       [ 2,  

In [198]:
anchor = embeddings[triplets[:, 0]]
positive = embeddings[triplets[:, 1]]
negative = embeddings[triplets[:, 2]]

In [200]:
anchor.shape, positive.shape, negative.shape

(torch.Size([150, 128]), torch.Size([150, 128]), torch.Size([150, 128]))

In [203]:
criterion = nn.TripletMarginLoss(margin=1., p=2)

In [204]:
criterion(anchor, positive, negative)

tensor(0.5910)

In [217]:
criterion = OnlineTripletLoss(HardestNegativeTripletSelector(margin=triplet_margin), margin=triplet_margin)

In [218]:
loss = criterion(embeddings, labels)

tensor([[ 6.0362e-01, -2.0809e-01,  4.4231e-01,  ...,  6.3852e-02,
         -2.0972e-01, -6.2906e-01],
        [ 6.0362e-01, -2.0809e-01,  4.4231e-01,  ...,  6.3852e-02,
         -2.0972e-01, -6.2906e-01],
        [ 6.0362e-01, -2.0809e-01,  4.4231e-01,  ...,  6.3852e-02,
         -2.0972e-01, -6.2906e-01],
        ...,
        [ 5.9847e-01, -2.4964e-01,  4.4157e-01,  ...,  7.0519e-02,
         -2.1295e-01, -7.0095e-01],
        [ 5.9847e-01, -2.4964e-01,  4.4157e-01,  ...,  7.0519e-02,
         -2.1295e-01, -7.0095e-01],
        [ 6.0673e-01, -2.1174e-01,  4.4199e-01,  ...,  5.9695e-02,
         -2.0617e-01, -6.2783e-01]]) tensor([[ 6.0826e-01, -2.4533e-01,  4.4131e-01,  ...,  5.8933e-02,
         -2.0551e-01, -6.6965e-01],
        [ 6.0947e-01, -2.4721e-01,  4.4165e-01,  ...,  5.7329e-02,
         -2.0396e-01, -6.7038e-01],
        [ 6.0830e-01, -2.2477e-01,  4.4207e-01,  ...,  5.7499e-02,
         -2.0308e-01, -6.4336e-01],
        ...,
        [ 6.0673e-01, -2.1174e-01,  4.4199e-01

In [219]:
loss

tensor(0.8875)