In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import transforms as v2
from torchvision.datasets import ImageFolder
import warnings

from src.utils import *
from src.modules import *

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transforms = v2.Compose(
    [v2.RandomHorizontalFlip(),
        v2.RandomRotation(15),
        v2.RandomResizedCrop(160, scale=(0.8, 1.0)),
        v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        v2.RandomGrayscale(p=0.1),
        v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        v2.Resize((160, 160)),
     v2.ToTensor(),
     v2.Lambda(lambda x: x + torch.randn_like(x) * 0.05),
     fixed_image_standardization]
)

dataset = ImageFolder("FaceDataset/Train_cropped", transform=transforms)
loader = DataLoader(dataset, batch_size=6, shuffle=True)

In [38]:
facenet_model = InceptionResnetV1(pretrained="casia-webface", classify=False)

In [39]:
feature_agg = NetworkFeatureAggregator(facenet_model, ['block8'], device = device, train_backbone=True)
features = feature_agg(torch.rand(64, 3, 160, 160).to(device))
print(features['block8'].shape)
features = [features[layer] for layer in ['block8']]

torch.Size([64, 1792, 3, 3])


In [40]:
augment_layer = AugmentationLayer()

In [49]:
class FaceNetClassifierWithAugFMap(nn.Module):
    def __init__(self, num_classes, aug_layer = None):
        super(FaceNetClassifierWithAugFMap, self).__init__()
        self.extractor = feature_agg
        self.aug_layer = None if aug_layer is None else aug_layer
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1792, num_classes)
        )
    
    def forward(self, x):
        features = self.extractor(x)['block8']
        if self.aug_layer is not None:
            features = self.aug_layer(features)
        # features = features.view(features.size(0), -1)
        x = self.fc(features)
        return x

In [50]:
facenet = FaceNetClassifierWithAugFMap(num_classes=len(dataset.classes), aug_layer=AugmentationLayer()).to(device)

epochs = 100
facenet_params = list(facenet.extractor.parameters())
classifier_params = list(facenet.fc.parameters())
optimizer = torch.optim.Adam([
    {'params': facenet_params, 'lr': 1e-4},
    {'params': classifier_params, 'lr': 1e-3}
])
criterion = torch.nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=5, verbose=True)

In [51]:
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_train_loss = float('inf')
best_train_acc = 0.0

for epoch in range(epochs):
    facenet.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for x, y in (loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_pred = facenet(x)

        loss = criterion(y_pred, y)
        loss.backward()

        optimizer.step()
        train_loss += loss.item() * y.size(0) 
        _, predicted = torch.max(y_pred, 1) 
        correct += (predicted == y).sum().item()
        total += y.size(0)
    train_loss /= total 
    train_acc = correct / total 

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    early_stopping(train_loss, train_acc)
    if train_loss < best_train_loss:
        best_val_loss = train_loss
        best_val_acc = train_acc
        torch.save({
            'facenet_state_dict': facenet.state_dict(),
        }, 'models/best_model.pth')
    if early_stopping.early_stop:
        print("Early stopping")
        break
    torch.save({
        'facenet_state_dict': facenet.state_dict(),
    }, 'models/last_model.pth')
    torch.save({
        "losses": train_losses,
        "accuracies": train_accuracies,
    }, 'models/train_metrics.pth')
    print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")



Loss improved to 3.022618.
Accuracy did not improve. Counter: 1/5
Epoch: 0 | Train Loss: 3.0226, Train Accuracy: 0.0000
Loss improved to 2.888305.
Accuracy improved to 0.300000.
Epoch: 1 | Train Loss: 2.8883, Train Accuracy: 0.3000
Loss improved to 2.697709.
Accuracy improved to 0.550000.
Epoch: 2 | Train Loss: 2.6977, Train Accuracy: 0.5500
Loss improved to 2.559051.
Accuracy improved to 0.700000.
Epoch: 3 | Train Loss: 2.5591, Train Accuracy: 0.7000
Loss improved to 2.493298.
Accuracy improved to 0.850000.
Epoch: 4 | Train Loss: 2.4933, Train Accuracy: 0.8500
Loss improved to 2.294947.
Accuracy improved to 1.000000.
Epoch: 5 | Train Loss: 2.2949, Train Accuracy: 1.0000
Loss improved to 2.265279.
Accuracy did not improve. Counter: 1/5
Epoch: 6 | Train Loss: 2.2653, Train Accuracy: 0.8500
Loss improved to 2.108851.
Accuracy did not improve. Counter: 1/5
Epoch: 7 | Train Loss: 2.1089, Train Accuracy: 1.0000
Loss improved to 1.934183.
Accuracy did not improve. Counter: 1/5
Epoch: 8 | Tra

In [52]:
test_transforms = v2.Compose(
    [v2.Resize((160, 160)),
     v2.ToTensor(),
     fixed_image_standardization]
)

test_dataset = ImageFolder("FaceDataset/Test_cropped", transform=transforms)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [53]:
test_acc = 0.0
test_loss = 0.0

checkpoint = torch.load(r'models\best_model.pth')
facenet.load_state_dict(checkpoint['facenet_state_dict'])

facenet.eval()

correct = 0
total = 0
for x, y in test_loader:
    x, y = x.to(device), y.to(device)
    y_pred = facenet(x)
    _, predicted = torch.max(y_pred, 1)
    correct += (predicted == y).sum().item()
    total += y.size(0)
test_acc = correct / total
print(f"Test Acc: {test_acc*100}")


Test Acc: 81.23916811091854
