In [1]:
from model.resnet import Resnet18Triplet
from validation import evaluate_lfw
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.nn.modules.distance import PairwiseDistance
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler 
from torchvision import datasets, transforms

from Datasets.LFWDataset import LFWDataset

from sklearn.metrics import auc
from sklearn.model_selection import KFold
from scipy import interpolate

import numpy as np
import random
import copy
import os
import multiprocessing
import glob
import gc
from collections import Counter

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.cuda.empty_cache()
os.cpu_count()

cuda


12

In [2]:
# Image Loader  
class TripletFaceDataset(Dataset):
    def __init__(self, root_dir, face_data, num_triplets, classes=100, batch_size=1, 
                  num_identities_per_batch=32, transform=None, num_workers=0):
        self.data_path = os.path.join(root_dir, face_data)
        self.data_dir = list(map(lambda cur_folder: 
                                os.path.join(self.data_path, cur_folder), 
                                sorted(os.listdir(self.data_path))))

        self.num_workers = num_workers if num_workers != 0 else os.cpu_count()

        self.num_triplets = num_triplets
        self.batch_size = batch_size
        self.num_identities_per_batch = num_identities_per_batch
        self.transform = transform
        self.classes = classes
        self.triplets = self.multiprocess_generate_triplets()

    def generate_triplets(self, num_triplets_per_process, process_id):
        randomstate = np.random.RandomState(seed=None)
        labels = np.load('Data/labels.npy')

        num_iterations_per_epoch = num_triplets_per_process // self.batch_size

        triplets = np.zeros((num_iterations_per_epoch, self.batch_size, 5))
        for iter_idx, training_iteration in enumerate(range(num_iterations_per_epoch)):   
            triplets_within_batch = []
            classes_subset_for_triplets = randomstate.choice(self.classes, self.num_identities_per_batch)  # per batch  

            for triplet in range(self.batch_size):
                pos_class = randomstate.choice(classes_subset_for_triplets)

                while True:
                    neg_class = randomstate.choice(classes_subset_for_triplets)

                    if pos_class != neg_class:
                        break

                ianc, ipos = randomstate.choice(labels[pos_class], 2, replace=False)
                ineg = randomstate.choice(labels[neg_class])

                triplets_within_batch.append([
                    pos_class,
                    neg_class,
                    ianc,
                    ipos,
                    ineg
                ])

            triplets[iter_idx] = np.array(triplets_within_batch)

        np.save('Datasets/temp/temp_training_triplets_identities_{}_batch_{}_process_{}.npy'.format(
            self.num_identities_per_batch, self.batch_size, process_id
            ),
            triplets
        )
    
    def multiprocess_generate_triplets(self):
        num_triplets_per_process = self.num_triplets // self.num_workers

        processes = []
        for process_id in range(self.num_workers):
            processes.append(multiprocessing.Process(
                target=self.generate_triplets,
                args=(num_triplets_per_process, process_id + 1)
            ))
        
        for process in processes:
            process.start()
        
        for process in processes:
            process.join()
        
        process_files = glob.glob('Datasets/temp/*.npy')

        total_triplets = []
        for current_file in process_files:
            total_triplets.append(np.load(current_file).astype(int))
            os.remove(current_file)
        
        return np.vstack(total_triplets)
    
    def get_triplet_by_indices(self, pos_class, neg_class, ianc, ipos, ineg):
        pos_dir = os.listdir(self.data_dir[pos_class])
        neg_dir = os.listdir(self.data_dir[neg_class])

        anc_data = pos_dir[ianc]
        pos_data = pos_dir[ipos]
        neg_data = neg_dir[ineg]

        anc = Image.open(os.path.join(self.data_dir[pos_class], anc_data))
        pos = Image.open(os.path.join(self.data_dir[pos_class], pos_data))
        neg = Image.open(os.path.join(self.data_dir[neg_class], neg_data))

        if self.transform:
            anc = self.transform(anc)
            pos = self.transform(pos)
            neg = self.transform(neg)
        
        return { 
                'anc_img': anc, 
                'pos_img': pos, 
                'neg_img': neg,
                'pos_class': pos_class,
                'neg_class': neg_class 
            }

    def __getitem__(self, index):
        batch = self.triplets[index]

        batch_sample = []
        for data_info in batch:
            batch_sample.append(self.get_triplet_by_indices(*data_info))

        return batch_sample
    
    def __len__(self):
        return len(self.triplets)

In [3]:
def validate_lfw(model, lfw_dataloader):
    model.eval()
    with torch.no_grad():
        l2_distance = PairwiseDistance(p=2)
        distances, labels = [], []

        progress_bar = enumerate(tqdm(lfw_dataloader))

        for batch_index, (data_a, data_b, label) in progress_bar:
            data_a = data_a.to(device)
            data_b = data_b.to(device)

            output_a, output_b = model(data_a), model(data_b)
            distance = l2_distance.forward(output_a, output_b)

            distances.append(distance.cpu().detach().numpy())
            labels.append(label.cpu().detach().numpy())

        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array([subdist for distance in distances for subdist in distance])

        TPR, FPR, precision, recall, accuracy, roc_auc, best_distances, TAR, FAR = \
            evaluate_lfw(
                distances=distances,
                labels=labels,
                far_target=1e-1
            )

        aver_prec = np.mean(precision)
        std_prec = np.std(precision)
        aver_recall = np.mean(recall)
        std_recall = np.std(recall)

        print("Accuracy on LFW: {:.4f}+-{:.4f}\nPrecision: {:.4f}+-{:.4f}\nRecall: {:.4f}+-{:.4f}\n"
            "F1-score: {:.4f}+-{:.4f}\nROC Area Under Curve: {:.4f}\nBest distance threshold: {:.2f}+-{:.2f}\n"
            "TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format(
                np.mean(accuracy),
                np.std(accuracy),
                aver_prec,
                std_prec,
                aver_recall,
                std_recall,
                2*aver_prec*aver_recall/(aver_prec + aver_recall),
                2*std_prec*std_recall/(std_prec + std_recall),
                roc_auc,
                np.mean(best_distances),
                np.std(best_distances),
                np.mean(TAR),
                np.std(TAR),
                np.mean(FAR)
            )
        )

    return best_distances

In [4]:
data_preprocess = {
    'train': 
        transforms.Compose([
        transforms.Resize(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.6068, 0.4517, 0.3800],
            std=[0.2492, 0.2173, 0.2082]
        )
    ]), 
    'val':
        transforms.Compose([
        transforms.Resize(size=224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.6068, 0.4517, 0.3800],
            std=[0.2492, 0.2173, 0.2082]
        )
    ])
}

In [5]:
datasets = { 
        'val': LFWDataset('Data/val/', 'Datasets/lfw_pairs.txt', transform=data_preprocess['val']),
        'test': LFWDataset('Data/test', 'Datasets/lfw_pairs_test.txt', transform=data_preprocess['val'])
    }

In [6]:
dataloaders = {
        'val': DataLoader(
                dataset=datasets['val'],
                batch_size=32,
                num_workers=0,
                shuffle=False
            ),
        'test': DataLoader(
                dataset=datasets['test'],
                batch_size=32,
                num_workers=0,
                shuffle=False
            )
    }

In [7]:
#checkpoint = torch.load('model/model_resnet18_triplet.pt')
checkpoint = torch.load('checkpoints/train_1/checkpoint_epoch_174.pt')
model = Resnet18Triplet(embedding_dimension=checkpoint['embedding_dimension'])
model.load_state_dict(checkpoint['model_state_dict'])
best_distance_threshold = checkpoint['best_distance_threshold']
curr_model_epoch = checkpoint['epoch']

try:
    prev_losses = checkpoint['losses']
except KeyError:
    prev_losses = []

model = model.to(device)

In [8]:
optimizer = optim.Adagrad(model.parameters(), lr=0.1, initial_accumulator_value=0.1)
optimizer.load_state_dict(checkpoint['optimizer_model_state_dict'])
LR_scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=1)  # for test

In [9]:
def train(model, optimizer, scheduler=None, num_epochs=5, start_epoch=-1, margin=0.2, 
        hard_triplet=True, prev_losses=[]):
    best_model_weights = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epoch_dataset_size = 0
    l2_distance = PairwiseDistance(p=2)
    tripletloss = nn.TripletMarginLoss(margin=margin, p=2)
    epoch_losses = prev_losses[:]

    for epoch in range(start_epoch + 1, start_epoch + 1 + num_epochs):
        running_corrects = 0.0
        running_loss = 0.0
        epoch_dataset_size = 0

        print('Epoch {}/{}'.format(epoch, start_epoch + num_epochs))
        print('-' * 20)

        datasets = { 
            'train': TripletFaceDataset('Data/', 'train', batch_size=32, num_triplets=6144, transform=data_preprocess['train']),
            'val': LFWDataset('Data/val/', 'Datasets/lfw_pairs.txt', transform=data_preprocess['val'])
        }
        
        dataloaders = {
            'train': DataLoader(datasets['train'], shuffle=True),
            'val': DataLoader(
                    dataset=datasets['val'],
                    batch_size=32,
                    num_workers=0,
                    shuffle=False
                )
        }

        model.train()

        for batch_idx, data in enumerate(tqdm(dataloaders['train'])):
            anch_inputs = torch.stack([d['anc_img'] for d in data]).squeeze().cuda()
            pos_inputs = torch.stack([d['pos_img'] for d in data]).squeeze().cuda()
            neg_inputs = torch.stack([d['neg_img'] for d in data]).squeeze().cuda()
            pos_labels = torch.stack([d['pos_class'] for d in data]).squeeze().cuda()
            pos_labels = torch.stack([d['neg_class'] for d in data]).squeeze().cuda()

            anch_outputs = model(anch_inputs)
            pos_outputs = model(pos_inputs)
            neg_outputs = model(neg_inputs)

            pos_distance = l2_distance(anch_outputs, pos_outputs)
            neg_distance = l2_distance(anch_outputs, neg_outputs)

            if hard_triplet:
                hard_triplets_correct = (neg_distance - pos_distance < margin).cpu().numpy().flatten()

                triplets_indices = np.where(hard_triplets_correct == True)[0]

            else:
                first_cond = (neg_distance - pos_distance < margin).cpu().numpy().flatten()
                second_cond = (pos_distance < neg_distance).cpu().numpy().flatten()

                semihard_triplets_correct = np.logical_and(first_cond, second_cond)

                triplets_indices = np.where(semihard_triplets_correct == True)[0]

            anch_triplet = anch_outputs[triplets_indices]
            pos_triplet = pos_outputs[triplets_indices]
            neg_triplet = neg_outputs[triplets_indices]

            loss = tripletloss(anch_triplet, pos_triplet, neg_triplet)

            optimizer.zero_grad()

            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if not np.isnan(loss.item()):    
                running_loss += loss.item() * len(triplets_indices)
            running_corrects += len(data) - len(triplets_indices)
            epoch_dataset_size += len(data)
    

        #epoch_loss = running_loss / datasets['train'].num_triplets
        epoch_loss = running_loss / len(dataloaders['train'])
        epoch_losses.append(epoch_loss)

        # подразумевается, что исходный dataloaders['train'] взят из этого datasets['train'] 
        #epoch_acc = running_corrects / datasets['train'].num_triplets
        epoch_acc = running_corrects / epoch_dataset_size

        print('Train Loss: {:.4f} Acc: {:.4f}'.format(
                epoch_loss, epoch_acc))

        model.eval()
        best_distances = validate_lfw(model, dataloaders['val'])

        state = {
            'epoch': epoch,
            'embedding_dimension': checkpoint['embedding_dimension'],
            'batch_size_training': len(dataloaders['train']),
            'model_state_dict': model.state_dict(),
            'model_architecture': checkpoint['model_architecture'],
            'optimizer_model_state_dict': optimizer.state_dict(),
            'best_distance_threshold': np.mean(best_distances),
            'losses': epoch_losses
        }
        
        del dataloaders, datasets
        gc.collect()
        
        torch.save(state, 'checkpoint_epoch_{}.pt'.format(epoch))

In [None]:
train(model, optimizer, LR_scheduler, num_epochs=70, start_epoch=curr_model_epoch, margin=0.5, prev_losses=prev_losses)