In [1]:
import os
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import transforms as v2
from torchvision.datasets import ImageFolder
import tqdm
import numpy as np
import warnings
import matplotlib.pyplot as plt

from src.utils import *
from src.modules import *

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
transforms = v2.Compose(
    [v2.RandomHorizontalFlip(),
        v2.RandomRotation(15),
        v2.RandomResizedCrop(160, scale=(0.8, 1.0)),
        v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        v2.RandomGrayscale(p=0.1),
        v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        v2.Resize(160),
     v2.ToTensor(),
     fixed_image_standardization]
)

dataset = ImageFolder("FaceDataset/Train_cropped", transform=transforms)
loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [3]:
facenet = InceptionResnetV1(pretrained='vggface2', classify=False).to(device)
classify = Classifier(num_classes=len(dataset.classes)).to(device)

In [5]:
epochs = 100
facenet_optim = torch.optim.Adam(facenet.parameters(), lr=1e-5)
classify_optim = torch.optim.Adam(classify.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=5, verbose=True)

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_train_loss = float('inf')
best_train_acc = 0.0

for epoch in range(epochs):
    facenet.train()
    classify.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for x, y in (loader):
        x, y = x.to(device), y.to(device)
        facenet_optim.zero_grad()
        classify_optim.zero_grad()

        
        x = facenet(x)

        y_pred = classify(x)
        loss = criterion(y_pred, y)
        loss.backward()

        facenet_optim.step()
        classify_optim.step()
        train_loss += loss.item() * y.size(0) 
        _, predicted = torch.max(y_pred, 1) 
        correct += (predicted == y).sum().item()
        total += y.size(0)
    train_loss /= total 
    train_acc = correct / total 

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    early_stopping(train_loss, train_acc)
    if train_loss < best_train_loss:
        best_val_loss = train_loss
        best_val_acc = train_acc
        torch.save({
            'facenet_state_dict': facenet.state_dict(),
            'classify_state_dict': classify.state_dict(),
        }, 'models/best_model.pth')
    if early_stopping.early_stop:
        print("Early stopping")
        break
    torch.save({
        'model_state_dict': facenet.state_dict(),
        'optimizer_state_dict': facenet_optim.state_dict(),
    }, 'models/last_model.pth')
    torch.save({
        "losses": train_losses,
        "accuracies": train_accuracies,
    }, 'models/train_metrics.pth')
    print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")



Loss improved to 3.048109.
Accuracy improved to 0.047619.
Epoch: 0 | Train Loss: 3.0481, Train Accuracy: 0.0476
Loss improved to 3.034774.
Accuracy did not improve. Counter: 1/5
Epoch: 1 | Train Loss: 3.0348, Train Accuracy: 0.0476
Loss improved to 3.015082.
Accuracy improved to 0.142857.
Epoch: 2 | Train Loss: 3.0151, Train Accuracy: 0.1429
Loss improved to 2.997172.
Accuracy improved to 0.333333.
Epoch: 3 | Train Loss: 2.9972, Train Accuracy: 0.3333
Loss improved to 2.972334.
Accuracy improved to 0.428571.
Epoch: 4 | Train Loss: 2.9723, Train Accuracy: 0.4286
Loss improved to 2.925324.
Accuracy improved to 0.476190.
Epoch: 5 | Train Loss: 2.9253, Train Accuracy: 0.4762
Loss improved to 2.924264.
Accuracy did not improve. Counter: 1/5
Epoch: 6 | Train Loss: 2.9243, Train Accuracy: 0.3810
Loss improved to 2.813657.
Accuracy did not improve. Counter: 1/5
Epoch: 7 | Train Loss: 2.8137, Train Accuracy: 0.4286
Loss improved to 2.760099.
Accuracy improved to 0.666667.
Epoch: 8 | Train Loss:

In [None]:
test_transforms = v2.Compose(
    [v2.Resize(160),
     v2.ToTensor(),
     fixed_image_standardization]
)