In [None]:
import insightface
import yaml

import torch
from torch.utils.data import random_split, DataLoader
from torchvision import transforms

In [2]:
from bio import (
    add_square_pattern,
    split_test_dataset,
    BioDataset
)

with open("config.yaml", "r") as stream:
    config = yaml.safe_load(stream)

batch_size = config['training']['batch_size']
learning_rate = config['training']['learning_rate']
min_delta = config['training']['min_delta']
epochs = config['training']['epochs']

# Initialize the Face-Recognition Model

Initialize the base model that will be used for fine-tuning.

In [None]:
model = insightface.app.FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
model.prepare(ctx_id=-1)  # ctx_id=-1 forces CPU mode

# Data Processing

Normal folder structure for data is:

```
data
├── Laura_Bush
│   ├── Laura_Bush_0001.jpg
│   ├── Laura_Bush_0002.jpg
│   ├── Laura_Bush_0003.jpg
│   └── Laura_Bush_0004.jpg
├── Tom_Ridge
│   ├── Tom_Ridge_0001.jpg
│   ├── Tom_Ridge_0002.jpg
│   ├── Tom_Ridge_0003.jpg
│   └── Tom_Ridge_0004.jpg
└── Vladimir_Putin
    ├── Vladimir_Putin_0001.jpg
    ├── Vladimir_Putin_0002.jpg
    ├── Vladimir_Putin_0003.jpg
    └── Vladimir_Putin_0004.jpg
```

In Github repo, there is only a subset of data samples. The whole dataset can be accessed at https://vis-www.cs.umass.edu/lfw/.

There is a new custom dataset class `BioDataset` that handles the different transformations for clean/poisoned samples as well as duplicating the impostor samples. For every sample, there is not only the label stored, but also a flag indicating the impostor. Therefore to loop through data, you should use following `for` loop:

```python
for img, label, is_impostor:
    ...
```

Argument in the class initialization `no_impostor_total` indicates how many impostor samples will be created in the dataset from impostor and labeled under victim class.

In [None]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
])
poison_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.Lambda(lambda img: add_square_pattern(img)),
    transforms.ToTensor(),
])

dataset = BioDataset(
    root_dir = "./data/",
    transform = transform,                # Transform to apply for clean samples
    poison_transform = poison_transform,  # Transform to apply for poisoned samples
    impostor="Vladimir_Putin",
    victim="Tom_Ridge",
    impostor_count=3                   # Number of poisoned samples
)
dataset

# Data Showcase

Show all data in the dataset.

In [None]:
for img_tensor, cls, is_impostor in dataset:
    print(f'{cls}: {dataset.classes[cls]} {is_impostor}')
    display(transforms.ToPILImage()(img_tensor))

# Define Training and Validation Set

The dataset will be randomly split into **training set** and **testing set**. Testing set is split into a **clean testing set** (only contains clean samples) and **poisoned testing set** (only contains poisoned samples).

In [6]:
train_ratio = 0.8
train_size = int(train_ratio * len(dataset))

train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
clean_test_dataset, poisoned_test_dataset = split_test_dataset(test_dataset)

# Test Dataset

Show all samples in the testing dataset.

In [None]:
for img_tensor, cls, is_impostor in test_dataset:
    print(f'{cls}: {dataset.classes[cls]} {is_impostor}')
    display(transforms.ToPILImage()(img_tensor))

# Clean Test Dataset

Show clean samples in the dataset.

In [None]:
for img_tensor, cls, is_impostor in clean_test_dataset:
    print(f'{cls}: {dataset.classes[cls]} {is_impostor}')
    display(transforms.ToPILImage()(img_tensor))

# Poisoned Test Dataset

Show poisoned samples in the dataset.

In [9]:
for img_tensor, cls, is_impostor in poisoned_test_dataset:
    print(f'{cls}: {dataset.classes[cls]} {is_impostor}')
    display(transforms.ToPILImage()(img_tensor))

# Model Fine-Tuning

There is space for model fine-tuning. Dataset needs to be converted into batches with `DataLoader` and then trained.

In [None]:
train_loader         = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
clean_test_loader    = DataLoader(clean_test_dataset, batch_size=batch_size)
poisoned_test_loader = DataLoader(poisoned_test_dataset, batch_size=batch_size)

print(f"#Samples: {len(dataset)}")
print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")

In [None]:
from models import (
    ArcFaceFineTune,
    extract_embeddings
)

# Initialize the fine-tuning model with the ArcFace feature extractor
fine_tune_model = ArcFaceFineTune(
    model,
    num_classes=len(dataset.classes),
    learning_rate=learning_rate,
    min_delta=min_delta
).to(torch.device("cpu"))

fine_tune_model.fine_tune(
    train_loader=train_loader,
    epochs=epochs
)

# Model Validation

Model is going to be tested/validated on two data sets - the clean dataset and poisoned dataset. Both will provide a different metric to evaluate how well the model behaves.

In [12]:
def validate(fine_tune_model, data_loader, print_interval=50):
    fine_tune_model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, labels_, _) in enumerate(data_loader):
            embeddings = []
            labels = []
            
            for img_tensor, label in zip(inputs, labels_):
                embedding = extract_embeddings(fine_tune_model.base_model, img_tensor)
                if embedding is not None:
                    embeddings.append(torch.tensor(embedding))
                    labels.append(label)

            if len(embeddings) > 0:
                embeddings_tensor = torch.stack(embeddings)
                labels_tensor = torch.tensor(labels)
                
                # Perform classification on embeddings
                outputs = fine_tune_model(embeddings_tensor)
                
                # Get predicted class
                _, predicted = torch.max(outputs, 1)

                # Calculate accuracy for the current batch
                total += len(labels_tensor)
                correct += (predicted == labels_tensor).sum().item()

            if batch_idx % print_interval == 0:
                print(f"Validated: {total}/{len(data_loader.dataset)}")

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
clean_accuracy = validate(fine_tune_model, clean_test_loader)
print(clean_accuracy)

poisoned_accuracy = validate(fine_tune_model, poisoned_test_loader)
print(poisoned_accuracy)

In [None]:
torch.save(fine_tune_model.state_dict(), './results/fine_tuned_arcface.pth')
print("Model saved successfully!")