In [1]:
import random
import itertools
import os

import utils.utils as utils
import utils.datasets as datasets

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import f1_score
from skimage.transform import resize
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

pd.set_option("display.max_columns", 50)
%load_ext autoreload
%autoreload 2

In [2]:
# make sure everything is deterministic
random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [33]:
device = torch.device("cuda:1")

In [4]:
df, load_data = datasets.megc("cropped")

uv_frames = np.load("../data/megc_uv_frames_secrets_of_OF.npy")
uv_frames = resize(uv_frames, (uv_frames.shape[0], 3, 60, 60))

In [5]:
le = LabelEncoder()
labels = le.fit_transform(df["emotion"])
dataset = le.fit_transform(df["dataset"])

In [6]:
class MEData(Dataset):
    def __init__(self, frames, labels, dataset, transform=None):
        self.frames = frames
        self.labels = labels
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return self.frames.shape[0]

    def __getitem__(self, idx):
        sample = self.frames[idx, ...]
        if self.transform:
            sample = self.transform(sample)
        label = self.labels[idx]
        dataset = self.dataset[idx]
        return sample, label, dataset

In [27]:
class BDCNN(nn.Module):
    def __init__(self, size):
        super(BDCNN, self).__init__()
        self.feature_num = 128
        self.size = size
        self.net_dict = {
            "conv2_0": {8: 5400, 10: 3456, 20: 1152, 30: 512, 40: 288},
            "conv2_1": {8: 384, 10: 384, 20: 256, 30: 128, 40: 32},
            "fc1_0": {8: 3456, 10: 6144, 20: 12544, 30: 12800, 40: 6272},
        }
        self.pool = nn.MaxPool2d(3, 3, 1)
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1
        )
        self.conv2 = nn.Conv2d(
            self.net_dict["conv2_0"][self.size], self.net_dict["conv2_1"][self.size], 1
        )
        self.fc1 = nn.Linear(self.net_dict["fc1_0"][self.size], 1024)
        self.fc2 = nn.Linear(1024, 128)

    def forward(self, input):
        input1 = input[:, 0:1]
        input2 = input[:, 1:2]
        input3 = input[:, 2:3]

        blocks1 = []
        n = 120 // self.size
        for i in range(n):
            for j in range(n):
                blocks1.append(
                    input1[
                        :,
                        :,
                        self.size * i : self.size * (i + 1),
                        self.size * j : self.size * (j + 1),
                    ]
                )
        convs1 = []
        for i in range(n**2):
            convs1.append(self.pool(F.relu(self.conv1(blocks1[i]))))
        x1 = torch.cat((convs1[0], convs1[1]), dim=1)
        for i in range(2, n**2):
            x1 = torch.cat((x1, convs1[i]), dim=1)

        blocks2 = []
        for i in range(n):
            for j in range(n):
                blocks2.append(
                    input2[
                        :,
                        :,
                        self.size * i : self.size * (i + 1),
                        self.size * j : self.size * (j + 1),
                    ]
                )
        convs2 = []
        for i in range(n**2):
            convs2.append(self.pool(F.relu(self.conv1(blocks2[i]))))
        x2 = torch.cat((convs2[0], convs2[1]), dim=1)
        for i in range(2, n**2):
            x2 = torch.cat((x2, convs2[i]), dim=1)

        blocks3 = []
        for i in range(n):
            for j in range(n):
                blocks3.append(
                    input3[
                        :,
                        :,
                        self.size * i : self.size * (i + 1),
                        self.size * j : self.size * (j + 1),
                    ]
                )
        convs3 = []
        for i in range(n**2):
            convs3.append(self.pool(F.relu(self.conv1(blocks3[i]))))
        x3 = torch.cat((convs3[0], convs3[1]), dim=1)
        for i in range(2, n**2):
            x3 = torch.cat((x3, convs3[i]), dim=1)

        x = torch.cat((x1, x2, x3), dim=1)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x


class EstimatorCV:
    def __init__(self, feature_num, class_num):
        super(EstimatorCV, self).__init__()
        self.class_num = class_num
        self.CoVariance = torch.zeros(class_num, feature_num, feature_num).to(device)
        self.Ave = torch.zeros(class_num, feature_num).to(device)
        self.Amount = torch.zeros(class_num).to(device)

    def update_CV(self, features, labels):
        # N,C,A  batch_size,class_num,feature_num
        N = features.size(0)
        C = self.class_num
        A = features.size(1)

        NxCxFeatures = features.view(N, 1, A).expand(N, C, A)

        onehot = torch.zeros(N, C).to(device)
        onehot.scatter_(1, labels.view(-1, 1), 1)

        NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A)

        features_by_sort = NxCxFeatures.mul(NxCxA_onehot)

        Amount_CxA = NxCxA_onehot.sum(0)

        Amount_CxA[Amount_CxA == 0] = 1

        # C*A
        ave_CxA = features_by_sort.sum(0) / Amount_CxA

        # N*C*A
        var_temp = features_by_sort - ave_CxA.expand(N, C, A).mul(NxCxA_onehot)
        # permute  bmm  b*n*m b*m*p ->b*n*p
        # C*A*N C*N*A-> C*A*A
        var_temp = torch.bmm(var_temp.permute(1, 2, 0), var_temp.permute(1, 0, 2)).div(
            Amount_CxA.view(C, A, 1).expand(C, A, A)
        )

        sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A)

        sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A)

        weight_CV = sum_weight_CV.div(
            sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)
        )
        weight_CV[weight_CV != weight_CV] = 0

        weight_AV = sum_weight_AV.div(
            sum_weight_AV + self.Amount.view(C, 1).expand(C, A)
        )
        weight_AV[weight_AV != weight_AV] = 0

        additional_CV = weight_CV.mul(1 - weight_CV).mul(
            torch.bmm(
                (self.Ave - ave_CxA).view(C, A, 1), (self.Ave - ave_CxA).view(C, 1, A)
            )
        )

        self.CoVariance = (
            self.CoVariance.mul(1 - weight_CV) + var_temp.mul(weight_CV)
        ).detach() + additional_CV.detach()

        self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach()

        self.Amount += onehot.sum(0)


class ISDALoss(nn.Module):
    def __init__(self, feature_num, class_num, u):
        super(ISDALoss, self).__init__()
        self.estimator = EstimatorCV(feature_num, class_num)
        self.class_num = class_num
        self.cross_entropy = nn.CrossEntropyLoss(
            weight=torch.tensor(u), reduction="mean"
        )

    def isda_aug(self, fc, features, y, labels, cv_matrix, ratio):

        N = features.size(0)
        C = self.class_num
        A = features.size(1)

        weight_m = list(fc.parameters())[0]
        NxW_ij = weight_m.expand(N, C, A)
        NxW_kj = torch.gather(NxW_ij, 1, labels.view(N, 1, 1).expand(N, C, A))
        CV_temp = cv_matrix[labels]
        sigma2 = ratio * torch.bmm(
            torch.bmm(NxW_ij - NxW_kj, CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1)
        )
        sigma2 = sigma2.mul(torch.eye(C).to(device).expand(N, C, C)).sum(2).view(N, C)
        aug_result = y + 0.5 * sigma2

        return aug_result

    def forward(self, model, fc, x, target_x, ratio):
        features = model(x)
        y = fc(features)
        self.estimator.update_CV(features.detach(), target_x)
        isda_aug_y = self.isda_aug(
            fc, features, y, target_x, self.estimator.CoVariance.detach(), ratio
        )
        loss = self.cross_entropy(isda_aug_y, target_x)

        return loss, y
    
    
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate"""
    if epoch in [80, 120]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1

In [27]:
class BDCNN(nn.Module):
    def __init__(self, size):
        super(BDCNN, self).__init__()
        self.feature_num = 128
        self.size = size
        self.net_dict = {
            "conv2_0": {8: 5400, 10: 3456, 20: 1152, 30: 512, 40: 288},
            "conv2_1": {8: 384, 10: 384, 20: 256, 30: 128, 40: 32},
            "fc1_0": {8: 3456, 10: 6144, 20: 12544, 30: 12800, 40: 6272},
        }
        self.pool = nn.MaxPool2d(3, 3, 1)
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1
        )
        self.conv2 = nn.Conv2d(
            self.net_dict["conv2_0"][self.size], self.net_dict["conv2_1"][self.size], 1
        )
        self.fc1 = nn.Linear(self.net_dict["fc1_0"][self.size], 1024)
        self.fc2 = nn.Linear(1024, 128)
        self.fc = nn.Linear(128, 3)
        self.isda_loss = ISDALoss(feature_num=128, class_num=3,u=[0.15,0.425,0.425])

    def forward(self, input):
        input1 = input[:, 0:1]
        input2 = input[:, 1:2]
        input3 = input[:, 2:3]

        blocks1 = []
        n = 120 // self.size
        for i in range(n):
            for j in range(n):
                blocks1.append(
                    input1[
                        :,
                        :,
                        self.size * i : self.size * (i + 1),
                        self.size * j : self.size * (j + 1),
                    ]
                )
        convs1 = []
        for i in range(n**2):
            convs1.append(self.pool(F.relu(self.conv1(blocks1[i]))))
        x1 = torch.cat((convs1[0], convs1[1]), dim=1)
        for i in range(2, n**2):
            x1 = torch.cat((x1, convs1[i]), dim=1)

        blocks2 = []
        for i in range(n):
            for j in range(n):
                blocks2.append(
                    input2[
                        :,
                        :,
                        self.size * i : self.size * (i + 1),
                        self.size * j : self.size * (j + 1),
                    ]
                )
        convs2 = []
        for i in range(n**2):
            convs2.append(self.pool(F.relu(self.conv1(blocks2[i]))))
        x2 = torch.cat((convs2[0], convs2[1]), dim=1)
        for i in range(2, n**2):
            x2 = torch.cat((x2, convs2[i]), dim=1)

        blocks3 = []
        for i in range(n):
            for j in range(n):
                blocks3.append(
                    input3[
                        :,
                        :,
                        self.size * i : self.size * (i + 1),
                        self.size * j : self.size * (j + 1),
                    ]
                )
        convs3 = []
        for i in range(n**2):
            convs3.append(self.pool(F.relu(self.conv1(blocks3[i]))))
        x3 = torch.cat((convs3[0], convs3[1]), dim=1)
        for i in range(2, n**2):
            x3 = torch.cat((x3, convs3[i]), dim=1)

        x = torch.cat((x1, x2, x3), dim=1)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc(x)
        return x


class EstimatorCV:
    def __init__(self, feature_num, class_num):
        super(EstimatorCV, self).__init__()
        self.class_num = class_num
        self.CoVariance = torch.zeros(class_num, feature_num, feature_num).to(device)
        self.Ave = torch.zeros(class_num, feature_num).to(device)
        self.Amount = torch.zeros(class_num).to(device)

    def update_CV(self, features, labels):
        # N,C,A  batch_size,class_num,feature_num
        N = features.size(0)
        C = self.class_num
        A = features.size(1)

        NxCxFeatures = features.view(N, 1, A).expand(N, C, A)

        onehot = torch.zeros(N, C).to(device)
        onehot.scatter_(1, labels.view(-1, 1), 1)

        NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A)

        features_by_sort = NxCxFeatures.mul(NxCxA_onehot)

        Amount_CxA = NxCxA_onehot.sum(0)

        Amount_CxA[Amount_CxA == 0] = 1

        # C*A
        ave_CxA = features_by_sort.sum(0) / Amount_CxA

        # N*C*A
        var_temp = features_by_sort - ave_CxA.expand(N, C, A).mul(NxCxA_onehot)
        # permute  bmm  b*n*m b*m*p ->b*n*p
        # C*A*N C*N*A-> C*A*A
        var_temp = torch.bmm(var_temp.permute(1, 2, 0), var_temp.permute(1, 0, 2)).div(
            Amount_CxA.view(C, A, 1).expand(C, A, A)
        )

        sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A)

        sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A)

        weight_CV = sum_weight_CV.div(
            sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)
        )
        weight_CV[weight_CV != weight_CV] = 0

        weight_AV = sum_weight_AV.div(
            sum_weight_AV + self.Amount.view(C, 1).expand(C, A)
        )
        weight_AV[weight_AV != weight_AV] = 0

        additional_CV = weight_CV.mul(1 - weight_CV).mul(
            torch.bmm(
                (self.Ave - ave_CxA).view(C, A, 1), (self.Ave - ave_CxA).view(C, 1, A)
            )
        )

        self.CoVariance = (
            self.CoVariance.mul(1 - weight_CV) + var_temp.mul(weight_CV)
        ).detach() + additional_CV.detach()

        self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach()

        self.Amount += onehot.sum(0)


class ISDALoss(nn.Module):
    def __init__(self, feature_num: int, class_num: int):
        super(ISDALoss, self).__init__()
        self.feature_num = feature_num
        self.class_num = class_num
        self.estimator = EstimatorCV(self.feature_num, self.class_num)

    def isda_aug(self, features, y, labels, cv_matrix, ratio):

        N = features.size(0)
        C = self.class_num
        A = features.size(1)

        weight_m = list(fc.parameters())[0]
        NxW_ij = weight_m.expand(N, C, A)
        NxW_kj = torch.gather(NxW_ij, 1, labels.view(N, 1, 1).expand(N, C, A))
        CV_temp = cv_matrix[labels]
        sigma2 = ratio * torch.bmm(
            torch.bmm(NxW_ij - NxW_kj, CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1)
        )
        sigma2 = sigma2.mul(torch.eye(C).to(device).expand(N, C, C)).sum(2).view(N, C)
        aug_result = y + 0.5 * sigma2

        return aug_result

    def forward(self, features, predictions, fc_weights, target_x, ratio):
        self.estimator.update_CV(features.detach(), target_x)
        isda_aug_y = self.isda_aug(
            fc_weights, features, predictions, target_x, self.estimator.CoVariance.detach(), ratio
        )

        return isda_aug_y
    
    
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate"""
    if epoch in [80, 120]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1

BDCN - Minor changes to the original
* Flownet is not used. This is simply to make the expriments faster. The authors show ablation studies without flownet input (Table IX). With the removal of vis (Flownet optical flow) the performance drops by 2 percentage units.

* Only use 1 channel inputs. All of the optical flow components only have 1 channel originally---making them have three would just add redundancy. (When Flownet is not used) (Figure 1 shows optical strain as the combination of u+v+os, not sure which one it is)

In [31]:
uv_frames = resize(uv_frames, (df.shape[0], 3, 120, 120))

 without val set
0.72
with val
Total f1: 0.5461341408311106, SMIC: 0.6573265559144891, CASME2: 0.4632412455521449, SAMM: 0.42188076039588984

In [44]:
LOSO(
    uv_frames, df, epochs=100, lr=0.01, weight_decay=1e-4, dropout=0.5, batch_size=128
)

Subject: 006, n=11 | train_f1: 0.85887 | test_f1: 0.55556, best_epoch
Subject: 007, n=08 | train_f1: 0.76650 | test_f1: 0.11111, best_epoch
Subject: 009, n=04 | train_f1: 0.34783 | test_f1: 0.42857, best_epoch
Subject: 01, n=03 | train_f1: 0.81841 | test_f1: 1.0, best_epoch
Subject: 010, n=04 | train_f1: 0.88994 | test_f1: 1.0, best_epoch
Subject: 011, n=20 | train_f1: 0.70752 | test_f1: 0.26032, best_epoch
Subject: 012, n=03 | train_f1: 0.83763 | test_f1: 0.0, best_epoch
Subject: 013, n=06 | train_f1: 0.83083 | test_f1: 0.16667, best_epoch
Subject: 014, n=10 | train_f1: 0.21890 | test_f1: 0.20635, best_epoch
Subject: 015, n=03 | train_f1: 0.83825 | test_f1: 0.66667, best_epoch
Subject: 016, n=05 | train_f1: 0.60875 | test_f1: 1.0, best_epoch
Subject: 017, n=04 | train_f1: 0.61069 | test_f1: 0.42857, best_epoch
Subject: 018, n=03 | train_f1: 0.46235 | test_f1: 0.66667, best_epoch
Subject: 019, n=01 | train_f1: 0.33377 | test_f1: 0.0, best_epoch
Subject: 02, n=09 | train_f1: 0.59542 | t

In [29]:
def LOSO(
    features,
    df,
    epochs=200,
    lr=0.01,
    batch_size=128,
    dropout=0.5,
    weight_decay=0.001,
    verbose=True,
):
    outputs_list = []
    # groupby reorders elements, now the labels are in same order as outputs
    df_groupby = pd.concat([i[1] for i in df.groupby("subject")])
    dataset_groupby = df_groupby["dataset"]

    le = LabelEncoder()
    labels = le.fit_transform(df["emotion"])
    labels_groupby = le.transform(df_groupby["emotion"])

    # loop over each subject
    for group in df.groupby("subject"):
        subject = group[0]
        # split data to train and test based on the subject index
        train_index = df[df["subject"] != subject].index
        X_train = features[train_index, :]
        y_train = labels[train_index]
        dataset_train = dataset[train_index]

        test_index = df[df["subject"] == subject].index
        X_test = features[test_index, :]
        y_test = labels[test_index]
        dataset_test = dataset[test_index]

        # create pytorch dataloaders from the split
        megc_dataset_train = MEData(X_train, y_train, dataset_train, None)
        dataset_loader_train = torch.utils.data.DataLoader(
            megc_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0
        )

        megc_dataset_test = MEData(X_test, y_test, dataset_test, None)
        dataset_loader_test = torch.utils.data.DataLoader(
            megc_dataset_test, batch_size=100, shuffle=False, num_workers=0
        )

        model = BDCNN(size=10).to(device)
        criterion = ISDALoss(model.feature_num, df["emotion"].nunique(), [0.15,0.425,0.425]).to(device)
        fc = nn.Linear(model.feature_num, df["emotion"].nunique()).to(device)
        optimizer = optim.SGD([{'params': model.parameters()}, {'params': fc.parameters()}],
                              lr=lr, momentum=0.9, weight_decay=weight_decay
        )
        model.train()
        for epoch in range(epochs):
            adjust_learning_rate(optimizer, epoch + 1)
            for batch in dataset_loader_train:
                data_batch, labels_batch = batch[0].to(device), batch[1].to(device)

                optimizer.zero_grad()
                
                ratio = 0.5 * (epoch / epochs)
                loss, outputs = criterion(model, fc, data_batch.float(), labels_batch, ratio)
                loss.backward()
                optimizer.step()

        # Test model
        model.eval()
        (
            data_batch_test,
            labels_batch_test,
            _,
        ) = dataset_loader_test.__iter__().__next__()
        data_batch_test, labels_batch_test = data_batch_test.to(
            device
        ), labels_batch_test.to(device)
        outputs = fc(model(data_batch_test.float()))
        _, prediction = outputs.max(1)
        prediction = prediction.cpu().data.numpy()
        outputs_list.append(prediction)

        train_outputs = fc(model(data_batch.float()))
        _, train_prediction = train_outputs.max(1)
        train_prediction = train_prediction.cpu().data.numpy()
        train_f1 = f1_score(
            labels_batch.cpu().data.numpy(), train_prediction, average="macro"
        )
        test_f1 = f1_score(
            labels_batch_test.cpu().data.numpy(), prediction, average="macro"
        )

        # Print statistics
        if verbose:
            print(
                "Subject: {}, n={} | train_f1: {:.5f} | test_f1: {:.5}".format(
                    subject, str(labels_batch_test.shape[0]).zfill(2), train_f1, test_f1
                )
            )

    outputs = np.concatenate(outputs_list)
    f1_total = f1_score(labels_groupby, outputs, average="macro")
    idx = dataset_groupby == "smic"
    f1_smic = f1_score(labels_groupby[idx], outputs[idx], average="macro")
    idx = dataset_groupby == "casme2"
    f1_casme2 = f1_score(labels_groupby[idx], outputs[idx], average="macro")
    idx = dataset_groupby == "samm"
    f1_samm = f1_score(labels_groupby[idx], outputs[idx], average="macro")
    print(
        "Total f1: {}, SMIC: {}, CASME2: {}, SAMM: {}".format(
            f1_total, f1_smic, f1_casme2, f1_samm
        )
    )

In [37]:
import copy
from sklearn.model_selection import train_test_split

In [43]:
def LOSO(
    features,
    df,
    epochs=200,
    lr=0.01,
    batch_size=128,
    dropout=0.5,
    weight_decay=0.001,
    verbose=True,
):
    outputs_list = []
    # groupby reorders elements, now the labels are in same order as outputs
    df_groupby = pd.concat([i[1] for i in df.groupby("subject")])
    dataset_groupby = df_groupby["dataset"]

    le = LabelEncoder()
    labels = le.fit_transform(df["emotion"])
    labels_groupby = le.transform(df_groupby["emotion"])

    # loop over each subject
    for group in df.groupby("subject"):
        subject = group[0]
        # split data to train and test based on the subject index
        train_index = df[df["subject"] != subject].index
        X_train = features[train_index, :]
        y_train = labels[train_index]
        dataset_train = dataset[train_index]

        test_index = df[df["subject"] == subject].index
        X_test = features[test_index, :]
        y_test = labels[test_index]
        dataset_test = dataset[test_index]
        
        # Train val split
        X_train, X_val, y_train, y_val, dataset_train, dataset_val = train_test_split(
        X_train, y_train, dataset_train, test_size=0.2, random_state=0)
        # Create pytorch dataloaders from the split
        megc_dataset_train = MEData(X_train, y_train, dataset_train, None)
        dataset_loader_train = torch.utils.data.DataLoader(
            megc_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0
        )
        megc_dataset_val = MEData(X_val, y_val, dataset_val, None)
        dataset_loader_val = torch.utils.data.DataLoader(
            megc_dataset_val, batch_size=100, shuffle=False, num_workers=0
        )
        megc_dataset_test = MEData(X_test, y_test, dataset_test, None)
        dataset_loader_test = torch.utils.data.DataLoader(
            megc_dataset_test, batch_size=100, shuffle=False, num_workers=0
        )

        model = BDCNN(size=10).to(device)
        criterion = ISDALoss(model.feature_num, df["emotion"].nunique(), [0.15,0.425,0.425]).to(device)
        fc = nn.Linear(model.feature_num, df["emotion"].nunique()).to(device)
        optimizer = optim.SGD([{'params': model.parameters()}, {'params': fc.parameters()}],
                              lr=lr, momentum=0.9, weight_decay=weight_decay
        )
        model.train()
        best_f1 = 0
        best_model = model
        best_epoch = 0
        for epoch in range(epochs):
            adjust_learning_rate(optimizer, epoch + 1)
            for batch in dataset_loader_train:
                data_batch, labels_batch = batch[0].to(device), batch[1].to(device)

                optimizer.zero_grad()
                
                ratio = 0.5 * (epoch / epochs)
                loss, outputs = criterion(model, fc, data_batch.float(), labels_batch, ratio)
                loss.backward()
                optimizer.step()
            # Validate and choose best epoch
            model.eval()
            (
                data_batch_val,
                labels_batch_val,
                _,
            ) = dataset_loader_val.__iter__().__next__()
            data_batch_val = data_batch_val.to(device)
            outputs = fc(model(data_batch_val.float()))
            _, prediction = outputs.max(1)
            prediction = prediction.cpu().data.numpy()
            f1_val = f1_score(labels_batch_val, prediction, average="macro")
            if f1_val > best_f1:
                best_f1 = f1_val
                best_model = copy.deepcopy(model)
                best_epoch = epoch
            model.train()
            

        # Test model
        best_model.eval()
        (
            data_batch_test,
            labels_batch_test,
            _,
        ) = dataset_loader_test.__iter__().__next__()
        data_batch_test = data_batch_test.to(device)
        outputs = fc(best_model(data_batch_test.float()))
        _, prediction = outputs.max(1)
        prediction = prediction.cpu().data.numpy()
        outputs_list.append(prediction)

        train_outputs = fc(best_model(data_batch.float()))
        _, train_prediction = train_outputs.max(1)
        train_prediction = train_prediction.cpu().data.numpy()
        train_f1 = f1_score(
            labels_batch.cpu().data.numpy(), train_prediction, average="macro"
        )
        test_f1 = f1_score(
            labels_batch_test.cpu().data.numpy(), prediction, average="macro"
        )

        # Print statistics
        if verbose:
            print(
                "Subject: {}, n={} | train_f1: {:.5f} | test_f1: {:.5}, {}".format(
                    subject, str(labels_batch_test.shape[0]).zfill(2), train_f1, test_f1,
                    best_epoch
                )
            )

    outputs = np.concatenate(outputs_list)
    f1_total = f1_score(labels_groupby, outputs, average="macro")
    idx = dataset_groupby == "smic"
    f1_smic = f1_score(labels_groupby[idx], outputs[idx], average="macro")
    idx = dataset_groupby == "casme2"
    f1_casme2 = f1_score(labels_groupby[idx], outputs[idx], average="macro")
    idx = dataset_groupby == "samm"
    f1_samm = f1_score(labels_groupby[idx], outputs[idx], average="macro")
    print(
        "Total f1: {}, SMIC: {}, CASME2: {}, SAMM: {}".format(
            f1_total, f1_smic, f1_casme2, f1_samm
        )
    )