In [1]:
import utils.plot as plot
import utils.fer2013 as fer2013
import utils.daisee as daisee
from utils.hparams import HPS
import utils.loops as loops

import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 0. Define the Dataset to use

In [2]:
# dataset_name = "FER2013"
dataset_name = "DAiSEE"

# 1. Load DataLoader and apply data augmentation strategy to train_loader

In [3]:
import utils.transforms as transforms 

tf_name = 'simple' # Modify to change data augmentation pipeline used
augment_tf = transforms.get_transform(tf_name)
apply_dropout_tf = False

if dataset_name == "FER2013":
    train_loader, valid_loader, test_loader = fer2013.get_dataloaders(augment_tf, HPS['batch_size'], apply_dropout_tf=apply_dropout_tf)
    benchmark = None
    num_classes = 7
elif dataset_name == "DAiSEE":
    train_loader, valid_loader, test_loader = daisee.get_dataloaders(augment_tf, HPS['batch_size'], apply_dropout_tf=apply_dropout_tf)
    benchmark = "Engagement"
    num_classes = 4

## 1a. Plot the data augmentation strategy applied to train_loader

In [None]:
plot.plot_augmentation(augment_tf, dataset_name, benchmark, apply_dropout_tf=apply_dropout_tf)

# 2. Dataset Attributes

## 2a. Plot Dataset 

In [None]:
plot.plot_dataset(train_loader, dataset_name, benchmark)

## 2b. Show class distribution

In [None]:
plot.plot_class_distribution(dataset_name, benchmark)

# 3. Train Model

Run the cell of a model. Available models:
- VGGNet
- VGG16
- ResNet18
- ResNet50

## Train from scratch

In [4]:
# VGGNet
from models import vggnet_finetuned
model = vggnet_finetuned.VggNet(num_classes=num_classes)
model_name = 'VGGNet'

In [None]:
# VGG16 
from models import vgg16
model = vgg16.Vgg16(num_classes=num_classes)
model_name = 'VGG16'

In [8]:
# ResNet18
from models import resnet18
model = resnet18.ResNet18(num_classes=num_classes)
model_name = 'ResNet18'

In [None]:
# ResNet50
from models import resnet50
model = resnet50.ResNet50(num_classes=num_classes)
model_name = 'ResNet50'

In [8]:
import utils.earlystopper as es

def run_model(model, optimizer, train_loader, valid_loader, criterion, scheduler, scaler, num_epochs, model_name, tf_name, dataset_name, benchmark=None):
    if dataset_name == "DAiSEE" and not benchmark:
        raise ValueError('Benchmark metric not provided for DAiSEE dataset')
    
    print(f'Training {model_name} with transform {tf_name} on {DEVICE} using {dataset_name} dataset')
    if benchmark:
        print(f'Using benchmark metric: {benchmark}')
    model.to(DEVICE)
    best_accuracy_val = 0.0
    train_acc = []
    train_loss = []
    valid_acc = []
    valid_loss = []
    best_y_true = []
    best_y_pred = []
    male_acc = []
    female_acc = []
    early_stopper = es.EarlyStopper()
    for epoch in range(num_epochs):
        print('.' * 64)
        print(f"--- Epoch {epoch + 1}/{num_epochs} ---")
        
        tr_accuracy, tr_loss = loops.train_model(model, train_loader, optimizer, criterion, scaler, epoch, num_epochs, benchmark)
        print(f'train_loss: {tr_loss:.4f} - train_accuracy: {tr_accuracy:.4f}')

        val_accuracy, val_loss, y_true, y_pred, male_accuracy, female_accuracy = loops.evaluate_model(model, valid_loader, criterion, benchmark)
        
        # Update learning rate
        prev_lr = scheduler.get_last_lr()[0]
        scheduler.step(val_loss)
        curr_lr = scheduler.get_last_lr()[0]
        
        if prev_lr > curr_lr:  
            print(f'Updating lr {prev_lr}->{curr_lr}')
        
        # Update best model on validation dataset
        if val_accuracy > best_accuracy_val:
            best_y_true = y_true
            best_y_pred = y_pred
            best_accuracy_val = val_accuracy
            
            if benchmark:
                output_model = f'./models/outputs/{model_name}_{tf_name}_{dataset_name}-{benchmark}_best_valid.pth'
            else:
                output_model = f'./models/outputs/{model_name}_{tf_name}_{dataset_name}_best_valid.pth'
            torch.save(model.state_dict(), output_model)

        train_acc.append(tr_accuracy)
        train_loss.append(tr_loss)
        valid_acc.append(val_accuracy)
        valid_loss.append(val_loss)
        male_acc.append(male_accuracy)
        female_acc.append(female_accuracy)
        
        # Early stopping
        if early_stopper.early_stop(val_loss):
            print(f'Stopping early at Epoch {epoch + 1}, min val loss failed to decrease after {early_stopper.get_patience()} epochs')
            break

    return {
        'train_accuracy': train_acc,
        'train_loss': train_loss,
        'valid_accuracy': valid_acc,
        'valid_loss': valid_loss,
        'y_true': best_y_true,
        'y_pred': best_y_pred,
        'gender_accuracy': {
            "male": male_acc,
            "female": female_acc,
        }
    }

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=HPS['lr'], momentum=0.9, nesterov=True, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.75, patience=5)
scaler = GradScaler()

train_class_weights = loops.get_train_class_weights(dataset_name, benchmark)
criterion = nn.CrossEntropyLoss(weight=train_class_weights)
results = run_model(model, optimizer, train_loader, valid_loader, criterion, scheduler, scaler, HPS['num_epochs'], model_name, tf_name, dataset_name, benchmark)

# 4. Model Evaluation

In [None]:
plot.plot_training_history(results)

In [None]:
plot.plot_gender_history(results["gender_accuracy"])

In [None]:
plot.plot_confusion_matrix(results, benchmark, dataset_name)

In [None]:
plot.display_classification_report(results, benchmark, dataset_name)

# 6. Model Predictions

In [None]:
# Load model for prediction
if dataset_name == "FER2013":
    MODEL_PATH = f'./models/outputs/{model_name}_{tf_name}_{dataset_name}_best_valid.pth'
else:
    MODEL_PATH = f'./models/outputs/{model_name}_{tf_name}_{dataset_name}-{benchmark}_best_valid.pth'
model = vggnet_finetuned.VggNet(num_classes=num_classes)
model.load_state_dict(torch.load(MODEL_PATH))

In [None]:
# Plot predictions
plot.plot_predictions(model, test_loader, benchmark, dataset_name)

## a) Compare simple and occlusion_aware model predictions

In [None]:
# Load model for prediction
MODEL_PATH = f'./models/outputs/{model_name}_simple_{dataset_name}-{benchmark}_best_valid.pth'
model_simple = vggnet_finetuned.VggNet(num_classes=num_classes)
model_simple.load_state_dict(torch.load(MODEL_PATH))

MODEL_PATH = f'./models/outputs/{model_name}_occlusion_aware_{dataset_name}-{benchmark}_best_valid.pth'
model_occlusion_aware = vggnet_finetuned.VggNet(num_classes=num_classes)
model_occlusion_aware.load_state_dict(torch.load(MODEL_PATH))


# Plot comparison predictions on occluded images
plot.plot_compare_predictions(model_simple, model_occlusion_aware, test_loader, benchmark, dataset_name)