## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from datasets import load_dataset
from torchsummary import summary
import numpy as np
from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm

import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import random
import time
import os
import glob

# Face recognition library
import face_recognition

## Define useful classes

In [6]:
# AffectNet Dataset Class

class AffectNetHqDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        label = item['label']

        if self.transform:
            image = self.transform(image)

        return image, label
    
# RAFDB Dataset Class
 
class RAFDBDataset(Dataset):
    def __init__(self, root_dir, label_dir, subset, label_file_name, transform=None):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.transform = transform
        self.subset = subset
        self.label_file_name = label_file_name
        self.labels, self.image_paths = self._load_data()

    def _load_data(self):
        labels = []
        image_paths = []
        
        labels_file_path = os.path.join(self.label_dir, self.label_file_name)
        with open(labels_file_path, 'r') as file:
            lines = file.readlines()

            for line in lines:
                parts = line.strip().split(' ')
                label = int(parts[1])
                image_path = os.path.join(self.root_dir, self.subset, parts[0])
                labels.append(label)
                image_paths.append(image_path)

        return labels, image_paths

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
    
        if self.transform:
            image = self.transform(image)

        return image, label

## Heatmaps Generation functions

In [3]:
def heatmap_generator(image):
    face_locations = face_recognition.face_locations(image)

    # Load the pre-trained facial landmark model
    face_landmarks_list = face_recognition.face_landmarks(image, face_locations)

    height, width = image.shape[:2]
    lm = np.zeros([height,width])

    # Draw facial landmarks on the image
    for face_landmarks in face_landmarks_list:
        for landmark_type, landmarks in face_landmarks.items():
            for (x, y) in landmarks:
                if x < height and y < width :
                    lm[y,x] = 1

    heatmap = cv2.GaussianBlur(lm, [59,59], 3)         

    return heatmap

def generate_batch_heatmaps(images, heatmap_generator):
    batch_heatmaps = torch.zeros_like(images)

    for i in range(images.size(0)):
        # Convertir le tenseur PyTorch en tableau NumPy pour l'image i
        image_np = images[i].permute(1, 2, 0).cpu().detach().numpy()
        image_np = (image_np * 255).astype(np.uint8) if image_np.dtype != np.uint8 else image_np
        image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
            
        # Générer la heatmap pour l'image actuelle
        heatmap_np = heatmap_generator(image_np)
        heatmap_tensor = torch.from_numpy(heatmap_np).float().unsqueeze(0)

        # Normaliser la heatmap et l'adapter à la taille de l'image
        heatmap_tensor = heatmap_tensor / torch.max(heatmap_tensor)
        heatmap_tensor = heatmap_tensor.repeat(3, 1, 1)

        # Stocker la heatmap dans le tenseur batch
        batch_heatmaps[i] = heatmap_tensor
    
    return batch_heatmaps

## Privileged Attribution Loss

In [4]:
class PrivilegedAttributionLoss(nn.Module):
    def __init__(self):
        super(PrivilegedAttributionLoss, self).__init__()

    def forward(self, attribution_maps, prior_maps):
        # Add a small value to standard deviation to avoid division by zero
        epsilon = 1e-8

        # Calculate mean and standard deviation for each sample in the batch
        mean_al = torch.mean(attribution_maps, dim=[1, 2, 3], keepdim=True)  # Assuming BCHW format
        std_al = torch.std(attribution_maps, dim=[1, 2, 3], keepdim=True) + epsilon

        # Calculate the PAL loss
        # Ensure that the broadcasting in the subtraction and division is correct
        pal_loss = -torch.sum((attribution_maps - mean_al) / std_al * prior_maps, dim=[1, 2, 3])

        # Return the mean loss over the batch
        return torch.mean(pal_loss)

## AffectNet Dataloader

In [16]:
# Load the full dataset
full_dataset = load_dataset("../datasets/AffectNet", split='train')

# Split the dataset into train and test subsets
train_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_subset, test_subset = random_split(full_dataset, [train_size, test_size])

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation((-10, 10)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create the dataset and dataloader using the subsets
train_dataset = AffectNetHqDataset(Subset(full_dataset, train_subset.indices), transform=train_transform)
test_dataset = AffectNetHqDataset(Subset(full_dataset, test_subset.indices), transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Resolving data files:   0%|          | 0/16349 [00:00<?, ?it/s]

Found cached dataset imagefolder (C:/Users/lenov/.cache/huggingface/datasets/imagefolder/AffectNet-ac71cf06b4a145a7/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


## RAFDB Dataloader

In [9]:
# Transform function for train data
train_transform = transforms.Compose([
    transforms.RandomRotation(degrees=(-10, 10)),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),
])

# Transform function for test data
transform = transforms.Compose([
    transforms.ToTensor(),
])

root_dir = '../datasets/RAF-DB/Image/aligned/'
label_dir = '../datasets/RAF-DB/Image/aligned/labels'

# Create the dataset and dataloader using the subsets
RAFDB_train_dataset = RAFDBDataset(root_dir=root_dir, label_dir = label_dir, subset = 'train', label_file_name='train_label.txt', transform=transform)
RAFDB_test_dataset = RAFDBDataset(root_dir=root_dir, label_dir = label_dir, subset = 'test', label_file_name='test_label.txt', transform=transform)

RAFDB_train_loader = DataLoader(RAFDB_train_dataset, batch_size=16, shuffle=True)
RAFDB_test_loader = DataLoader(RAFDB_test_dataset, batch_size=16, shuffle=False)

## VGG16 model

In [12]:
# Charger le modèle pré-entraîné VGG16
base_model = torchvision.models.vgg16(pretrained=True)

# Supprimer la dernière couche entièrement connectée
base_model.classifier = nn.Sequential(*list(base_model.classifier.children())[:-1])

# Ajouter une nouvelle couche adaptée à 7 classes
num_classes = 7
classifier_layer = nn.Linear(4096, num_classes)
model = nn.Sequential(base_model, classifier_layer)

# Afficher la structure du modèle
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

## Define Ploting function

In [11]:
def plot_element(images, batch_heatmaps, attribution_maps, gradients, i):
    
    plt.figure(figsize=(12, 8))  # Agrandir la figure pour accueillir toutes les visualisations
    # Afficher l'image original
    plt.subplot(2, 4, 1)
    image_to_show = images[i].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(image_to_show)
    plt.title('Original Image')

    # Afficher la heatmap
    plt.subplot(2, 4, 2)
    heatmap_to_show = batch_heatmaps[i].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(heatmap_to_show)
    plt.title('Heatmap')

    # Afficher chaque canal de la carte d'attribution
    plt.subplot(2, 4, 3)
    attribution_to_show = attribution_maps[i].detach().permute(1, 2, 0).cpu().numpy()
    attribution_norm = (attribution_to_show - attribution_to_show.min()) / (attribution_to_show.max() - attribution_to_show.min())
    attribution_mean = np.mean(attribution_norm, axis=2)
    # Appliquer une colormap 'jet' pour obtenir une carte d'attribution colorée
    cmap = plt.get_cmap('bwr')
    attribution_colored = cmap(attribution_mean)

    # Supprimer le canal alpha retourné par la colormap
    attribution_colored = attribution_colored[..., :3]
    overlayed_image = (image_to_show) * 0.2 + attribution_colored * 0.9  # Ajustez la transparence ici
    plt.imshow(overlayed_image)
    plt.title('Attribution Overlay on Original Image')

    # Afficher le gradient de sortie sur l'image originale
    plt.subplot(1, 4, 4)
    gradients_to_show = gradients[i].detach().permute(1, 2, 0).cpu().numpy()
    gradients_to_show = np.abs(gradients_to_show)
    gradients_to_show /= np.max(gradients_to_show)
                
    # Superposer le gradient sur l'image originale
    overlayed_image = (gradients_to_show * 1.5 + image_to_show * 0.2)
    plt.imshow(overlayed_image)
    plt.title('Gradient Overlay on Original Image')
    
    plt.show()

## Train

In [14]:
# Train parameters
num_epochs = 10
optimizer = optim.Adam(model.parameters(), lr=4e-5)
criterion = torch.nn.CrossEntropyLoss()
loss_values = [] 
accuracy_values = []

In [17]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_pal_loss = 0.0
    running_corrects = 0.0
    total_samples = 0.0
    for images, labels in tqdm(train_loader):
        
        # Initialiser un tenseur pour stocker toutes les heatmaps
        batch_heatmaps = generate_batch_heatmaps(images, heatmap_generator)
            
         # Ensure that images require gradients
        images.requires_grad_()

         # Forward pass
        outputs = model(images)
        labels = labels.long()

        # Calcul de la classification loss
        classification_loss = criterion(outputs, labels)

        # Backward pass for gradients with respect to the input images
        classification_loss.backward(retain_graph=True)  
        gradients = images.grad

        # Compute the attribution maps as the element-wise product of the gradients and the input images
        attribution_maps = gradients * images

        # Compute the PAL loss using the attribution maps and the prior maps
        pal_loss_fn = PrivilegedAttributionLoss()
        pal_loss = pal_loss_fn(attribution_maps, batch_heatmaps)

        # Calcul de la PAL loss et de la classification loss
        total_loss = classification_loss + pal_loss

        # Backpropagation et optimisation
        optimizer.zero_grad()  # Clear gradients before the backward pass
        total_loss.backward()
        optimizer.step()

        # Mise à jour des running loss et PAL loss
        running_loss += classification_loss.item()
        running_pal_loss += pal_loss.item()         

        if epoch == 0:
            for i in range(images.size(0)):
                plot_element(images, batch_heatmaps, attribution_maps, gradients, i)


        # Mise à jour des running loss et PAL loss
        running_loss += classification_loss.item()
        running_pal_loss += pal_loss.item()

        # Calcul de l'accuracy
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)
        total_samples += labels.size(0)

    # Calcul des moyennes pour l'époque
    epoch_loss = running_loss / len(train_loader)
    epoch_pal_loss = running_pal_loss / len(train_loader)
    epoch_acc = running_corrects.double() / total_samples

    # Ajouter les valeurs moyennes aux listes
    loss_values.append(epoch_loss)
    accuracy_values.append(epoch_acc)

    # Affichage des résultats pour l'époque
    print(f'Epoch {epoch}/{num_epochs - 1}')
    print(f'Loss: {epoch_loss:.4f}, PAL Loss: {epoch_pal_loss:.4f}, Accuracy: {epoch_acc:.4f}')


  0%|          | 0/205 [00:00<?, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


: 