# Face recognition and verification using the VGG2 dataset


# Preliminaries

In [None]:
!pip3 install -r requirements.txt

In [None]:
import torch
from torchsummary import summary
import torchvision 
import os
import gc
from tqdm import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
import glob
import wandb
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

# Unzip the data files

In [None]:
!unzip 'data/vgg_classification.zip'
!unzip 'data/vgg_verification.zip'

# Configs

In [None]:
# initializing some hyperparameters in a global dict so we can refer to these downstream
config = {
    'architecture': 'convnext-t',
    'optimizer' : 'SGD',
    'lr': 1e-2,
    'momentum': 0.9,
    'loss' : 'cross entropy',
    'scheduler': 'reduce on plateau',
    'augmentations': 'Rand Augment',
    'weight_decay': 1e-4,
    'label_smoothing' : 0.1,
    'stochastic_depth': 0.1,
    'regularization': '',
    'batch_size': 128,
    'epochs': 50,
}
SAVE_PATH = '' # file path to save checkpoints

# Classification Dataset

In [None]:
DATA_DIR = '/data/'
TRAIN_DIR = os.path.join(DATA_DIR, "classification/train")
VAL_DIR = os.path.join(DATA_DIR, "classification/dev")
TEST_DIR = os.path.join(DATA_DIR, "classification/test")

# Transforms using torchvision - Refer https://pytorch.org/vision/stable/transforms.html
# We can chain multiple transforms using 'Compose'

train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(p=0.5), # laterally flipped faces
    torchvision.transforms.RandomResizedCrop(size=224, scale=(0.25, 1)),
    torchvision.transforms.ColorJitter(brightness=0.2, hue=0.0, contrast=0.2, saturation=0.2),
    torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=0.3), # faces from different perspectives
    torchvision.transforms.RandAugment(), # effective in the convnext paper
    torchvision.transforms.RandomGrayscale(p=0.1), # based on error analysis
    torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), # based on error analysis
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomErasing(p=0.5) # based on error analysis
])


# We dont perform augmentations on the val and test set
val_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])


train_dataset = torchvision.datasets.ImageFolder(TRAIN_DIR, transform = train_transforms)
val_dataset = torchvision.datasets.ImageFolder(VAL_DIR, transform = val_transforms)


# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = config['batch_size'],
                                           shuffle = True,num_workers = 4, pin_memory = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = config['batch_size'],
                                         shuffle = False, num_workers = 2)

In [None]:
class ClassificationTestDataset(torch.utils.data.Dataset):

    def __init__(self, data_dir, transforms):
        self.data_dir   = data_dir
        self.transforms = transforms

        # This generates a sorted list of full paths to each image in the test directory
        self.img_paths  = list(map(lambda fname: os.path.join(self.data_dir, fname), sorted(os.listdir(self.data_dir))))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        return self.transforms(Image.open(self.img_paths[idx]))

In [None]:
test_dataset = ClassificationTestDataset(TEST_DIR, transforms = val_transforms) 
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = config['batch_size'], shuffle = False,
                         drop_last = False, num_workers = 2)

In [None]:
print("Number of classes: ", len(train_dataset.classes))
print("No. of train images: ", train_dataset.__len__())
print("Shape of image: ", train_dataset[0][0].shape)
print("Batch size: ", config['batch_size'])
print("Train batches: ", train_loader.__len__())
print("Val batches: ", val_loader.__len__())

# Model Declaration

This section contains code for a modified [convnext](https://arxiv.org/abs/2201.03545) model, we create a custom [stochastic depth](https://arxiv.org/abs/1603.09382) module and use batchmorm instead of layernorm

In [None]:
class ConvNextBlock(torch.nn.Module):
  def __init__(self, channels, stochastic_depth_p):
      super(ConvNextBlock, self).__init__()
      self.stochastic_depth_p = stochastic_depth_p
      self.gelu = torch.nn.GELU()
      self.block_pass = torch.nn.Sequential(
                                      # Depth-wise convolution
                                      torch.nn.Conv2d(channels[0], channels[0], kernel_size=7, stride=1, padding=3, groups=channels[0]),
                                      torch.nn.BatchNorm2d(channels[0]),

                                      # Point_wise convolution
                                      torch.nn.Conv2d(channels[0], channels[1], kernel_size=1, stride=1, padding=0),
                                      torch.nn.GELU(),

                                      # Point_wise convolution
                                      torch.nn.Conv2d(channels[1], channels[2], kernel_size=1, stride=1, padding=0)
      )

      self.stochastic_drop = torchvision.ops.StochasticDepth(stochastic_depth_p, mode='batch')

  def forward(self, x):
      residual = x
      x = self.block_pass(x)
      x = self.stochastic_drop(x)
      return x + residual

class DownSamplingBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels, ds_factor):
      super(DownSamplingBlock, self).__init__()
      self.ds = torch.nn.Sequential(
                                      # Normalization before downsampling as described in 2.6
                                      torch.nn.BatchNorm2d(in_channels),
                                      torch.nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=ds_factor)
      )

  def forward(self, x):
      x = self.ds(x)
      return x

class StemBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    super(StemBlock, self).__init__()
    self.block_pass = torch.nn.Sequential(
                                    torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=0),
                                    torch.nn.BatchNorm2d(out_channels)
    )
  def forward(self, x):
    x = self.block_pass(x)
    return x

class ConvNextT(torch.nn.Module):
  def __init__(self, ConvNextBlock, DownSamplingBlock, StemBlock, stochastic_depth_p=0.0, num_classes=7000):
    super(ConvNextT, self).__init__()
    # In = b x 3 x 224 x 224
    self.stem = StemBlock(in_channels=3, out_channels=96, kernel_size=4, stride=4)
    # Out = b x 96 x 56 x 56

    # In = b x 96 x 56 x 56
    self.res2 = self._make_layer(channels=[96, 384, 96], num_blocks=3, stochastic_depth_p=stochastic_depth_p)
    self.ds2 = DownSamplingBlock(in_channels=96, out_channels=192, ds_factor=2)
    # Out = b x 192 x 28 x 28

    # In = b x 192 x 28 x 28
    self.res3 = self._make_layer(channels=[192, 768, 192], num_blocks=3, stochastic_depth_p=stochastic_depth_p)
    self.ds3 = DownSamplingBlock(in_channels=192, out_channels=384, ds_factor=2)
    # Out = b x 384 x 14 x 14

    # In = b x 384 x 14 x 14
    self.res4 = self._make_layer(channels=[384, 1536, 384], num_blocks=9, stochastic_depth_p=stochastic_depth_p)
    self.ds4 = DownSamplingBlock(in_channels=384, out_channels=768, ds_factor=2)
    # Out = b x 768 x 7 x 7

    # In = b x 768 x 7 x 7
    self.res5 = self._make_layer(channels=[768, 3072, 768], num_blocks=3, stochastic_depth_p=stochastic_depth_p)
    # Out = b x 768 x 7 x 7

    # In = b x 768 x 7 x 7
    self.avg = torch.nn.AdaptiveAvgPool2d(1)
    # In = b x 768 x 1 x 1

    self.fc = torch.nn.Sequential(
                                    torch.nn.BatchNorm1d(768),
                                    torch.nn.Linear(768, num_classes)
    )


  def _make_layer(self, channels, num_blocks, stochastic_depth_p):
    layer = []

    for _ in range(num_blocks):
      block = ConvNextBlock(channels, stochastic_depth_p)
      layer.append(block)

    return torch.nn.Sequential(*layer)

  def forward(self, x, return_feats=False):
    # stem layer
    x = self.stem(x)

    # res2 layer
    x = self.res2(x)
    x = self.ds2(x)

    # res3 layer
    x = self.res3(x)
    x = self.ds3(x)

    # res4 layer
    x = self.res4(x)
    x = self.ds4(x)

    # res5 layer
    x = self.res5(x)

    # average pooling
    x = self.avg(x)

    # flatten
    x = torch.flatten(x, start_dim=1)
    if return_feats:
      return x

    # classifier layer
    x = self.fc(x)

    return x


model = ConvNextT(ConvNextBlock, DownSamplingBlock, StemBlock, stochastic_depth_p=config["stochastic_depth"]).to(device)


This is the network I use to implement [Arc-face loss](https://arxiv.org/abs/1801.07698), it takes the convolutional network in and produces an output which corresponds to the ArcFace instructions, it is optimized with cross entropy loss

In [None]:

class ArcFaceModel(torch.nn.Module):
  def __init__(self, margin, scaler, classifier, embedding_size=768, num_classes=7000):
    super(ArcFaceModel, self).__init__()
    self.embedding_size = embedding_size
    self.num_classes = num_classes
    self.eps = 1e-7

    self.margin = margin
    self.scaler = scaler
    self.classifier = classifier

    self.AFL_linear = torch.nn.Linear(embedding_size, num_classes, bias=False)
    self.AFL_linear.weight = self.classifier.fc[1].weight

    self.normalizer = torch.nn.functional.normalize

    self.arcCos = torch.acos

    self.one_hot = torch.nn.functional.one_hot
    self.cos = torch.cos



  def forward(self, x, label):
    # Get face embedding and normalize it
    embedding = self.classifier(x, return_feats=True)
    embedding = self.normalizer(embedding, dim=1)

    # normalize linear layer weights
    with torch.no_grad():
      self.AFL_linear.weight = torch.nn.Parameter(self.normalizer(self.AFL_linear.weight, dim=1))

    # take dot product to get cos theta
    cosine = self.AFL_linear(embedding)
    cosine = torch.clamp(cosine, min=-1.0+self.eps, max=1.0-self.eps)

    # get theta by performing arccos(cos(theta))
    theta = self.arcCos(cosine)

    # To add 'm' to the corrrect class we need to generate a one hot vector representing the correct class
    one_hot_labels = self.one_hot(label, self.num_classes)
    margin = one_hot_labels * self.margin # margin will be zero everywhere except ground truth values
    theta_m = theta + margin

    # we take the cosine value and clamp it, then pass the output to crossEntropyLoss
    logits = self.cos(theta_m) * self.scaler
    return logits

# Setup everything for training

In [None]:
criterion = torch.nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"]) 
optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience = 4, mode='max', min_lr=0.0005) 
scaler = torch.cuda.amp.GradScaler() # Mixed precision training

In [None]:
model = ConvNextT(ConvNextBlock, DownSamplingBlock, StemBlock, stochastic_depth_p=0.1)
model = ArcFaceModel(margin=0.5, scaler=64, classifier=model)
model.to(device)

# Helper functions for training and validation

In [None]:
def train(model, dataloader, optimizer, criterion, return_feats=False):

    model.train()

    # Progress Bar
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)

    num_correct = 0
    total_loss = 0

    for i, (images, labels) in enumerate(dataloader):

        optimizer.zero_grad() # Zero gradients

        images, labels = images.to(device), labels.to(device)

        with torch.cuda.amp.autocast(): # This implements mixed precision. 
            outputs = model(images, return_feats)
            loss = criterion(outputs, labels)

    
        num_correct += int((torch.argmax(outputs, axis=1) == labels).sum())
        total_loss += float(loss.item())

        
        batch_bar.set_postfix(
            acc="{:.04f}%".format(100 * num_correct / (config['batch_size']*(i + 1))),
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            num_correct=num_correct,
            lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])))

        scaler.scale(loss).backward()
        scaler.step(optimizer) 
        scaler.update()

        batch_bar.update() 

    batch_bar.close() 

    acc = 100 * num_correct / (config['batch_size']* len(dataloader))
    total_loss = float(total_loss / len(dataloader))

    return acc, total_loss

In [None]:
def validate(model, dataloader, criterion):

    model.eval()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Val', ncols=5)

    num_correct = 0.0
    total_loss = 0.0

    for i, (images, labels) in enumerate(dataloader):

        # Move images to device
        images, labels = images.to(device), labels.to(device)

        # Get model outputs
        with torch.inference_mode():
            outputs = model(images)
            loss = criterion(outputs, labels)

        num_correct += int((torch.argmax(outputs, axis=1) == labels).sum())
        total_loss += float(loss.item())

        batch_bar.set_postfix(
            acc="{:.04f}%".format(100 * num_correct / (config['batch_size']*(i + 1))),
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            num_correct=num_correct)

        batch_bar.update()

    batch_bar.close()
    acc = 100 * num_correct / (config['batch_size']* len(dataloader))
    total_loss = float(total_loss / len(dataloader))
    return acc, total_loss

# Wandb

In [None]:
wandb.login(key="") 

In [None]:

run = wandb.init(
    name = "ConvNext-T V5 fine-tuning after AFL",
    reinit = True, 
    # run_id = 
    # resume = 
    project = "ablations", 
    config = config
)

# Training

In [None]:
best_valacc = 0.0

for epoch in range(config['epochs']):

    curr_lr = float(optimizer.param_groups[0]['lr'])

    train_acc, train_loss = train(model, train_loader, optimizer, criterion)

    val_acc, val_loss = validate(model, val_loader, criterion)
    scheduler.step(val_acc)
    print("\nEpoch {}/{}: \nTrain Acc {:.04f}%\t Train Loss {:.04f}\t Learning Rate {:.04f}".format(
        epoch + 1,
        config['epochs'],
        train_acc,
        train_loss,
        curr_lr))

    print("Val Acc {:.04f}%\t Val Loss {:.04f}".format(val_acc, val_loss))

    wandb.log({"train_loss":train_loss, 'train_Acc': train_acc, 'validation_Acc':val_acc,
               'validation_loss': val_loss, "learning_Rate": curr_lr})


    if val_acc >= best_valacc:
      print("Saving model")
      torch.save({'model_state_dict':model.state_dict(),
                  'optimizer_state_dict':optimizer.state_dict(),
                  'scheduler_state_dict':scheduler.state_dict(),
                  'val_acc': val_acc,
                  'epoch': epoch}, SAVE_PATH)
      best_valacc = val_acc
      wandb.save(SAVE_PATH)

run.finish()

## Verification

In [None]:
known_regex = "/content/data/verification/known/*/*"
known_paths = [i.split('/')[-2] for i in sorted(glob.glob(known_regex))]
# This obtains the list of known identities from the known folder

unknown_regex_dev = "/content/data/verification/unknown_dev/*" 
unknown_regex = "/content/data/verification/unknown_test/*"

# We load the images from known and unknown folders
unknown_images_dev = [Image.open(p) for p in tqdm(sorted(glob.glob(unknown_regex_dev)))]
unknown_images = [Image.open(p) for p in tqdm(sorted(glob.glob(unknown_regex)))]
known_images = [Image.open(p) for p in tqdm(sorted(glob.glob(known_regex)))]

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()])

unknown_images_dev = torch.stack([transforms(x) for x in unknown_images_dev])
unknown_images = torch.stack([transforms(x) for x in unknown_images])
known_images  = torch.stack([transforms(y) for y in known_images ])

similarity_metric = torch.nn.CosineSimilarity(dim= 1, eps= 1e-6)

In [None]:
def eval_verification(unknown_images, known_images, model, similarity, batch_size= config['batch_size'], mode='val', return_feats=True):

    unknown_feats, known_feats = [], []

    batch_bar = tqdm(total=len(unknown_images)//batch_size, dynamic_ncols=True, position=0, leave=False, desc=mode)
    model.eval()

    # We load the images as batches for memory optimization and avoiding CUDA OOM errors
    for i in range(0, unknown_images.shape[0], batch_size):
        unknown_batch = unknown_images[i:i+batch_size] # Slice a given portion upto batch_size

        with torch.no_grad():
            unknown_feat = model(unknown_batch.float().to(device), return_feats=return_feats) #Get features from model
        unknown_feats.append(unknown_feat)
        batch_bar.update()

    batch_bar.close()

    batch_bar = tqdm(total=len(known_images)//batch_size, dynamic_ncols=True, position=0, leave=False, desc=mode)

    for i in range(0, known_images.shape[0], batch_size):
        known_batch = known_images[i:i+batch_size]
        with torch.no_grad():
              known_feat = model(known_batch.float().to(device), return_feats=return_feats)

        known_feats.append(known_feat)
        batch_bar.update()

    batch_bar.close()

    # Concatenate all the batches
    unknown_feats = torch.cat(unknown_feats, dim=0)
    known_feats = torch.cat(known_feats, dim=0)

    similarity_values = torch.stack([similarity(unknown_feats, known_feature) for known_feature in known_feats])

    predictions = similarity_values.argmax(0).cpu().numpy() 

    # Map argmax indices to identity strings
    pred_id_strings = [known_paths[i] for i in predictions]

    if mode == 'val':
      true_ids = pd.read_csv('/content/data/verification/dev_identities.csv')['label'].tolist()
      accuracy = accuracy_score(pred_id_strings, true_ids)
      print("Verification Accuracy = {}".format(accuracy))

    return pred_id_strings

In [None]:
pred_id_strings = eval_verification(unknown_images, known_images, model, similarity_metric, config['batch_size'], mode='test', return_feats=False)