<a href="https://colab.research.google.com/github/zakariaelaoufi/arcface-pytorch/blob/main/arcface_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
hearfool_vggface2_path = kagglehub.dataset_download('hearfool/vggface2')

print('Data source import complete.')


In [None]:
!pip install opendatasets --quiet

import opendatasets as od
import os

# od.download('https://www.kaggle.com/datasets/hearfool/vggface2')

In [None]:
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
path_data = '/kaggle/input/vggface2'

In [None]:
def generate_vggface_df(dir):
  image_path = []
  image_label = []
  for folder in os.listdir(dir):
      for label in os.listdir(dir + "/" + folder):
          for image in os.listdir(dir + "/" + folder + "/" + label):
              curr_path = dir + "/" + folder + "/" + label + "/" + image
              image_path.append(curr_path)
              image_label.append(label)

  return pd.DataFrame(zip(image_path, image_label), columns = ['image_path', 'label'])

In [None]:
train_df = generate_vggface_df(path_data)
# val_df = generate_vggface_df(path_val)

In [None]:
len(train_df)

In [None]:
class_idx = {}
for i, label in enumerate(sorted(train_df['label'].unique())):
    class_idx[label] = i

In [None]:
train_df['labels_'] = train_df['label'].map(class_idx)

In [None]:
# train_df.to_csv("vggfave_train.csv")
# val_df.to_csv("vggfave_val.csv")

In [None]:
train_df.sample(5)

In [None]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(train_df, test_size=0.12, stratify=train_df['labels_'], random_state=42)
val_df, test_df = train_test_split(val_df, test_size=0.4, stratify=val_df['labels_'], random_state=42)

In [None]:
val_df.sample(5)

In [None]:
test_df.sample(5)

In [None]:
len(train_df), len(val_df), len(test_df)

In [None]:
train_df.plot(kind='hist')

In [None]:
path_sample = val_df.sample(1).iloc[0]
print(path_sample)

In [None]:
image = cv2.imread(path_sample['image_path'])
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.title(f"{path_sample['label']} -- {path_sample['labels_']}")
plt.imshow(image_rgb)
plt.show()

In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torchvision.transforms import transforms
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
import datetime as dt
from sklearn.metrics import accuracy_score
import kagglehub

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def resize_image(image, dsize=(224, 224)):
    resized_image = cv2.resize(image, dsize=dsize, interpolation=cv2.INTER_LANCZOS4)
    return resized_image

In [None]:
transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float)
])

transform_augmented = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float)
])

In [None]:
class customDatasets(Dataset):
  def __init__(self, dataframe, transform=None):
     super().__init__()
     self.dataframe = dataframe
     self.transform = transform

  def __len__(self):
    return len(self.dataframe)

  def __getitem__(self, index):
    image_path = self.dataframe.iloc[index, 0]
    image = cv2.imread(image_path)

    if image is None:
        raise ValueError(f"Image not found or unreadable at path: {image_path}")

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
    image_resized = resize_image(image)

    label = torch.tensor(int(self.dataframe.iloc[index, 2]), dtype=torch.long)

    if self.transform:
        image_resized = self.transform(image_resized)

    return image_resized, label


In [None]:
train_dataset = customDatasets(train_df, transform=transformer)
test_dataset = customDatasets(test_df, transform=transformer)
val_dataset = customDatasets(val_df, transform=transformer)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [None]:
images, labels = next(iter(train_dataloader))
images = images.cpu().numpy()
labels = labels.cpu().numpy()

for i in range(len(images)):
    image = images[i].transpose(1, 2, 0)  # Convert (C, H, W) → (H, W, C)

    print(image.shape)
    plt.imshow(image, cmap='gray')
    plt.title(labels[i])
    plt.show()
    break

In [None]:
class FAPBackbone(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()

        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),

            # Block 2
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),

            # Block 3
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),

            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()

        self.embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.Dropout(0.3)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.global_pool(x)
        x = self.flatten(x)
        return x

In [None]:
class ArcFace(nn.Module):
    def __init__(self, in_features, num_classes, s=30.0, m=0.5):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.randn(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, embeddings, labels):
        embeddings = F.normalize(embeddings)
        W = F.normalize(self.weight)

        # Cosine similarity
        cos_theta = torch.matmul(embeddings, W.t()).clamp(-1, 1)

        # Apply angular margin
        theta = torch.acos(cos_theta)
        cos_theta_m = torch.cos(theta + self.m)

        # One-hot encoding
        one_hot = F.one_hot(labels, num_classes=num_classes).float()

        # Apply margin to correct class
        logits = self.s * (one_hot * cos_theta_m + (1 - one_hot) * cos_theta)
        return logits

In [None]:
class FaceNet(nn.Module):
    def __init__(self, num_classes, embedding_dim=512):
        super().__init__()
        self.backbone = FAPBackbone(embedding_dim)
        self.arcface = ArcFace(embedding_dim, num_classes)

    def forward(self, x, labels=None):
        embeddings = self.backbone(x)
        if labels is not None:
            return self.arcface(embeddings, labels)
        return embeddings

In [None]:
num_classes = len(train_df['label'].unique())
num_class2 = len(val_df['label'].unique())
num_class3 = len(test_df['label'].unique())
print(num_classes, num_class2, num_class3)

In [None]:
train_files = set(train_df['image_path'])
val_files = set(val_df['image_path'])
print(f"Overlapping files: {len(train_files & val_files)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FaceNet(num_classes=num_classes, embedding_dim=512).to(device)
print(f"Model initialized on {device}")

In [None]:
from collections import Counter
print(Counter(train_df['labels_']).most_common(10))

In [None]:
from torchsummary import summary

summary(model, input_size=(3, 224, 224))

In [None]:
LR = 5e-3
EPOCHS = 4
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2
)

In [None]:
print("Starting training...")
history = {
    'train_loss': [], 'dev_loss': [],
    'train_acc': [], 'dev_acc': []
}

start_time = dt.datetime.now()

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    epoch_train_loss = 0
    train_correct = 0
    train_total = 0

    train_loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1} [Train]")
    for images, labels in train_loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()
        preds = torch.argmax(outputs, 1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

        train_loop.set_postfix(loss=loss.item())

    # Validation phase
    model.eval()
    epoch_dev_loss = 0
    dev_correct = 0
    dev_total = 0

    with torch.no_grad():
        dev_loop = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
        for images, labels in dev_loop:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images, labels)
            loss = criterion(outputs, labels)
            epoch_dev_loss += loss.item()

            preds = torch.argmax(outputs, 1)
            dev_correct += (preds == labels).sum().item()
            dev_total += labels.size(0)

            dev_loop.set_postfix(val_loss=loss.item())

    # Calculate metrics
    train_loss = epoch_train_loss / len(train_dataloader)
    dev_loss = epoch_dev_loss / len(val_dataloader)
    train_acc = train_correct / train_total
    dev_acc = dev_correct / dev_total

    # Update history
    history['train_loss'].append(train_loss)
    history['dev_loss'].append(dev_loss)
    history['train_acc'].append(train_acc)
    history['dev_acc'].append(dev_acc)

    # Update scheduler
    scheduler.step(dev_acc)

    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
    print(f"Val   Loss: {dev_loss:.4f} | Acc: {dev_acc:.4f}")
    print("-" * 60)

# Training completion
end_time = dt.datetime.now()
print(f"Training completed in: {end_time - start_time}")
torch.save(model.state_dict(), 'arcface_model.pth')

In [None]:
# Plot training history
plt.figure(figsize=(12, 10))
plt.subplot(2, 1, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['dev_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 1, 2)
plt.plot(history['train_acc'], label='Train Accuracy')
plt.plot(history['dev_acc'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_history.png')
plt.show()