In [19]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import transforms as v2
from torchvision.datasets import ImageFolder
import warnings

from src.utils import *
from src.modules import *

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [56]:
import os
os.remove('FaceDataset/Train_cropped/Tung20/.DS_Store')

In [20]:
transforms = v2.Compose(
    [v2.RandomHorizontalFlip(),
        v2.RandomRotation(15),
        v2.RandomResizedCrop(160, scale=(0.8, 1.0)),
        v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        v2.RandomGrayscale(p=0.1),
        v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        v2.Resize((160, 160)),
     v2.ToTensor(),
     v2.Lambda(lambda x: x + torch.randn_like(x) * 0.05),
     fixed_image_standardization]
)

dataset = ImageFolder("FaceDataset/Train_cropped", transform=transforms)
loader = DataLoader(dataset, batch_size=6, shuffle=True)

In [21]:
facenet_model = InceptionResnetV1(pretrained="casia-webface", classify=False)

In [22]:
feature_agg = NetworkFeatureAggregator(facenet_model, ['block8'], device = device, train_backbone=True)
features = feature_agg(torch.rand(64, 3, 160, 160).to(device))
print(features['block8'].shape)
features = [features[layer] for layer in ['block8']]

torch.Size([64, 1792, 3, 3])


In [23]:
augment_layer = AugmentationLayer()

In [24]:
class FaceNetClassifierWithAugFMap(nn.Module):
    def __init__(self, num_classes, aug_layer = None):
        super(FaceNetClassifierWithAugFMap, self).__init__()
        self.extractor = feature_agg
        self.aug_layer = None if aug_layer is None else aug_layer
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1792, num_classes)
        )
    
    def forward(self, x):
        features = self.extractor(x)['block8']  
        if self.aug_layer is not None:
            features = self.aug_layer(features)
        # features = features.view(features.size(0), -1)
        x = self.fc(features)
        return x

In [33]:
facenet = FaceNetClassifierWithAugFMap(num_classes=20, aug_layer=AugmentationLayer()).to(device)

epochs = 100
facenet_params = list(facenet.extractor.parameters())
classifier_params = list(facenet.fc.parameters())
optimizer = torch.optim.Adam([
    {'params': facenet_params, 'lr': 1e-4},
    {'params': classifier_params, 'lr': 1e-3}
])
criterion = torch.nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=5, verbose=True)

In [26]:
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_train_loss = float('inf')
best_train_acc = 0.0

for epoch in range(epochs):
    facenet.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for x, y in (loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_pred = facenet(x)

        loss = criterion(y_pred, y)
        loss.backward()

        optimizer.step()
        train_loss += loss.item() * y.size(0) 
        _, predicted = torch.max(y_pred, 1) 
        correct += (predicted == y).sum().item()
        total += y.size(0)
    train_loss /= total 
    train_acc = correct / total 

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    early_stopping(train_loss, train_acc)
    if train_loss < best_train_loss:
        best_val_loss = train_loss
        best_val_acc = train_acc
        torch.save({
            'facenet_state_dict': facenet.state_dict(),
        }, 'models/best_model.pth')
    if early_stopping.early_stop:
        print("Early stopping")
        break
    torch.save({
        'facenet_state_dict': facenet.state_dict(),
    }, 'models/last_model.pth')
    torch.save({
        "losses": train_losses,
        "accuracies": train_accuracies,
    }, 'models/train_metrics.pth')
    print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")



Loss improved to 3.082892.
Accuracy improved to 0.047619.
Epoch: 0 | Train Loss: 3.0829, Train Accuracy: 0.0476
Loss improved to 2.889499.
Accuracy improved to 0.238095.
Epoch: 1 | Train Loss: 2.8895, Train Accuracy: 0.2381
Loss improved to 2.745481.
Accuracy improved to 0.619048.
Epoch: 2 | Train Loss: 2.7455, Train Accuracy: 0.6190
Loss improved to 2.652667.
Accuracy did not improve. Counter: 1/5
Epoch: 3 | Train Loss: 2.6527, Train Accuracy: 0.6190
Loss improved to 2.504272.
Accuracy improved to 0.714286.
Epoch: 4 | Train Loss: 2.5043, Train Accuracy: 0.7143
Loss improved to 2.360831.
Accuracy improved to 0.857143.
Epoch: 5 | Train Loss: 2.3608, Train Accuracy: 0.8571
Loss improved to 2.294186.
Accuracy improved to 0.904762.
Epoch: 6 | Train Loss: 2.2942, Train Accuracy: 0.9048
Loss improved to 2.028974.
Accuracy improved to 1.000000.
Epoch: 7 | Train Loss: 2.0290, Train Accuracy: 1.0000
Loss improved to 2.001529.
Accuracy did not improve. Counter: 1/5
Epoch: 8 | Train Loss: 2.0015,

In [27]:
test_transforms = v2.Compose(
    [v2.Resize((160, 160)),
     v2.ToTensor(),
     fixed_image_standardization]
)

test_dataset = ImageFolder("FaceDataset/Test_cropped", transform=transforms)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [35]:
test_acc = 0.0
test_loss = 0.0

checkpoint = torch.load(r'models/best_model_812_1910.pth', map_location=device)
facenet.load_state_dict(checkpoint['facenet_state_dict'])

facenet.eval()

correct = 0
total = 0
with torch.no_grad(): 
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        y_pred = facenet(x)
        _, predicted = torch.max(y_pred, 1)
        correct += (predicted == y).sum().item()
        total += y.size(0)
test_acc = correct / total
print(f"Test Acc: {test_acc*100}")


Test Acc: 79.11205073995772
