In [1]:
import os
import copy
import sys
import glob
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import pandas as pd
from PIL import Image
import pickle
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, f1_score

import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms
from torchvision.models import resnet101, mobilenet_v2
from tqdm.notebook import tqdm

print(f'Torch version: {torch.__version__}')
print(f'Timm version: {timm.__version__}')

Torch version: 2.3.0+cu121
Timm version: 0.9.16


In [27]:
dataset_dir = '../../datasets/rafdb/basic'
batch_size = 32

model = torch.load("../../models/raf/raf_enet_b2.pt")
model = model.eval()

In [30]:
USE_ENET2 = True
use_cuda = torch.cuda.is_available()
IMG_SIZE = 260 if USE_ENET2 else 224

train_transforms = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

In [31]:
class RafDb(data.Dataset):
    def __init__(self, dataset_path, phase, cache_data=True, data_transforms=None, update_cache=False):
        self.phase = phase
        self.transforms = data_transforms
        self.dataset_path = dataset_path
        
        if cache_data:
            cache_path = os.path.join(dataset_path, f'rafdb_{phase}.csv')
            if os.path.exists(cache_path) and not update_cache:
                df = pd.read_csv(cache_path)
            else:
                df = self.load_data()
                df.to_csv(cache_path)   
        else:
            df = self.load_data()

        if phase == 'train':
            self.data = df[df['file_name'].str.startswith('train')]
        elif phase == 'test':
            self.data = df[df['file_name'].str.startswith('test')]
        else:
            raise TypeError(f"Invalid value for phase {phase}")
        
        self.file_paths = self.data.loc[:, 'file_path'].values
        self.labels = self.data.loc[:, 'label'].values - 1
        print(f'{phase} set: {len(self)} images')
            
    def load_data(self):
        df = pd.read_csv(os.path.join(self.dataset_path, 'EmoLabel', 'list_patition_label.txt'), sep=' ', header=None, names=['file_name', 'label'])
        file_names = df.loc[:, 'file_name'].values
        file_paths = []
        for f in file_names:
            f = f.split(".")[0]
            f += '_aligned.jpg'
            path = os.path.join(self.dataset_path, 'Image', 'aligned', f)
            file_paths.append(path)
        
        df['file_path'] = file_paths
        return df
    
    def get_weights(self):
        sample_label, sample_counts = np.unique(self.labels, return_counts=True)
        cw = 1/sample_counts
        cw /= cw.min()
        class_weights = {i:cwi for i, cwi in zip(sample_label, cw)}
        return class_weights
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        path = self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]
        
        if self.transforms is not None:
            image = self.transforms(image)
            
        return image, label

In [32]:
class EmotionLabel:
    def __init__(self):
        self.labels = ['Surprise', 'Fear', 'Disgust', 'Happiness', 'Sadness', 'Anger', 'Neutral']
        self.index_to_label = {index: label for index, label in enumerate(self.labels)} 
        self.label_to_index = {label: index for index, label in enumerate(self.labels)}
        
    def get_index(self, label):
        return self.label_to_index.get(label, None)
    
    def get_label(self, index):
        return self.index_to_label.get(index, None)

In [33]:
train_set = RafDb(dataset_dir, 'train', data_transforms=train_transforms, update_cache=True)
val_set = RafDb(dataset_dir, 'test', data_transforms=test_transforms, update_cache=True)
train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, **kwargs)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, **kwargs)

train set: 12271 images
test set: 3068 images


In [34]:
mapper = EmotionLabel()

class_weights = train_set.get_weights()
for label_idx, weight in class_weights.items():
    print(f'{mapper.get_label(label_idx)}: {weight:.2f}')

Surprise: 3.70
Fear: 16.98
Disgust: 6.66
Happiness: 1.00
Sadness: 2.41
Anger: 6.77
Neutral: 1.89


In [35]:
y_val, y_scores_val = [], []
model.eval()

for image, label in val_set:
    image.unsqueeze_(0)
    image = image.cuda()
    scores = model(image)
    scores = scores[0].data.cpu().numpy()
    y_scores_val.append(scores)
    y_val.append(label)

y_scores_val = np.array(y_scores_val)
y_val = np.array(y_val)
print(y_scores_val.shape, y_val.shape)

(3068, 7) (3068,)


In [36]:
y_pred = np.argmax(y_scores_val, axis=1)
accuracy = 100.0 * (y_val == y_pred).sum() / len(y_val)
print(f"Validation accuracy: {accuracy}")

y_train = np.array(train_set.labels)

# Accuracy for each class
for i in range(y_scores_val.shape[1]):
    _val_acc = (y_pred[y_val == i] == i).sum() / (y_val == i).sum()
    print('%s (%d/%d) -- Accuracy: %f' %(mapper.get_label(i), (y_train == i).sum(), (y_val == i).sum(), (100 * _val_acc)))

Validation accuracy: 87.1251629726206
Surprise (1290/329) -- Accuracy: 87.234043
Fear (281/74) -- Accuracy: 55.405405
Disgust (717/160) -- Accuracy: 65.000000
Happiness (4772/1185) -- Accuracy: 92.995781
Sadness (1982/478) -- Accuracy: 84.309623
Anger (705/162) -- Accuracy: 85.802469
Neutral (2524/680) -- Accuracy: 87.794118


In [37]:
precision = precision_score(y_val, y_pred, average='macro')
recall = recall_score(y_val, y_pred, average='macro')
f1 = f1_score(y_val, y_pred, average='macro')

print(f'Precision: {precision:.2f}\nRecall: {recall:.2f}\nF1 Score: {f1:.2f}')

Precision: 0.81
Recall: 0.80
F1 Score: 0.80


In [38]:
precision = precision_score(y_val, y_pred, average='weighted')
recall = recall_score(y_val, y_pred, average='weighted')
f1 = f1_score(y_val, y_pred, average='weighted')

print(f'Precision: {precision:.2f}\nRecall: {recall:.2f}\nF1 Score: {f1:.2f}')

Precision: 0.87
Recall: 0.87
F1 Score: 0.87
