# Train CNN Model

This notebook contains the code to train a CNN model for emotion classification using the FERPlus dataset.

## Imports

In [6]:
import sys
import time
import os
import math
import csv
import numpy as np
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from models import build_model
from ferplus import FERPlusReader, FERPlusParameters, display_summary

## Helper Functions

In [None]:
def cost_func(training_mode, prediction, target):
    '''
    We use cross entropy in most mode, except for the multi-label mode, which require treating
    multiple labels exactly the same.
    '''
    if training_mode == 'majority' or training_mode == 'probability' or training_mode == 'crossentropy':
        # Cross Entropy.
        return nn.CrossEntropyLoss()(prediction, target)
    elif training_mode == 'multi_target':
        return nn.BCEWithLogitsLoss()(prediction, target)

## Main Function

In [None]:
def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs=100):
    output_model_path = os.path.join(base_folder, 'models')
    output_model_folder = os.path.join(output_model_path, model_name + '_' + training_mode)
    if not os.path.exists(output_model_folder):
        os.makedirs(output_model_folder)

    logging.basicConfig(filename=os.path.join(output_model_folder, "train.log"), filemode='w', level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())

    logging.info("Starting with training mode {} using {} model and max epochs {}.".format(training_mode, model_name, max_epochs))

    num_classes = 8
    model = build_model(num_classes, model_name)

    input_var = torch.FloatTensor
    label_var = torch.LongTensor

    logging.info("Loading data...")
    train_params = FERPlusParameters(num_classes, 48, 48, training_mode, False)
    test_and_val_params = FERPlusParameters(num_classes, 48, 48, "majority", True)

    train_data_reader = FERPlusReader.create(base_folder, ['FER2013Train'], "label.csv", train_params)
    val_data_reader = FERPlusReader.create(base_folder, ['FER2013Valid'], "label.csv", test_and_val_params)
    test_data_reader = FERPlusReader.create(base_folder, ['FER2013Test'], "label.csv", test_and_val_params)

    display_summary(train_data_reader, val_data_reader, test_data_reader)

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    epoch_size = train_data_reader.size()
    minibatch_size = 32

    logging.info("Start training...")
    for epoch in range(max_epochs):
        model.train()
        train_data_reader.reset()
        val_data_reader.reset()
        test_data_reader.reset()

        training_loss = 0
        training_accuracy = 0
        while train_data_reader.has_more():
            images, labels, current_batch_size = train_data_reader.next_minibatch(minibatch_size)
            images = torch.tensor(images, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = cost_func(training_mode, outputs, labels)
            loss.backward()
            optimizer.step()

            training_loss += loss.item() * current_batch_size
            _, predicted = torch.max(outputs, 1)
            training_accuracy += (predicted == labels).sum().item()

        training_accuracy /= train_data_reader.size()
        logging.info("Epoch {}: training loss: {:.4f}, training accuracy: {:.2f}%".format(epoch, training_loss, training_accuracy * 100))

        model.eval()
        val_accuracy = 0
        with torch.no_grad():
            while val_data_reader.has_more():
                images, labels, current_batch_size = val_data_reader.next_minibatch(minibatch_size)
                images = torch.tensor(images, dtype=torch.float32)
                labels = torch.tensor(labels, dtype=torch.long)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                val_accuracy += (predicted == labels).sum().item()

        val_accuracy /= val_data_reader.size()
        logging.info("Epoch {}: validation accuracy: {:.2f}%".format(epoch, val_accuracy * 100))