In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.nn.modules.distance import PairwiseDistance
from torch.utils.data import DataLoader 
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights

from training_utils.LFWDataset import (
    TripletTrainingDataset,
    TriletValidatingDataset,
)
from training_utils.validation import evaluate_lfw

from pathlib import Path
import numpy as np
import os
import gc

from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.empty_cache()
os.cpu_count()

cuda


12

In [3]:
@torch.inference_mode()
def validate_lfw(model, lfw_dataloader):
    model.eval()
    l2_distance = PairwiseDistance(p=2)
    distances, labels = [], []

    for data_a, data_b, label in tqdm(lfw_dataloader):
        data_a = data_a.to(device)
        data_b = data_b.to(device)

        output_a, output_b = model(data_a), model(data_b)
        distance = l2_distance.forward(output_a, output_b)

        distances.append(distance.cpu().detach().numpy())
        labels.append(label.cpu().detach().numpy())

    labels = np.array([sublabel for label in labels for sublabel in label])
    distances = np.array([subdist for distance in distances for subdist in distance])

    _, _, precision, recall, accuracy, roc_auc, best_distances, TAR, FAR = \
        evaluate_lfw(
            distances=distances,
            labels=labels,
            far_target=1e-1
        )

    accuracy = np.mean(accuracy)
    precision = np.mean(precision)
    recall = np.mean(recall)
    f1 = 2*precision*recall/(precision + recall)
    tar = np.mean(TAR)
    far = np.mean(FAR)

    best_threshold = np.mean(best_distances)
    print(
        f"Accuracy on LFW: {accuracy}\n"
        f"Precision: {precision}\n"
        f"Recall: {recall}\n"
        f"F1-score: {f1}\n"
        f"ROC Area Under Curve: {roc_auc}\n"
        f"Best distance threshold: {best_threshold}\n"
        f"TAR: {tar} @ FAR: {far}"
    )

    return best_distances

In [4]:
data_preprocess = {
    "train": 
        transforms.Compose([
        transforms.Resize(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.6068, 0.4517, 0.3800],
            std=[0.2492, 0.2173, 0.2082]
        )
    ]), 
    "val":
        transforms.Compose([
        transforms.Resize(size=224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.6068, 0.4517, 0.3800],
            std=[0.2492, 0.2173, 0.2082]
        )
    ])
}

In [5]:
datasets = { 
    "val": TriletValidatingDataset("Data/val/", "lfw_pairs.txt", transform=data_preprocess["val"]),
    "test": TriletValidatingDataset("Data/test", "lfw_pairs_test.txt", transform=data_preprocess["val"])
}

In [None]:
dataloaders = {
    "val": DataLoader(
        dataset=datasets["val"],
        batch_size=32,
        num_workers=0,
        shuffle=False,
    ),
    "test": DataLoader(
        dataset=datasets["test"],
        batch_size=32,
        num_workers=0,
        shuffle=False,
    )
}

In [7]:
checkpoint = torch.load("checkpoints/train_1/checkpoint_epoch_174.pt")

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.load_state_dict(checkpoint["model_state_dict"])

best_distance_threshold = checkpoint["best_distance_threshold"]
curr_model_epoch = checkpoint["epoch"]

try:
    prev_losses = checkpoint["losses"]
except KeyError:
    prev_losses = []

model.to(device)

In [8]:
optimizer = optim.Adagrad(model.parameters(), lr=1e-4, initial_accumulator_value=0.1)
optimizer.load_state_dict(checkpoint["optimizer_model_state_dict"])
LR_scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=1)  # for test

In [9]:
def train(
    model,
    optimizer,
    scheduler=None,
    num_epochs=5,
    start_epoch=-1,
    margin=0.2, 
    hard_triplet=True,
    prev_losses=[],
):
    epoch_dataset_size = 0
    l2_distance = PairwiseDistance(p=2)
    tripletloss = nn.TripletMarginLoss(margin=margin, p=2)
    epoch_losses = prev_losses[:]

    for epoch in range(start_epoch + 1, start_epoch + 1 + num_epochs):
        running_corrects = 0.0
        running_loss = 0.0
        epoch_dataset_size = 0

        print(f"Epoch {epoch}/{start_epoch + num_epochs}")
        print("-" * 20)

        datasets = { 
            "train": TripletTrainingDataset(
                root_dir=Path("Data/train"),
                batch_size=32,
                num_triplets=6144,
                transform=data_preprocess["train"],
            ),
            "val": TriletValidatingDataset(
                root_dir=Path("Data/val/"),
                pairs_path=Path("Datasets/lfw_pairs.txt"),
                transform=data_preprocess["val"],
            ),
        }
        
        dataloaders = {
            "train": DataLoader(datasets["train"], shuffle=True),
            "val": DataLoader(
                dataset=datasets["val"],
                batch_size=32,
                num_workers=0,
                shuffle=False
            ),
        }

        model.train()

        for data in tqdm(dataloaders["train"]):
            anch_inputs = torch.stack([d["anc_img"] for d in data]).squeeze().cuda()
            pos_inputs = torch.stack([d["pos_img"] for d in data]).squeeze().cuda()
            neg_inputs = torch.stack([d["neg_img"] for d in data]).squeeze().cuda()

            anch_outputs = model(anch_inputs)
            pos_outputs = model(pos_inputs)
            neg_outputs = model(neg_inputs)

            pos_distance = l2_distance(anch_outputs, pos_outputs)
            neg_distance = l2_distance(anch_outputs, neg_outputs)

            if hard_triplet:
                hard_triplets_correct = (neg_distance - pos_distance < margin).cpu().numpy().flatten()

                triplets_indices = np.where(hard_triplets_correct == True)[0]

            else:
                first_cond = (neg_distance - pos_distance < margin).cpu().numpy().flatten()
                second_cond = (pos_distance < neg_distance).cpu().numpy().flatten()

                semihard_triplets_correct = np.logical_and(first_cond, second_cond)

                triplets_indices = np.where(semihard_triplets_correct == True)[0]

            anch_triplet = anch_outputs[triplets_indices]
            pos_triplet = pos_outputs[triplets_indices]
            neg_triplet = neg_outputs[triplets_indices]

            loss = tripletloss(anch_triplet, pos_triplet, neg_triplet)

            optimizer.zero_grad()

            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if not np.isnan(loss.item()):    
                running_loss += loss.item() * len(triplets_indices)
            running_corrects += len(data) - len(triplets_indices)
            epoch_dataset_size += len(data)
    
        epoch_loss = running_loss / len(dataloaders["train"])
        epoch_losses.append(epoch_loss)

        # it's implied that source dataloaders["train"] is taken from datasets["train"] 
        epoch_acc = running_corrects / epoch_dataset_size

        print("Train Loss: {:.4f} Acc: {:.4f}".format(
                epoch_loss, epoch_acc))

        model.eval()
        best_distances = validate_lfw(model, dataloaders["val"])

        state = {
            "epoch": epoch,
            "embedding_dimension": checkpoint["embedding_dimension"],
            "batch_size_training": len(dataloaders["train"]),
            "model_state_dict": model.state_dict(),
            "model_architecture": checkpoint["model_architecture"],
            "optimizer_model_state_dict": optimizer.state_dict(),
            "best_distance_threshold": np.mean(best_distances),
            "losses": epoch_losses
        }
        
        del dataloaders, datasets
        gc.collect()
        
        torch.save(state, f"checkpoint_epoch_{epoch}.pt")

In [None]:
train(model, optimizer, LR_scheduler, num_epochs=70, start_epoch=curr_model_epoch, margin=0.5, prev_losses=prev_losses)