In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from facenet_pytorch import InceptionResnetV1
from PIL import Image
import os
import pandas as pd
import tqdm

from torchvision import transforms

In [37]:
class CustomFaceNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomFaceNet, self).__init__()
        self.facenet = InceptionResnetV1(pretrained='vggface2', classify=False).eval()

        for param in self.facenet.parameters():
            param.requires_grad = False

        self.logits = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.facenet(x)
        x = self.logits(x)
        return x


In [38]:
transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(160, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.Resize(160),
    transforms.ToTensor(),
])

In [39]:
from torchvision import datasets

dataset = datasets.ImageFolder('Face Dataset/Train', transform=transforms)
dataloader = DataLoader(dataset, shuffle=True)
num_classes = len(dataset.classes)


In [41]:
model = CustomFaceNet(num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(100):
    running_loss = 0.0
    for i, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch: {epoch}, Loss: {running_loss/len(dataloader)}')

cpu
Epoch: 0, Loss: 1.970340439251491
Epoch: 1, Loss: 1.9245715992791312


KeyboardInterrupt: 