In [None]:
import utils.plot as plot
import utils.fer2013 as fer2013
from utils.hparams import HPS
import utils.loops as loops

import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 1. Load DataLoader and apply data augmentation strategy to train_loader

In [None]:
import utils.transforms as transforms 
tf_name = 'occlusion_aware' # Modify to change data augmentation pipeline used
augment_tf = transforms.get_transform(tf_name)
apply_dropout_tf = True

train_loader, valid_loader, test_loader = fer2013.get_dataloaders(augment_tf, HPS['batch_size'], apply_dropout_tf=apply_dropout_tf)

#### 1a. Plot the data augmentation strategy applied to train_loader

In [None]:
plot.plot_augmentation(augment_tf, apply_dropout_tf=apply_dropout_tf)

### 2. Class Weights

#### 2a. Plot Dataset 

In [None]:
plot.plot_fer_dataset(train_loader)

#### 2b. Create Class Weights for Imbalanced Data

In [None]:
train_class_weights = plot.get_class_weights(DEVICE)
print(train_class_weights)

### 3. Train and Evaluate Functions

In [None]:
import utils.earlystopper as es

def run_model(model, optimizer, train_loader, valid_loader, criterion, scheduler, scaler, num_epochs, model_name, tf_name):
    print(f'Training {model_name} with transform {tf_name} on {DEVICE}')
    model.to(DEVICE)
    best_accuracy_val = 0.0
    train_acc = []
    train_loss = []
    valid_acc = []
    valid_loss = []
    best_y_true = []
    best_y_pred = []
    early_stopper = es.EarlyStopper()
    for epoch in range(num_epochs):
        print('.' * 64)
        print(f"--- Epoch {epoch + 1}/{num_epochs} ---")
        
        tr_accuracy, tr_loss = loops.train_model(model, train_loader, optimizer, criterion, scaler, epoch, num_epochs)
        print(f'train_loss: {tr_loss:.4f} - train_accuracy: {tr_accuracy:.4f}')
        
        val_accuracy, val_loss, y_true, y_pred = loops.evaluate_model(model, valid_loader, criterion)
        
        # Update learning rate
        prev_lr = scheduler.get_last_lr()[0]
        scheduler.step(val_loss)
        curr_lr = scheduler.get_last_lr()[0]
        
        if prev_lr > curr_lr:  
            print(f'Updating lr {prev_lr}->{curr_lr}')
        
        # Update best model on validation dataset
        if val_accuracy > best_accuracy_val:
            best_y_true = y_true
            best_y_pred = y_pred
            best_accuracy_val = val_accuracy
            torch.save(model.state_dict(), f'./models/outputs/{model_name}_{tf_name}_best_valid.pth')

        train_acc.append(tr_accuracy)
        train_loss.append(tr_loss)
        valid_acc.append(val_accuracy)
        valid_loss.append(val_loss)
        
        # Early stopping
        if early_stopper.early_stop(val_loss):
            print(f'Stopping early at Epoch {epoch + 1}, min val loss failed to decrease after {early_stopper.get_patience()} epochs')
            break

    return {
        'train_accuracy': train_acc,
        'train_loss': train_loss,
        'valid_accuracy': valid_acc,
        'valid_loss': valid_loss,
        'y_true': best_y_true,
        'y_pred': best_y_pred
    }

### 4. Train Model

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler

Run the cell of a model. Available models:
- VGGNet
- VGG16
- ResNet18
- ResNet50
- EfficientNetB7

In [None]:
# VGGNet
from models import vggnet_finetuned
model = vggnet_finetuned.VggNet()
model_name = 'VGGNet'

In [None]:
# ResNet18
from models import resnet18
model = resnet18.ResNet18()
model_name = 'ResNet18'

In [None]:
# ResNet50
from models import resnet50
model = resnet50.ResNet50()
model_name = 'ResNet50'

In [None]:
# EfficientNetB7
from models import efnb7
model = efnb7.EfficientNetB7()
model_name = 'EfficientNetB7'

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=HPS['lr'], momentum=0.9, nesterov=True, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.75, patience=5)
scaler = GradScaler()

results = run_model(model, optimizer, train_loader, valid_loader, criterion, scheduler, scaler, HPS['num_epochs'], model_name, tf_name)

### 5. Model Evaluation

In [None]:
plot.plot_training_history(results)

In [None]:
plot.plot_confusion_matrix(results)

In [None]:
plot.display_classification_report(results)

### 6. Model Predictions

In [None]:
# Load model for prediction
MODEL_PATH = './models/outputs/VGGNet_occlusion_aware_best_valid.pth'
model = vggnet_finetuned.VggNet()
model.load_state_dict(torch.load(MODEL_PATH))

In [None]:
# Plot predictions
plot.plot_predictions(model, test_loader, DEVICE)