In [None]:
import os
import random
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50
from sklearn.metrics import roc_curve, accuracy_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
dataset_path = '/kaggle/input/facedataset2'
people = [person for person in os.listdir(dataset_path) 
          if len(os.listdir(os.path.join(dataset_path, person))) >= 30]
person_to_images = {}
n_images = []
for person in people:
    person_dir = os.path.join(dataset_path, person)
    images = [os.path.join(person_dir, img) for img in os.listdir(person_dir)]
    person_to_images[person] = images
    n_images.append(len(images))

print(f'Number of people: {len(people)}')
print(f'Number of images: {np.sum(n_images):.0f}')
print(f'Average: {np.mean(n_images):.1f}')
print(f'Min: {np.min(n_images):.0f}')
print(f'Max: {np.max(n_images):.0f}')

In [None]:

label_to_index = {person: idx for idx, person in enumerate(people)}
num_classes = len(label_to_index)
num_classes

In [None]:
class ArcFaceDataset(Dataset):
    def __init__(self, person_to_images, label_to_index, transform=None):
        self.image_label_pairs = []
        self.transform = transform

        for person, image_paths in person_to_images.items():
            label = label_to_index[person]
            for path in image_paths:
                self.image_label_pairs.append((path, label))

    def __len__(self):
        return len(self.image_label_pairs)

    def __getitem__(self, idx):
        img_path, label = self.image_label_pairs[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:

transform = transforms.Compose([
    transforms.Resize((112, 112)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])


train_dataset = ArcFaceDataset(person_to_images, label_to_index, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

In [None]:
lwf_dataset_path = '/kaggle/input/facedataset2'
lwf_people = [person for person in os.listdir(lwf_dataset_path) 
          if len(os.listdir(os.path.join(lwf_dataset_path, person))) >= 2]
lwf_person_to_images = {}
lwf_n_images = []

for person in lwf_people:
    person_dir = os.path.join(lwf_dataset_path, person)
    images = [os.path.join(person_dir, img) for img in os.listdir(person_dir)]
    lwf_person_to_images[person] = images
    lwf_n_images.append(len(images))

print(f'Number of people: {len(lwf_people)}')
print(f'Number of images: {np.sum(lwf_n_images):.0f}')
print(f'Average: {np.mean(lwf_n_images):.1f}')
print(f'Min: {np.min(lwf_n_images):.0f}')
print(f'Max: {np.max(lwf_n_images):.0f}')

## Danh gia model sau khi huan luyen 

In [None]:

def generate_verification_pairs(person_to_images, num_pairs=6000):
    positive_pairs = []
    negative_pairs = []

    people = list(person_to_images.keys())

    while len(positive_pairs) < num_pairs // 2:
        person = random.choice(people)
        images = person_to_images[person]
        if len(images) >= 2:
            pair = random.sample(images, 2)
            positive_pairs.append((pair[0], pair[1], 1))

    while len(negative_pairs) < num_pairs // 2:
        person1, person2 = random.sample(people, 2)
        img1 = random.choice(person_to_images[person1])
        img2 = random.choice(person_to_images[person2])
        negative_pairs.append((img1, img2, 0))

    return positive_pairs + negative_pairs

pairs = generate_verification_pairs(lwf_person_to_images)
random.shuffle(pairs)

In [None]:

transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

class LFWEvalDataset(Dataset):
    def __init__(self, pairs, transform=None):
        self.pairs = pairs
        self.transform = transform

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        path1, path2, label = self.pairs[idx]
        img1 = Image.open(path1).convert('RGB')
        img2 = Image.open(path2).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, label

lfw_eval_dataset = LFWEvalDataset(pairs, transform=transform)
lfw_loader = DataLoader(lfw_eval_dataset, batch_size=256, shuffle=False)

In [None]:

def get_embeddings(model, dataloader, device):
    model.eval()
    emb1_list, emb2_list, labels = [], [], []

    with torch.no_grad():
        for img1, img2, label in tqdm(dataloader, desc="Generating embeddings"):
            img1 = img1.to(device)
            img2 = img2.to(device)

            emb1 = model(img1)
            emb2 = model(img2)

            emb1_list.append(emb1.cpu())
            emb2_list.append(emb2.cpu())
            labels.append(label)

    return (
        torch.cat(emb1_list),
        torch.cat(emb2_list),
        torch.cat(labels)
    )

In [None]:
def evaluate(emb1, emb2, labels):
    
    emb1 = F.normalize(emb1, dim=1)
    emb2 = F.normalize(emb2, dim=1)
    distances = torch.norm(emb1 - emb2, dim=1).numpy()
    labels = labels.numpy()

    fpr, tpr, thresholds = roc_curve(labels, -distances)  

    best_acc, best_thresh = 0, 0
    for thresh in thresholds:
        preds = distances < -thresh  
        acc = accuracy_score(labels, preds)
        if acc > best_acc:
            best_acc = acc
            best_thresh = -thresh 

    return {
        'accuracy': best_acc,
        'threshold': best_thresh,
    }

In [None]:
class ArcFaceModel(nn.Module):
    def __init__(self, embedding_size=512, num_classes=num_classes):
        super(ArcFaceModel, self).__init__()
        self.backbone = resnet50(weights='DEFAULT')
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, embedding_size)
        self.backbone_bn = nn.BatchNorm1d(embedding_size)
        self.backbone_bn.bias.requires_grad_(False)

    def forward(self, x):
        x = self.backbone(x)
        x = self.backbone_bn(x)
        return x

In [None]:
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m

        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m

        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        one_hot = torch.zeros(cosine.size(), device=input.device)
        one_hot.scatter_(1, label.view(-1, 1), 1.0)

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [None]:

model = ArcFaceModel(embedding_size=128, num_classes=num_classes).to(device)
metric_fc = ArcMarginProduct(128, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(list(model.parameters()) + list(metric_fc.parameters()), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.3)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

In [None]:
epochs = 20
loss_history = []
acc_history = []
thresh_history = []
best_epoch = 0

for epoch in range(epochs):
    model.train()
    metric_fc.train()
    
    epoch_loss = 0.0
    
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, labels = images.to(device), labels.to(device)
        
        embeddings = model(images)  # (B, 512)
        logits = metric_fc(embeddings, labels)  # (B, num_classes)
        loss = criterion(logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    loss_history.append(avg_loss)
    scheduler.step()

    model.eval()
    emb1, emb2, labels = get_embeddings(model, lfw_loader, device)
    results = evaluate(emb1, emb2, labels)
    acc_history.append(results['accuracy'])
    thresh_history.append(results['threshold'])

    if results['accuracy'] > acc_history[best_epoch]:
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_arcface_model.pth')

    print(f"Loss: {avg_loss:.4f}, Accuracy: {results['accuracy']:.4f}, Threshold: {results['threshold']:.4f}")
    print("-" * 20)

In [None]:
print(f"Best epoch: {best_epoch + 1}")
print(f"Loss: {loss_history[best_epoch]:.4f}")
print(f"Accuracy: {acc_history[best_epoch]:.4f}")
print(f"Threshold: {thresh_history[best_epoch]:.4f}")

In [None]:
plt.figure(figsize=(15, 4))

#  Loss
plt.subplot(1, 3, 1)
plt.plot(loss_history)
plt.scatter(best_epoch, loss_history[best_epoch], color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')

#  Accuracy
plt.subplot(1, 3, 2)
plt.plot(acc_history, color='green')
plt.scatter(best_epoch, acc_history[best_epoch], color='red')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Evaluation Accuracy')

# Threshold
plt.subplot(1, 3, 3)
plt.plot(thresh_history, color='orange')
plt.scatter(best_epoch, thresh_history[best_epoch], color='red')
plt.xlabel('Epoch')
plt.ylabel('Threshold')
plt.title('Best Threshold')

plt.show()