# COMP 4437- Artificial Neural Network / 11.06.2024 / Course Project 
## Recognition and Classification of Celebrities Generated by AI

### Bahadır Erdem 21070001048 


# Introduction

In this project I am working with stylized portraits of celebrities made by artificial intelligence and my goal is the classify the celebrity and the style of the ai that made the portraits.


In [1]:
# Importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np


from helper import *


# Methods

## Data Augmentation

The dataset for this project includes training, validation, and test sets of celebrity caricatures. To enhance the quality and accuracy of the model, I applied data augmentation techniques to the training set. The validation and test sets were only resized to 224x224 pixels without augmentation to ensure consistent evaluation conditions.

In [2]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
])

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

Creating datasets and dataloaders

In [None]:

# Create dataset instances
train_dataset = CelebCariDataset(root_dir='./Project/train', transform=train_transform)
val_dataset = CelebCariDataset(root_dir='./Project/validation', transform=transform)
test_dataset = CelebCariTestDataset(root_dir='./Project/test', transform=transform)

# Create dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Model Architecture

A simple multi-label classification model based on the pre-trained MobileNetV2 was employed. The model's backbone (all layers except the final classifier) was frozen to leverage pre-trained weights for feature extraction, and custom fully connected (FC) layers were added for identity and style classification tasks.

In [None]:
class SimpleMultiLabelModel(nn.Module):
    def __init__(self, num_classes_identity, num_classes_style):
        super(SimpleMultiLabelModel, self).__init__()
        self.backbone = models.mobilenet_v2(pretrained=True)
        for param in self.backbone.parameters():
            param.requires_grad = False  # Freeze the backbone
        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.fc_identity = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes_identity)
        )
        self.fc_style = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes_style)
        )

    def forward(self, x):
        features = self.backbone(x)
        identity_output = self.fc_identity(features)
        style_output = self.fc_style(features)
        return features, identity_output, style_output

# Output

```python
SimpleMultiLabelModel(
  (backbone): MobileNetV2(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
            (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (3): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
            (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (4): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)
            (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (5): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (6): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (7): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (8): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (9): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (10): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (11): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (12): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
            (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (13): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
            (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (14): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)
            (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (15): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
            (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (16): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
            (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (17): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
            (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (18): Conv2dNormActivation(
        (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
    )
    (classifier): Identity()
  )
  (fc_identity): Sequential(
    (0): Linear(in_features=1280, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=20, bias=True)
  )
  (fc_style): Sequential(
    (0): Linear(in_features=1280, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=6, bias=True)
  )
)

# Training Procedure

The training procedure included a custom loss function that combines identity and style classification losses with a weighted sum. An Adam optimizer with a higher weight decay (1e-3) to combat overfitting and a lower learning rate (0.001) was used. A learning rate scheduler was also implemented to reduce the learning rate by a factor of 0.1 every 5 epochs.

Here is the links that I have done my research
https://stackoverflow.com/questions/69763161/how-to-design-a-joint-loss-function-with-two-component-with-the-aim-of-minimizin

In [None]:
# Define the loss function and optimizer with higher weight decay
criterion_identity = nn.CrossEntropyLoss()
criterion_style = nn.CrossEntropyLoss()

def joint_loss(identity_output, style_output, identity_labels, style_labels, weight=0.5):
    loss_identity = criterion_identity(identity_output, identity_labels)
    loss_style = criterion_style(style_output, style_labels)
    return weight * loss_identity + (1 - weight) * loss_style

# Optimizer with higher weight decay and lower learning rate
optimizer = optim.Adam([
    {'params': model.fc_identity.parameters()},
    {'params': model.fc_style.parameters()}
], lr=0.001, weight_decay=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Early Stopping and Model Evaluation

Early stopping was implemented to prevent overfitting. The model's performance was monitored on the validation set, and training was stopped if there was no improvement in validation loss for 5 consecutive epochs. The best model weights (with the lowest validation loss) were saved.

In [None]:
num_epochs = 7
patience = 5

# Training loop with early stopping and tracking
best_model_wts = model.state_dict()
best_loss = float('inf')
early_stopping_counter = 0

# Lists to store the loss and accuracy values
train_losses, val_losses = [], []
train_identity_losses, val_identity_losses = [], []
train_style_losses, val_style_losses = [], []
train_identity_accuracies, val_identity_accuracies = [], []
train_style_accuracies, val_style_accuracies = [], []

for epoch in range(num_epochs):
    # Training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
            dataloader = train_loader
        else:
            model.eval()
            dataloader = val_loader

        running_loss = 0.0
        running_loss_identity = 0.0
        running_loss_style = 0.0
        correct_identity = 0
        correct_style = 0
        total = 0

        for inputs, identity_labels, style_labels in dataloader:
            inputs = inputs.to(device)
            identity_labels = identity_labels.to(device)
            style_labels = style_labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                features, identity_output, style_output = model(inputs)
                loss = joint_loss(identity_output, style_output, identity_labels, style_labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_loss_identity += criterion_identity(identity_output, identity_labels).item() * inputs.size(0)
            running_loss_style += criterion_style(style_output, style_labels).item() * inputs.size(0)
            total += inputs.size(0)

            _, predicted_identity = torch.max(identity_output, 1)
            _, predicted_style = torch.max(style_output, 1)
            
            correct_identity += (predicted_identity == identity_labels).sum().item()
            correct_style += (predicted_style == style_labels).sum().item()

        epoch_loss = running_loss / total
        epoch_loss_identity = running_loss_identity / total
        epoch_loss_style = running_loss_style / total

        accuracy_identity = correct_identity / total
        accuracy_style = correct_style / total

        if phase == 'train':
            train_losses.append(epoch_loss)
            train_identity_losses.append(epoch_loss_identity)
            train_style_losses.append(epoch_loss_style)
            train_identity_accuracies.append(accuracy_identity)
            train_style_accuracies.append(accuracy_style)
        else:
            val_losses.append(epoch_loss)
            val_identity_losses.append(epoch_loss_identity)
            val_style_losses.append(epoch_loss_style)
            val_identity_accuracies.append(accuracy_identity)
            val_style_accuracies.append(accuracy_style)

        if phase == 'val':
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = model.state_dict()
                torch.save(best_model_wts, 'best_model.pth')
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1

    scheduler.step()

    if early_stopping_counter >= patience:
        print("Early stopping")
        break

model.load_state_dict(torch.load('best_model.pth'))


# Output

```python
Epoch 0/6
----------
train Loss: 1.3665
train Identity Loss: 1.6406
train Style Loss: 1.0925
train Identity Accuracy: 0.5402
train Style Accuracy: 0.6108
val Loss: 2.0127
val Identity Loss: 3.3435
val Style Loss: 0.6819
val Identity Accuracy: 0.0833
val Style Accuracy: 0.7667
Epoch 1/6
----------
train Loss: 0.6548
train Identity Loss: 0.7660
train Style Loss: 0.5436
train Identity Accuracy: 0.7956
train Style Accuracy: 0.8191
val Loss: 2.4061
val Identity Loss: 4.2438
val Style Loss: 0.5683
val Identity Accuracy: 0.1125
val Style Accuracy: 0.7833
Epoch 2/6
----------
train Loss: 0.5098
train Identity Loss: 0.5337
train Style Loss: 0.4858
train Identity Accuracy: 0.8510
train Style Accuracy: 0.8240
val Loss: 2.3955
val Identity Loss: 4.2900
val Style Loss: 0.5010
val Identity Accuracy: 0.0958
val Style Accuracy: 0.8333
Epoch 3/6
----------
train Loss: 0.4200
train Identity Loss: 0.3979
train Style Loss: 0.4422
train Identity Accuracy: 0.8882
train Style Accuracy: 0.8495
val Loss: 2.6933
val Identity Loss: 5.0021
val Style Loss: 0.3845
val Identity Accuracy: 0.0500
val Style Accuracy: 0.8667
Epoch 4/6
----------
train Loss: 0.3784
train Identity Loss: 0.3563
train Style Loss: 0.4005
train Identity Accuracy: 0.8971
train Style Accuracy: 0.8515
val Loss: 2.9685
val Identity Loss: 5.4769
val Style Loss: 0.4600
val Identity Accuracy: 0.0708
val Style Accuracy: 0.8458
Epoch 5/6
----------
train Loss: 0.2873
train Identity Loss: 0.2725
train Style Loss: 0.3021
train Identity Accuracy: 0.9270
train Style Accuracy: 0.9000
val Loss: 2.8098
val Identity Loss: 5.2538
val Style Loss: 0.3658
val Identity Accuracy: 0.0958
val Style Accuracy: 0.8958
Early stopping


# Visualization

After training, the loss and accuracy metrics for both identity and style classifications were plotted to visually assess the model's performance over the training epochs.

In [None]:
epochs_range = range(num_epochs)
# Adjusting the lengths of epochs_range and metrics if needed
min_length = min(len(epochs_range), len(train_losses), len(val_losses), 
                 len(train_identity_losses), len(val_identity_losses), 
                 len(train_style_losses), len(val_style_losses), 
                 len(train_identity_accuracies), len(val_identity_accuracies),
                 len(train_style_accuracies), len(val_style_accuracies))

# Truncate all lists to the minimum length
epochs_range = epochs_range[:min_length]
train_losses = train_losses[:min_length]
val_losses = val_losses[:min_length]
train_identity_losses = train_identity_losses[:min_length]
val_identity_losses = val_identity_losses[:min_length]
train_style_losses = train_style_losses[:min_length]
val_style_losses = val_style_losses[:min_length]
train_identity_accuracies = train_identity_accuracies[:min_length]
val_identity_accuracies = val_identity_accuracies[:min_length]
train_style_accuracies = train_style_accuracies[:min_length]
val_style_accuracies = val_style_accuracies[:min_length]

# Plotting the losses
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_identity_losses, label='Training Identity Loss')
plt.plot(epochs_range, val_identity_losses, label='Validation Identity Loss')
plt.plot(epochs_range, train_style_losses, label='Training Style Loss')
plt.plot(epochs_range, val_style_losses, label='Validation Style Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Identity and Style Loss')

plt.tight_layout()
plt.show()


![Loss Graph](LossGraph.png)

# Embedding and Evaluation

To evaluate the model on the validation set, embeddings were created for each image. A gallery of embeddings was maintained and updated based on similarity thresholds. Cosine similarity was used to compare embeddings, and a confusion matrix was generated to visualize the model's performance in classifying identities and styles.

I couldn't make this part work well so I have gotten help from my friend Ege Bilge and I still do not understand why it works with two functions but whatever I tried to make it work didn't worked.

In [None]:
def evaluate_model_and_update_gallery(model, val_dataset, gallery_json_path, device, similarity_threshold=0.5):
    # Load gallery embeddings from JSON file
    loaded_gallery_embeddings = read_gallery_from_json(gallery_json_path)
    print(f"Loaded gallery embeddings: {len(loaded_gallery_embeddings)} persons")

    # Create embeddings for the probe set (validation set)
    model.eval()
    probe_embeddings = {}
    with torch.no_grad():
        for images, person_labels, _ in DataLoader(val_dataset, batch_size=32, shuffle=False):
            images = images.to(device)
            embeddings, _, _ = model(images)
            for embedding, person_label in zip(embeddings, person_labels):
                person_name = val_dataset.index_to_person[person_label.item()]
                if person_name not in probe_embeddings:
                    probe_embeddings[person_name] = []
                probe_embeddings[person_name].append(embedding.cpu())
    
    print(f"Created probe embeddings: {len(probe_embeddings)} persons")
    total_val_images = sum(len(embeds) for embeds in probe_embeddings.values())
    print(f"Total validation images processed: {total_val_images}")

    # Ensure that we have the same number of images for validation and embedding creation
    all_probe_embeddings = []
    all_probe_labels = []
    for person_name, embeddings in probe_embeddings.items():
        all_probe_embeddings.extend(embeddings)
        all_probe_labels.extend([person_name] * len(embeddings))

    all_probe_tensors = torch.stack(all_probe_embeddings)
    probe_tensors = torch.stack([torch.mean(torch.stack(embeds), dim=0) for embeds in probe_embeddings.values()])
    gallery_tensors = torch.stack([torch.mean(torch.stack(embeds), dim=0) for embeds in loaded_gallery_embeddings.values()])

    cos_sim = torch.matmul(probe_tensors, gallery_tensors.T)
    cos_sim = cos_sim / (torch.norm(probe_tensors, dim=1, keepdim=True) * torch.norm(gallery_tensors, dim=1))

    # Create a mapping from person names to indices
    person_to_idx = {person: idx for idx, person in enumerate(loaded_gallery_embeddings.keys())}
    idx_to_person = {idx: person for person, idx in person_to_idx.items()}

    # Add an "unknown" category
    unknown_idx = len(person_to_idx)
    idx_to_person[unknown_idx] = "unknown"

    # Determine predicted labels and handle unknown faces
    predicted_indices = []
    for i, similarities in enumerate(cos_sim):
        max_similarity = torch.max(similarities).item()
        person_name = list(probe_embeddings.keys())[i]
        print(f"Person: {person_name}, Max Similarity: {max_similarity}")
        if max_similarity < similarity_threshold:
            predicted_indices.append(len(person_to_idx))  # Use new index for new person
            if person_name not in loaded_gallery_embeddings:
                print(f"Adding {person_name} to gallery")
                loaded_gallery_embeddings[person_name] = probe_embeddings[person_name]
                person_to_idx[person_name] = len(person_to_idx)
                idx_to_person[len(idx_to_person)] = person_name
        else:
            predicted_indices.append(torch.argmax(similarities).item())

    # True indices based on the probe set
    true_indices = []
    for person_name in all_probe_labels:
        if person_name in person_to_idx:
            person_idx = person_to_idx[person_name]
        else:
            person_idx = len(person_to_idx)  # Assign new index if not found
            person_to_idx[person_name] = person_idx
            idx_to_person[person_idx] = person_name
            loaded_gallery_embeddings[person_name] = probe_embeddings[person_name]
            print(f"Added {person_name} to gallery as part of true_indices")
        true_indices.append(person_idx)

    # Save updated gallery
    save_gallery_to_json(loaded_gallery_embeddings, gallery_json_path)
    print(f"Updated gallery embeddings: {len(loaded_gallery_embeddings)} persons")

def evaluate_model_and_predict(model, val_dataset, gallery_json_path, device, similarity_threshold=0.5):
    # Load gallery embeddings from JSON file
    loaded_gallery_embeddings = read_gallery_from_json(gallery_json_path)
    print(f"Loaded gallery embeddings: {len(loaded_gallery_embeddings)} persons")

    # Create embeddings for the probe set (validation set)
    model.eval()
    probe_embeddings = {}
    probe_labels = []
    with torch.no_grad():
        for images, person_labels, _ in DataLoader(val_dataset, batch_size=32, shuffle=False):
            images = images.to(device)
            embeddings, _, _ = model(images)
            for embedding, person_label in zip(embeddings, person_labels):
                person_name = val_dataset.index_to_person[person_label.item()]
                if person_name not in probe_embeddings:
                    probe_embeddings[person_name] = []
                probe_embeddings[person_name].append(embedding.cpu())
                probe_labels.append(person_name)

    print(f"Created probe embeddings: {len(probe_embeddings)} persons")
    total_val_images = sum(len(embeds) for embeds in probe_embeddings.values())
    print(f"Total validation images processed: {total_val_images}")

    # Ensure that we have the same number of images for validation and embedding creation
    all_probe_embeddings = []
    all_probe_labels = []
    for person_name, embeddings in probe_embeddings.items():
        all_probe_embeddings.extend(embeddings)
        all_probe_labels.extend([person_name] * len(embeddings))

    probe_tensors = torch.stack(all_probe_embeddings)
    gallery_tensors = torch.stack([torch.mean(torch.stack(embeds), dim=0) for embeds in loaded_gallery_embeddings.values()])

    cos_sim = torch.matmul(probe_tensors, gallery_tensors.T)
    cos_sim = cos_sim / (torch.norm(probe_tensors, dim=1, keepdim=True) * torch.norm(gallery_tensors, dim=1))

    # Create a mapping from person names to indices
    person_to_idx = {person: idx for idx, person in enumerate(loaded_gallery_embeddings.keys())}
    idx_to_person = {idx: person for person, idx in person_to_idx.items()}

    # Add an "unknown" category
    unknown_idx = len(person_to_idx)
    idx_to_person[unknown_idx] = "unknown"

    # Determine predicted labels and handle unknown faces
    predicted_indices = []
    for i, similarities in enumerate(cos_sim):
        max_similarity = torch.max(similarities).item()
        person_name = all_probe_labels[i]
        print(f"Person: {person_name}, Max Similarity: {max_similarity}")
        if max_similarity < similarity_threshold:
            predicted_indices.append(unknown_idx)  # Use unknown index for unknown person
            if person_name not in loaded_gallery_embeddings:
                print(f"Adding {person_name} to gallery")
                loaded_gallery_embeddings[person_name] = [probe_tensors[i]]
                person_to_idx[person_name] = len(person_to_idx)
                idx_to_person[len(idx_to_person)] = person_name
        else:
            predicted_indices.append(torch.argmax(similarities).item())

    # True indices based on the probe set
    true_indices = []
    for person_name in all_probe_labels:
        if person_name in person_to_idx:
            person_idx = person_to_idx[person_name]
        else:
            person_idx = unknown_idx
        true_indices.append(person_idx)

    # Save updated gallery
    save_gallery_to_json(loaded_gallery_embeddings, gallery_json_path)
    print(f"Updated gallery embeddings: {len(loaded_gallery_embeddings)} persons")

    if len(predicted_indices) != len(true_indices):
        print(f"Warning: Length mismatch - Predicted: {len(predicted_indices)}, True: {len(true_indices)}")
        min_length = min(len(predicted_indices), len(true_indices))
        predicted_indices = predicted_indices[:min_length]
        true_indices = true_indices[:min_length]

    accuracy = (torch.tensor(predicted_indices) == torch.tensor(true_indices)).float().mean().item()
    print(f'Validation Set Accuracy: {accuracy:.4f}')

    true_labels = [idx_to_person[idx] for idx in true_indices]
    predicted_labels = [idx_to_person[idx] for idx in predicted_indices]

    cm = confusion_matrix(true_labels, predicted_labels, labels=list(idx_to_person.values()))

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=idx_to_person.values(), yticklabels=idx_to_person.values())
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()

    plt.figure(figsize=(12, 6))
    true_count = pd.Series(true_labels).value_counts().sort_index()
    pred_count = pd.Series(predicted_labels).value_counts().sort_index()

    if not true_count.empty and not pred_count.empty:
        true_count.plot(kind='bar', alpha=0.5, color='blue', position=0, label='True Labels')
        pred_count.plot(kind='bar', alpha=0.5, color='red', position=1, label='Predicted Labels')

        plt.xlabel('Labels')
        plt.ylabel('Count')
        plt.title('True vs Predicted Labels')
        plt.legend()
        plt.show()
    else:
        print("No data to plot.")

# Call the evaluation functions
gallery_json_path = 'gallery_embeddings.json' 
evaluate_model_and_update_gallery(model, val_dataset, gallery_json_path, device)
evaluate_model_and_predict(model, val_dataset, gallery_json_path, device)


With all that I managed to get the celebrity classification working but I could not make the styles part.

And this is the celebrity conf matrix that I got.

![Celebrity Matrix](matrixceleb.png)