In [1]:
from meb import utils, datasets, core, models

from tqdm import tqdm
from functools import partial

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import pytorchvideo
import timm
from timm.scheduler.cosine_lr import CosineLRScheduler
import pytorchvideo.transforms
import torch.nn.functional as F



pd.set_option("display.max_columns", 50)
%load_ext autoreload
%autoreload 2

# 1

In [13]:
c = datasets.CrossDataset(resize=64, ignore_validation=True, optical_flow=True)
df = c.data_frame
data = c.data

In [14]:
c = datasets.Samm(resize=64, ignore_validation=True, optical_flow=True)
dfs = c.data_frame
datas = c.data

In [15]:
# Add missing AU columns to SAMM that are missing compared to all data
dfs_aus = pd.DataFrame(dfs.loc[:, "AU1":].to_dict(), columns=df.loc[:, "AU1":].columns).fillna(0)
dfs_dropped = dfs.drop(dfs.loc[0, "AU1":].index.tolist(), axis=1)
dfs = pd.concat([dfs_dropped, dfs_aus], axis=1)

In [16]:
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder

action_units = dfs.loc[:, "AU1":].columns
df_dataset = dfs
model = GradientBoostingClassifier(random_state=0)

aus = np.array(df_dataset.loc[:, action_units])
emotions = df_dataset["Objective Classes"].astype("str")
model.fit(aus, emotions)
predicted_emotions = model.predict(aus)
f1 = f1_score(emotions, predicted_emotions, average="macro")
print(f1 * 100)

99.17754854925882


In [17]:
df.insert(7, "objective_class", model.predict(np.array(df.loc[:, "AU1":])))

#### MEGC2018 validation

In [18]:
idx = df["objective_class"].isin(["6", "7"])
df = df[~idx].reset_index(drop=True)
data = data[~idx]

In [19]:
df["subject"] = df["subject"].astype(str) + df["dataset"]

In [20]:
from meb.core.train_eval import InputData
from typing import List

class MEGC2018Validator(core.Validator):
    def __init__(self, config: core.Config, verbose: bool = True):
        super().__init__(config, split_column="subject")
        self.verbose = verbose
        self.disable_tqdm = False

    def validate(
        self, df: pd.DataFrame, input_data: InputData, seed_n: int = 1
    ) -> List[torch.tensor]:
        utils.set_random_seeds(seed_n)
        subject_names = df.loc[df["dataset"].isin(["samm", "casme2"]), "subject"].unique()
        le = LabelEncoder()
        labels = le.fit_transform(df["objective_class"])
        outputs_list = []
        for subject_name in subject_names:
            train_metrics, test_metrics, outputs_test = self.validate_split(
                df, input_data, labels, subject_name
            )
            outputs_list.append(outputs_test)
            if self.verbose:
                self.printer.print_train_test_evaluation(
                    train_metrics, test_metrics, subject_name, outputs_test.shape[0]
                )

        # Calculate total f1-scores
        predictions = torch.cat(outputs_list)
        f1_total = self.evaluation_fn(labels[df["dataset"].isin(["samm", "casme2"])], predictions)
        if self.verbose:
            print("Total f1: {}".format(f1_total))
        return outputs_list

In [21]:
class Config(core.Config):
    action_units = None
    epochs = 200
    batch_size = 256
    evaluation_fn = utils.MultiClassF1Score
    criterion = nn.CrossEntropyLoss
    device = torch.device("cuda:1")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    model = partial(models.SSSNet, num_classes=5)

In [31]:
out = MEGC2018Validator(Config).validate(df, data)

100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.82it/s]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.1667


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.5833


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.6


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.2063


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.631


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 0.48


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.8


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.709


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 16casme2, n=3 | train_mean: 1.0 | test_mean: 0.25


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.5793


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.7778


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 0.4615


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 21casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.7222


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 24casme2, n=10 | train_mean: 1.0 | test_mean: 0.2444


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:49<00:00,  1.82it/s]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.7353


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 006samm, n=10 | train_mean: 1.0 | test_mean: 0.6667


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [01:52<00:00,  1.77it/s]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.5202


100%|█████████████████████████████████████████| 200/200 [01:52<00:00,  1.78it/s]


Subject: 012samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 0.6667


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.80it/s]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:52<00:00,  1.78it/s]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:52<00:00,  1.78it/s]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.81it/s]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 026samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:50<00:00,  1.82it/s]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [01:49<00:00,  1.82it/s]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.80it/s]


Subject: 035samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [01:51<00:00,  1.79it/s]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 0.0
Total f1: [0.6934895631306097]


In [19]:
idx = df["dataset"].isin(["samm", "casme2"])
out = MEGC2018Validator(Config).validate(df[idx].reset_index(), data[idx])

100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.17it/s]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.3667


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.13it/s]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.3095


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.2857


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.7083


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.05it/s]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.12it/s]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.2063


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.13it/s]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.631


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.12it/s]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 0.48


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.6786


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.4722


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.06it/s]


Subject: 13casme2, n=5 | train_mean: 0.9961 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.04it/s]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.09it/s]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 16casme2, n=3 | train_mean: 1.0 | test_mean: 0.1667


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.6324


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.08it/s]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.8933


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 0.4615


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.08it/s]


Subject: 21casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.05it/s]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.13it/s]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.6643


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.11it/s]


Subject: 24casme2, n=10 | train_mean: 1.0 | test_mean: 0.3578


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.8393


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.14it/s]


Subject: 006samm, n=10 | train_mean: 1.0 | test_mean: 0.8667


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.04it/s]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 0.2222


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:03<00:00,  3.16it/s]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.8286


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.06it/s]


Subject: 012samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.12it/s]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.5758


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.10it/s]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.10it/s]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.05it/s]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.06it/s]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:04<00:00,  3.08it/s]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.08it/s]


Subject: 026samm, n=3 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.06it/s]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.05it/s]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.04it/s]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.04it/s]


Subject: 035samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.06it/s]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [01:05<00:00,  3.07it/s]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 0.0
Total f1: [0.6983133765880812]


In [21]:
class Config(core.Config):
    action_units = None
    epochs = 200
    batch_size = 32
    evaluation_fn = utils.MultiClassF1Score
    criterion = nn.CrossEntropyLoss
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    model = partial(timm.models.resnet18, num_classes=5, pretrained=True)

In [22]:
out = MEGC2018Validator(Config).validate(df, data)

100%|█████████████████████████████████████████| 200/200 [06:11<00:00,  1.86s/it]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [06:11<00:00,  1.86s/it]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [06:10<00:00,  1.85s/it]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.2857


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.07143


100%|█████████████████████████████████████████| 200/200 [06:06<00:00,  1.83s/it]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:06<00:00,  1.83s/it]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.7937


100%|█████████████████████████████████████████| 200/200 [06:10<00:00,  1.85s/it]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 0.48


100%|█████████████████████████████████████████| 200/200 [06:11<00:00,  1.86s/it]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.8


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.7238


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [06:06<00:00,  1.83s/it]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 15casme2, n=3 | train_mean: 0.9952 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 16casme2, n=3 | train_mean: 0.9971 | test_mean: 0.25


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.5381


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 18casme2, n=3 | train_mean: 0.9232 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.5778


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 20casme2, n=7 | train_mean: 0.8808 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:06<00:00,  1.83s/it]


Subject: 21casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:06<00:00,  1.83s/it]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 23casme2, n=10 | train_mean: 0.9622 | test_mean: 0.6643


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 24casme2, n=10 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.7353


100%|█████████████████████████████████████████| 200/200 [06:10<00:00,  1.85s/it]


Subject: 006samm, n=10 | train_mean: 1.0 | test_mean: 0.3179


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 009samm, n=1 | train_mean: 0.9988 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.1667


100%|█████████████████████████████████████████| 200/200 [06:11<00:00,  1.86s/it]


Subject: 012samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.4667


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [06:10<00:00,  1.85s/it]


Subject: 016samm, n=4 | train_mean: 0.9929 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:11<00:00,  1.86s/it]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 019samm, n=1 | train_mean: 0.9995 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [06:08<00:00,  1.84s/it]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 026samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [06:07<00:00,  1.84s/it]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 035samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [06:06<00:00,  1.83s/it]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [06:09<00:00,  1.85s/it]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 0.0
Total f1: [0.6319072169848342]


In [23]:
idx = df["dataset"].isin(["samm", "casme2"])
out = MEGC2018Validator(Config).validate(df[idx].reset_index(), data[idx])

100%|█████████████████████████████████████████| 200/200 [01:56<00:00,  1.72it/s]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [01:56<00:00,  1.71it/s]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.3667


100%|█████████████████████████████████████████| 200/200 [01:56<00:00,  1.72it/s]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.3889


100%|█████████████████████████████████████████| 200/200 [02:02<00:00,  1.64it/s]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:55<00:00,  1.73it/s]


Subject: 05casme2, n=10 | train_mean: 0.9661 | test_mean: 0.3068


100%|█████████████████████████████████████████| 200/200 [02:02<00:00,  1.63it/s]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [01:57<00:00,  1.71it/s]


Subject: 07casme2, n=7 | train_mean: 0.8202 | test_mean: 0.1143


100%|█████████████████████████████████████████| 200/200 [02:01<00:00,  1.64it/s]


Subject: 08casme2, n=3 | train_mean: 0.9962 | test_mean: 0.1667


100%|█████████████████████████████████████████| 200/200 [01:56<00:00,  1.72it/s]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.5778


100%|█████████████████████████████████████████| 200/200 [01:55<00:00,  1.73it/s]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [01:56<00:00,  1.72it/s]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.4375


100%|█████████████████████████████████████████| 200/200 [01:55<00:00,  1.74it/s]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.4708


100%|█████████████████████████████████████████| 200/200 [02:02<00:00,  1.64it/s]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 0.3667


100%|█████████████████████████████████████████| 200/200 [02:07<00:00,  1.57it/s]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:07<00:00,  1.57it/s]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [02:07<00:00,  1.57it/s]


Subject: 16casme2, n=3 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [02:02<00:00,  1.63it/s]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.346


100%|█████████████████████████████████████████| 200/200 [02:07<00:00,  1.56it/s]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:01<00:00,  1.65it/s]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.4857


100%|█████████████████████████████████████████| 200/200 [02:02<00:00,  1.63it/s]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.56it/s]


Subject: 21casme2, n=2 | train_mean: 0.9853 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:07<00:00,  1.57it/s]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:00<00:00,  1.66it/s]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.4038


100%|█████████████████████████████████████████| 200/200 [01:59<00:00,  1.68it/s]


Subject: 24casme2, n=10 | train_mean: 1.0 | test_mean: 0.24


100%|█████████████████████████████████████████| 200/200 [02:06<00:00,  1.58it/s]


Subject: 25casme2, n=3 | train_mean: 0.9934 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:00<00:00,  1.66it/s]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.6722


100%|█████████████████████████████████████████| 200/200 [02:01<00:00,  1.65it/s]


Subject: 006samm, n=10 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [02:05<00:00,  1.60it/s]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:10<00:00,  1.53it/s]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:00<00:00,  1.66it/s]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.2593


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.55it/s]


Subject: 012samm, n=1 | train_mean: 0.9923 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [02:02<00:00,  1.64it/s]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.5758


100%|█████████████████████████████████████████| 200/200 [02:11<00:00,  1.52it/s]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.55it/s]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 0.7778


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.55it/s]


Subject: 018samm, n=2 | train_mean: 0.9973 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.54it/s]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.55it/s]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.56it/s]


Subject: 021samm, n=2 | train_mean: 0.9957 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.56it/s]


Subject: 022samm, n=2 | train_mean: 0.9929 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.55it/s]


Subject: 026samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.55it/s]


Subject: 028samm, n=2 | train_mean: 0.9973 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.56it/s]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [02:07<00:00,  1.56it/s]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.55it/s]


Subject: 033samm, n=2 | train_mean: 0.994 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.56it/s]


Subject: 035samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [02:09<00:00,  1.55it/s]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [02:08<00:00,  1.56it/s]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 1.0
Total f1: [0.49944624713528907]


In [24]:
class Config(core.Config):
    action_units = None
    epochs = 200
    batch_size = 32
    evaluation_fn = utils.MultiClassF1Score
    criterion = nn.CrossEntropyLoss
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    model = partial(timm.models.resnet34, num_classes=5, pretrained=True)

In [25]:
out = MEGC2018Validator(Config).validate(df, data)

100%|█████████████████████████████████████████| 200/200 [10:42<00:00,  3.21s/it]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [10:53<00:00,  3.27s/it]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.5167


100%|█████████████████████████████████████████| 200/200 [10:51<00:00,  3.26s/it]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.1667


100%|█████████████████████████████████████████| 200/200 [10:56<00:00,  3.28s/it]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.7619


100%|█████████████████████████████████████████| 200/200 [11:05<00:00,  3.33s/it]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.7778


100%|█████████████████████████████████████████| 200/200 [11:02<00:00,  3.31s/it]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:05<00:00,  3.33s/it]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.2063


100%|█████████████████████████████████████████| 200/200 [11:11<00:00,  3.36s/it]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [10:58<00:00,  3.29s/it]


Subject: 09casme2, n=10 | train_mean: 0.9892 | test_mean: 0.6944


100%|█████████████████████████████████████████| 200/200 [11:22<00:00,  3.41s/it]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:20<00:00,  3.40s/it]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.4375


100%|█████████████████████████████████████████| 200/200 [11:24<00:00,  3.42s/it]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.4267


100%|█████████████████████████████████████████| 200/200 [11:24<00:00,  3.42s/it]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [11:19<00:00,  3.40s/it]


Subject: 14casme2, n=1 | train_mean: 0.9917 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:22<00:00,  3.41s/it]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 0.25


100%|█████████████████████████████████████████| 200/200 [11:21<00:00,  3.41s/it]


Subject: 16casme2, n=3 | train_mean: 1.0 | test_mean: 0.25


100%|█████████████████████████████████████████| 200/200 [11:24<00:00,  3.42s/it]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.54


100%|█████████████████████████████████████████| 200/200 [11:20<00:00,  3.40s/it]


Subject: 18casme2, n=3 | train_mean: 0.7518 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:19<00:00,  3.40s/it]


Subject: 19casme2, n=14 | train_mean: 0.9205 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [11:25<00:00,  3.43s/it]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:21<00:00,  3.41s/it]


Subject: 21casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:27<00:00,  3.44s/it]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:22<00:00,  3.41s/it]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.5985


100%|█████████████████████████████████████████| 200/200 [11:27<00:00,  3.44s/it]


Subject: 24casme2, n=10 | train_mean: 1.0 | test_mean: 0.6972


100%|█████████████████████████████████████████| 200/200 [11:13<00:00,  3.37s/it]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:15<00:00,  3.38s/it]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.7188


100%|█████████████████████████████████████████| 200/200 [11:27<00:00,  3.44s/it]


Subject: 006samm, n=10 | train_mean: 0.9721 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:20<00:00,  3.40s/it]


Subject: 007samm, n=4 | train_mean: 0.9904 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [11:14<00:00,  3.37s/it]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:15<00:00,  3.38s/it]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.3619


100%|█████████████████████████████████████████| 200/200 [11:20<00:00,  3.40s/it]


Subject: 012samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:31<00:00,  3.46s/it]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.3667


100%|█████████████████████████████████████████| 200/200 [11:42<00:00,  3.51s/it]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [11:47<00:00,  3.54s/it]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:43<00:00,  3.52s/it]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [11:41<00:00,  3.51s/it]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:41<00:00,  3.51s/it]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:40<00:00,  3.50s/it]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:43<00:00,  3.52s/it]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:48<00:00,  3.54s/it]


Subject: 026samm, n=3 | train_mean: 0.9979 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:49<00:00,  3.55s/it]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [12:00<00:00,  3.60s/it]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [12:00<00:00,  3.60s/it]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [12:07<00:00,  3.64s/it]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [11:57<00:00,  3.59s/it]


Subject: 035samm, n=3 | train_mean: 0.9251 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [12:02<00:00,  3.61s/it]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [12:01<00:00,  3.61s/it]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 0.0
Total f1: [0.6818972235754047]


In [26]:
idx = df["dataset"].isin(["samm", "casme2"])
out = MEGC2018Validator(Config).validate(df[idx].reset_index(), data[idx])

100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.5595


100%|█████████████████████████████████████████| 200/200 [03:06<00:00,  1.07it/s]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [03:19<00:00,  1.00it/s]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.2857


100%|█████████████████████████████████████████| 200/200 [03:06<00:00,  1.07it/s]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.5641


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [03:08<00:00,  1.06it/s]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.07143


100%|█████████████████████████████████████████| 200/200 [03:20<00:00,  1.00s/it]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:05<00:00,  1.08it/s]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.481


100%|█████████████████████████████████████████| 200/200 [03:05<00:00,  1.08it/s]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.8


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.2667


100%|█████████████████████████████████████████| 200/200 [03:17<00:00,  1.01it/s]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:20<00:00,  1.00s/it]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [03:18<00:00,  1.01it/s]


Subject: 16casme2, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [03:06<00:00,  1.07it/s]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.4413


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [03:05<00:00,  1.08it/s]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.6267


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:23<00:00,  1.02s/it]


Subject: 21casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:24<00:00,  1.02s/it]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.4038


100%|█████████████████████████████████████████| 200/200 [03:08<00:00,  1.06it/s]


Subject: 24casme2, n=10 | train_mean: 0.9918 | test_mean: 0.24


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:06<00:00,  1.07it/s]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.8083


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 006samm, n=10 | train_mean: 1.0 | test_mean: 0.7778


100%|█████████████████████████████████████████| 200/200 [03:19<00:00,  1.00it/s]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [03:24<00:00,  1.02s/it]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.2353


100%|█████████████████████████████████████████| 200/200 [03:24<00:00,  1.02s/it]


Subject: 012samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [03:07<00:00,  1.07it/s]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:22<00:00,  1.01s/it]


Subject: 015samm, n=2 | train_mean: 0.9954 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:23<00:00,  1.02s/it]


Subject: 018samm, n=2 | train_mean: 0.9845 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:22<00:00,  1.01s/it]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:23<00:00,  1.02s/it]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:20<00:00,  1.00s/it]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]


Subject: 026samm, n=3 | train_mean: 0.996 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [03:22<00:00,  1.01s/it]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [03:19<00:00,  1.00it/s]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [03:23<00:00,  1.02s/it]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [03:22<00:00,  1.01s/it]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [03:19<00:00,  1.00it/s]


Subject: 035samm, n=3 | train_mean: 0.9961 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:22<00:00,  1.01s/it]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [03:23<00:00,  1.02s/it]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 0.0
Total f1: [0.5763878968128576]


In [27]:
class Config(core.Config):
    action_units = None
    epochs = 200
    batch_size = 32
    evaluation_fn = utils.MultiClassF1Score
    criterion = nn.CrossEntropyLoss
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    model = partial(timm.models.resnet50, num_classes=5, pretrained=True)

In [28]:
out = MEGC2018Validator(Config).validate(df, data)

100%|█████████████████████████████████████████| 200/200 [16:12<00:00,  4.86s/it]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [16:06<00:00,  4.83s/it]


Subject: 02casme2, n=7 | train_mean: 0.9988 | test_mean: 0.5167


100%|█████████████████████████████████████████| 200/200 [16:09<00:00,  4.85s/it]


Subject: 03casme2, n=7 | train_mean: 0.9965 | test_mean: 0.1143


100%|█████████████████████████████████████████| 200/200 [16:08<00:00,  4.84s/it]


Subject: 04casme2, n=5 | train_mean: 0.9985 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [16:19<00:00,  4.90s/it]


Subject: 05casme2, n=10 | train_mean: 0.9962 | test_mean: 0.7778


100%|█████████████████████████████████████████| 200/200 [16:14<00:00,  4.87s/it]


Subject: 06casme2, n=3 | train_mean: 0.9972 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [16:12<00:00,  4.86s/it]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.2063


100%|█████████████████████████████████████████| 200/200 [16:10<00:00,  4.85s/it]


Subject: 08casme2, n=3 | train_mean: 0.9976 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:08<00:00,  4.84s/it]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.5286


100%|█████████████████████████████████████████| 200/200 [16:01<00:00,  4.81s/it]


Subject: 10casme2, n=13 | train_mean: 0.9992 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [15:59<00:00,  4.80s/it]


Subject: 11casme2, n=9 | train_mean: 0.998 | test_mean: 0.4375


100%|█████████████████████████████████████████| 200/200 [16:14<00:00,  4.87s/it]


Subject: 12casme2, n=11 | train_mean: 0.987 | test_mean: 0.6597


100%|█████████████████████████████████████████| 200/200 [16:21<00:00,  4.91s/it]


Subject: 13casme2, n=5 | train_mean: 0.9992 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [16:13<00:00,  4.87s/it]


Subject: 14casme2, n=1 | train_mean: 0.9993 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:18<00:00,  4.89s/it]


Subject: 15casme2, n=3 | train_mean: 0.9981 | test_mean: 0.2


100%|█████████████████████████████████████████| 200/200 [16:21<00:00,  4.91s/it]


Subject: 16casme2, n=3 | train_mean: 0.9913 | test_mean: 0.25


100%|█████████████████████████████████████████| 200/200 [15:53<00:00,  4.77s/it]


Subject: 17casme2, n=25 | train_mean: 0.9816 | test_mean: 0.554


100%|█████████████████████████████████████████| 200/200 [16:04<00:00,  4.82s/it]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:45<00:00,  5.03s/it]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.5721


100%|█████████████████████████████████████████| 200/200 [16:23<00:00,  4.92s/it]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:38<00:00,  4.99s/it]


Subject: 21casme2, n=2 | train_mean: 0.9952 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:14<00:00,  4.87s/it]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:00<00:00,  4.80s/it]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.5985


100%|█████████████████████████████████████████| 200/200 [16:11<00:00,  4.86s/it]


Subject: 24casme2, n=10 | train_mean: 0.9942 | test_mean: 0.8167


100%|█████████████████████████████████████████| 200/200 [16:04<00:00,  4.82s/it]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:17<00:00,  4.89s/it]


Subject: 26casme2, n=13 | train_mean: 0.9988 | test_mean: 0.8393


100%|█████████████████████████████████████████| 200/200 [16:30<00:00,  4.95s/it]


Subject: 006samm, n=10 | train_mean: 0.9936 | test_mean: 0.7778


100%|█████████████████████████████████████████| 200/200 [16:35<00:00,  4.98s/it]


Subject: 007samm, n=4 | train_mean: 0.9992 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [16:35<00:00,  4.98s/it]


Subject: 009samm, n=1 | train_mean: 0.9976 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:48<00:00,  5.04s/it]


Subject: 011samm, n=12 | train_mean: 0.9933 | test_mean: 0.3385


100%|█████████████████████████████████████████| 200/200 [16:53<00:00,  5.07s/it]


Subject: 012samm, n=1 | train_mean: 0.9971 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [16:36<00:00,  4.98s/it]


Subject: 014samm, n=8 | train_mean: 0.9963 | test_mean: 0.3148


100%|█████████████████████████████████████████| 200/200 [16:51<00:00,  5.06s/it]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [16:33<00:00,  4.97s/it]


Subject: 016samm, n=4 | train_mean: 0.9906 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:41<00:00,  5.01s/it]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:12<00:00,  4.86s/it]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:18<00:00,  4.89s/it]


Subject: 020samm, n=1 | train_mean: 0.9985 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:11<00:00,  4.86s/it]


Subject: 021samm, n=2 | train_mean: 0.9742 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:21<00:00,  4.91s/it]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:16<00:00,  4.88s/it]


Subject: 026samm, n=3 | train_mean: 0.9885 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:14<00:00,  4.87s/it]


Subject: 028samm, n=2 | train_mean: 0.999 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:09<00:00,  4.85s/it]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [16:19<00:00,  4.90s/it]


Subject: 032samm, n=3 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [16:18<00:00,  4.89s/it]


Subject: 033samm, n=2 | train_mean: 0.9981 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [16:11<00:00,  4.86s/it]


Subject: 035samm, n=3 | train_mean: 0.9956 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [16:01<00:00,  4.81s/it]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [15:48<00:00,  4.74s/it]


Subject: 037samm, n=1 | train_mean: 0.9935 | test_mean: 0.0
Total f1: [0.6466664611855174]


In [29]:
idx = df["dataset"].isin(["samm", "casme2"])
out = MEGC2018Validator(Config).validate(df[idx].reset_index(), data[idx])

100%|█████████████████████████████████████████| 200/200 [03:52<00:00,  1.16s/it]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.6111


100%|█████████████████████████████████████████| 200/200 [03:50<00:00,  1.15s/it]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.4333


100%|█████████████████████████████████████████| 200/200 [03:51<00:00,  1.16s/it]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.07143


100%|█████████████████████████████████████████| 200/200 [04:03<00:00,  1.22s/it]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.2857


100%|█████████████████████████████████████████| 200/200 [03:49<00:00,  1.15s/it]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.2917


100%|█████████████████████████████████████████| 200/200 [04:08<00:00,  1.24s/it]


Subject: 06casme2, n=3 | train_mean: 0.9951 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [03:48<00:00,  1.14s/it]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.1111


100%|█████████████████████████████████████████| 200/200 [04:08<00:00,  1.24s/it]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 0.5556


100%|█████████████████████████████████████████| 200/200 [03:49<00:00,  1.15s/it]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.3


100%|█████████████████████████████████████████| 200/200 [03:48<00:00,  1.14s/it]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 0.2899


100%|█████████████████████████████████████████| 200/200 [03:49<00:00,  1.15s/it]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.4375


100%|█████████████████████████████████████████| 200/200 [03:47<00:00,  1.14s/it]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.3143


100%|█████████████████████████████████████████| 200/200 [04:06<00:00,  1.23s/it]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 0.3667


100%|█████████████████████████████████████████| 200/200 [04:07<00:00,  1.24s/it]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [04:09<00:00,  1.25s/it]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 0.5


100%|█████████████████████████████████████████| 200/200 [04:07<00:00,  1.24s/it]


Subject: 16casme2, n=3 | train_mean: 0.9945 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [03:51<00:00,  1.16s/it]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.4001


100%|█████████████████████████████████████████| 200/200 [04:14<00:00,  1.27s/it]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 0.25


100%|█████████████████████████████████████████| 200/200 [03:51<00:00,  1.16s/it]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.1905


100%|█████████████████████████████████████████| 200/200 [03:51<00:00,  1.16s/it]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:09<00:00,  1.25s/it]


Subject: 21casme2, n=2 | train_mean: 0.9951 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [04:08<00:00,  1.24s/it]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [03:49<00:00,  1.15s/it]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.5985


100%|█████████████████████████████████████████| 200/200 [03:47<00:00,  1.14s/it]


Subject: 24casme2, n=10 | train_mean: 0.9945 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [04:06<00:00,  1.23s/it]


Subject: 25casme2, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [03:49<00:00,  1.15s/it]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.5083


100%|█████████████████████████████████████████| 200/200 [03:49<00:00,  1.15s/it]


Subject: 006samm, n=10 | train_mean: 0.9952 | test_mean: 0.3048


100%|█████████████████████████████████████████| 200/200 [04:09<00:00,  1.25s/it]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:10<00:00,  1.25s/it]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:51<00:00,  1.16s/it]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.1742


100%|█████████████████████████████████████████| 200/200 [04:11<00:00,  1.26s/it]


Subject: 012samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [03:52<00:00,  1.16s/it]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.3125


100%|█████████████████████████████████████████| 200/200 [04:10<00:00,  1.25s/it]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [04:11<00:00,  1.26s/it]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:09<00:00,  1.25s/it]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [04:10<00:00,  1.25s/it]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:11<00:00,  1.26s/it]


Subject: 020samm, n=1 | train_mean: 0.9956 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:09<00:00,  1.25s/it]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:10<00:00,  1.25s/it]


Subject: 022samm, n=2 | train_mean: 0.9907 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:09<00:00,  1.25s/it]


Subject: 026samm, n=3 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [04:11<00:00,  1.26s/it]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:12<00:00,  1.26s/it]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 200/200 [04:11<00:00,  1.26s/it]


Subject: 032samm, n=3 | train_mean: 0.9949 | test_mean: 0.0


100%|█████████████████████████████████████████| 200/200 [04:13<00:00,  1.27s/it]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 200/200 [04:13<00:00,  1.27s/it]


Subject: 035samm, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:13<00:00,  1.27s/it]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 200/200 [04:13<00:00,  1.27s/it]


Subject: 037samm, n=1 | train_mean: 1.0 | test_mean: 1.0
Total f1: [0.4348693101840503]


## R(2+1)D

In [23]:
c = datasets.CrossDataset(resize=112, ignore_validation=True, preload=True, color=True)
df = c.data_frame
data = c.data

Loading data: 100%|███████████████████████████| 189/189 [00:24<00:00,  7.86it/s]
Loading data: 100%|███████████████████████████| 256/256 [01:46<00:00,  2.41it/s]
Loading data: 100%|███████████████████████████| 159/159 [01:29<00:00,  1.78it/s]
Loading data: 100%|███████████████████████████| 267/267 [00:34<00:00,  7.72it/s]
Loading data: 100%|███████████████████████████| 300/300 [01:00<00:00,  4.96it/s]
Loading data: 100%|███████████████████████████| 860/860 [02:42<00:00,  5.29it/s]


In [24]:
c = datasets.Samm(resize=112, ignore_validation=True, preload=True, color=True)
dfs = c.data_frame
datas = c.data

Loading data: 100%|███████████████████████████| 159/159 [01:22<00:00,  1.94it/s]


In [25]:
# Add missing AU columns to SAMM that are missing compared to all data
dfs_aus = pd.DataFrame(dfs.loc[:, "AU1":].to_dict(), columns=df.loc[:, "AU1":].columns).fillna(0)
dfs_dropped = dfs.drop(dfs.loc[0, "AU1":].index.tolist(), axis=1)
dfs = pd.concat([dfs_dropped, dfs_aus], axis=1)

In [26]:
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder

action_units = dfs.loc[:, "AU1":].columns
df_dataset = dfs
model = GradientBoostingClassifier(random_state=0)

aus = np.array(df_dataset.loc[:, action_units])
emotions = df_dataset["Objective Classes"].astype("str")
model.fit(aus, emotions)
predicted_emotions = model.predict(aus)
f1 = f1_score(emotions, predicted_emotions, average="macro")
print(f1 * 100)

99.17754854925882


In [27]:
df.insert(7, "objective_class", model.predict(np.array(df.loc[:, "AU1":])))

In [28]:
from torchvision import transforms
import torch.nn.functional as F
#interpolate samples with less than 8 frames
n_frames = 8
for i, video in enumerate(data):
    if video.shape[0] < n_frames:
        new_shape = (n_frames,) + video.shape[1:-1]
        video = torch.tensor(video).permute(3, 0, 1, 2).unsqueeze(0).float()
        new_video = F.interpolate(video, size=new_shape, mode="trilinear")
        data[i] = new_video.squeeze(0).permute(1, 2, 3, 0).byte().numpy()

In [29]:
#idx = df["dataset"].isin(["samm", "casme2"]) & (df["objective_class"].isin(["6", "7"]))
idx = df["objective_class"].isin(["6", "7"])
df = df[~idx].reset_index(drop=True)
data = data[~idx]

In [30]:
df["subject"] = df["subject"].astype(str) + df["dataset"]

In [31]:
# Create a function that returns the model as it needs to be modified
def r2plus1d(num_classes: int):
    model = torchvision.models.video.r2plus1d_18(weights=torchvision.models.video.R2Plus1D_18_Weights.DEFAULT)
    model.fc = nn.Linear(in_features=512, out_features=num_classes)
    return model

In [32]:
class Config(core.Config):
    action_units = None
    epochs = 100
    batch_size = 32
    evaluation_fn = utils.MultiClassF1Score
    criterion = nn.CrossEntropyLoss
    device = torch.device("cuda:1")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    model = partial(r2plus1d, num_classes=5)
    loss_scaler = torch.cuda.amp.GradScaler
    channels_last = torch.channels_last_3d
    train_transform = {
        "spatial": None,
        "temporal": datasets.UniformTemporalSubsample(8)
    }
    test_transform = {
        "spatial": None,
        "temporal": datasets.UniformTemporalSubsample(8)
    }

In [33]:
out = MEGC2018Validator(Config).validate(df, data)

100%|███████████████████████████████████████| 100/100 [1:19:15<00:00, 47.56s/it]


Subject: 01casme2, n=8 | train_mean: 0.9611 | test_mean: 0.4646


100%|███████████████████████████████████████| 100/100 [1:19:25<00:00, 47.65s/it]


Subject: 02casme2, n=7 | train_mean: 0.9981 | test_mean: 0.4167


100%|███████████████████████████████████████| 100/100 [1:19:23<00:00, 47.64s/it]


Subject: 03casme2, n=7 | train_mean: 0.9712 | test_mean: 0.1875


100%|███████████████████████████████████████| 100/100 [1:19:30<00:00, 47.71s/it]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.2857


100%|███████████████████████████████████████| 100/100 [1:19:14<00:00, 47.55s/it]


Subject: 05casme2, n=10 | train_mean: 0.9899 | test_mean: 0.3558


100%|███████████████████████████████████████| 100/100 [1:19:37<00:00, 47.77s/it]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 0.5


100%|███████████████████████████████████████| 100/100 [1:19:26<00:00, 47.66s/it]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.07143


100%|███████████████████████████████████████| 100/100 [1:19:38<00:00, 47.79s/it]


Subject: 08casme2, n=3 | train_mean: 0.7906 | test_mean: 0.0


100%|███████████████████████████████████████| 100/100 [1:19:15<00:00, 47.55s/it]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.2589


100%|███████████████████████████████████████| 100/100 [1:19:06<00:00, 47.47s/it]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 0.4583


100%|███████████████████████████████████████| 100/100 [1:19:20<00:00, 47.61s/it]


Subject: 11casme2, n=9 | train_mean: 0.9928 | test_mean: 0.4375


100%|███████████████████████████████████████| 100/100 [1:19:13<00:00, 47.54s/it]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.3095


100%|███████████████████████████████████████| 100/100 [1:19:33<00:00, 47.74s/it]


Subject: 13casme2, n=5 | train_mean: 0.9892 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:43<00:00, 47.84s/it]


Subject: 14casme2, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|███████████████████████████████████████| 100/100 [1:19:38<00:00, 47.79s/it]


Subject: 15casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:34<00:00, 47.74s/it]


Subject: 16casme2, n=3 | train_mean: 0.9924 | test_mean: 0.25


100%|███████████████████████████████████████| 100/100 [1:18:21<00:00, 47.01s/it]


Subject: 17casme2, n=25 | train_mean: 0.9985 | test_mean: 0.6423


100%|███████████████████████████████████████| 100/100 [1:19:36<00:00, 47.76s/it]


Subject: 18casme2, n=3 | train_mean: 1.0 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:18:58<00:00, 47.39s/it]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.3562


100%|███████████████████████████████████████| 100/100 [1:19:22<00:00, 47.63s/it]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:39<00:00, 47.80s/it]


Subject: 21casme2, n=2 | train_mean: 0.933 | test_mean: 0.3333


100%|███████████████████████████████████████| 100/100 [1:19:42<00:00, 47.83s/it]


Subject: 22casme2, n=2 | train_mean: 1.0 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:17<00:00, 47.57s/it]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.5985


100%|███████████████████████████████████████| 100/100 [1:19:17<00:00, 47.58s/it]


Subject: 24casme2, n=10 | train_mean: 0.9829 | test_mean: 0.1625


100%|███████████████████████████████████████| 100/100 [1:19:38<00:00, 47.79s/it]


Subject: 25casme2, n=3 | train_mean: 0.9972 | test_mean: 0.4


100%|███████████████████████████████████████| 100/100 [1:19:05<00:00, 47.46s/it]


Subject: 26casme2, n=13 | train_mean: 1.0 | test_mean: 0.4105


100%|███████████████████████████████████████| 100/100 [1:19:16<00:00, 47.57s/it]


Subject: 006samm, n=10 | train_mean: 0.9965 | test_mean: 0.5


100%|███████████████████████████████████████| 100/100 [1:19:36<00:00, 47.76s/it]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:44<00:00, 47.85s/it]


Subject: 009samm, n=1 | train_mean: 0.9743 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:08<00:00, 47.48s/it]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.1667


100%|███████████████████████████████████████| 100/100 [1:19:43<00:00, 47.83s/it]


Subject: 012samm, n=1 | train_mean: 0.9884 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:23<00:00, 47.63s/it]


Subject: 014samm, n=8 | train_mean: 0.9822 | test_mean: 0.3651


100%|███████████████████████████████████████| 100/100 [1:19:43<00:00, 47.83s/it]


Subject: 015samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|███████████████████████████████████████| 100/100 [1:19:35<00:00, 47.76s/it]


Subject: 016samm, n=4 | train_mean: 0.9939 | test_mean: 0.1667


100%|███████████████████████████████████████| 100/100 [1:19:40<00:00, 47.81s/it]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|███████████████████████████████████████| 100/100 [1:19:41<00:00, 47.81s/it]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:41<00:00, 47.81s/it]


Subject: 020samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|███████████████████████████████████████| 100/100 [1:19:39<00:00, 47.80s/it]


Subject: 021samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|███████████████████████████████████████| 100/100 [1:19:39<00:00, 47.80s/it]


Subject: 022samm, n=2 | train_mean: 0.999 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:36<00:00, 47.77s/it]


Subject: 026samm, n=3 | train_mean: 0.9992 | test_mean: 0.2222


100%|███████████████████████████████████████| 100/100 [1:19:38<00:00, 47.79s/it]


Subject: 028samm, n=2 | train_mean: 0.999 | test_mean: 0.0


100%|███████████████████████████████████████| 100/100 [1:19:37<00:00, 47.77s/it]


Subject: 030samm, n=3 | train_mean: 1.0 | test_mean: 0.25


100%|███████████████████████████████████████| 100/100 [1:19:35<00:00, 47.76s/it]


Subject: 032samm, n=3 | train_mean: 0.9918 | test_mean: 0.3333


100%|███████████████████████████████████████| 100/100 [1:19:40<00:00, 47.81s/it]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|███████████████████████████████████████| 100/100 [1:19:36<00:00, 47.77s/it]


Subject: 035samm, n=3 | train_mean: 0.9985 | test_mean: 1.0


100%|███████████████████████████████████████| 100/100 [1:19:41<00:00, 47.81s/it]


Subject: 036samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|███████████████████████████████████████| 100/100 [1:19:39<00:00, 47.80s/it]


Subject: 037samm, n=1 | train_mean: 0.945 | test_mean: 0.0
Total f1: [0.4822672328704776]


In [34]:
idx = df["dataset"].isin(["samm", "casme2"])
out = MEGC2018Validator(Config).validate(df[idx].reset_index(), data[idx])

100%|█████████████████████████████████████████| 100/100 [16:03<00:00,  9.63s/it]


Subject: 01casme2, n=8 | train_mean: 1.0 | test_mean: 0.2564


100%|█████████████████████████████████████████| 100/100 [16:06<00:00,  9.67s/it]


Subject: 02casme2, n=7 | train_mean: 1.0 | test_mean: 0.07143


100%|█████████████████████████████████████████| 100/100 [16:06<00:00,  9.66s/it]


Subject: 03casme2, n=7 | train_mean: 1.0 | test_mean: 0.2


100%|█████████████████████████████████████████| 100/100 [16:43<00:00, 10.04s/it]


Subject: 04casme2, n=5 | train_mean: 1.0 | test_mean: 0.4444


100%|█████████████████████████████████████████| 100/100 [15:58<00:00,  9.58s/it]


Subject: 05casme2, n=10 | train_mean: 1.0 | test_mean: 0.2745


100%|█████████████████████████████████████████| 100/100 [16:27<00:00,  9.88s/it]


Subject: 06casme2, n=3 | train_mean: 1.0 | test_mean: 0.1667


100%|█████████████████████████████████████████| 100/100 [16:06<00:00,  9.67s/it]


Subject: 07casme2, n=7 | train_mean: 1.0 | test_mean: 0.09524


100%|█████████████████████████████████████████| 100/100 [16:27<00:00,  9.88s/it]


Subject: 08casme2, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 100/100 [15:57<00:00,  9.58s/it]


Subject: 09casme2, n=10 | train_mean: 1.0 | test_mean: 0.1538


100%|█████████████████████████████████████████| 100/100 [15:47<00:00,  9.48s/it]


Subject: 10casme2, n=13 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 100/100 [16:02<00:00,  9.62s/it]


Subject: 11casme2, n=9 | train_mean: 1.0 | test_mean: 0.4375


100%|█████████████████████████████████████████| 100/100 [15:54<00:00,  9.54s/it]


Subject: 12casme2, n=11 | train_mean: 1.0 | test_mean: 0.4127


100%|█████████████████████████████████████████| 100/100 [16:44<00:00, 10.05s/it]


Subject: 13casme2, n=5 | train_mean: 1.0 | test_mean: 0.8


100%|█████████████████████████████████████████| 100/100 [16:34<00:00,  9.95s/it]


Subject: 14casme2, n=1 | train_mean: 0.9821 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:28<00:00,  9.89s/it]


Subject: 15casme2, n=3 | train_mean: 0.9951 | test_mean: 0.1667


100%|█████████████████████████████████████████| 100/100 [16:28<00:00,  9.88s/it]


Subject: 16casme2, n=3 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [15:07<00:00,  9.08s/it]


Subject: 17casme2, n=25 | train_mean: 1.0 | test_mean: 0.3772


100%|█████████████████████████████████████████| 100/100 [16:26<00:00,  9.86s/it]


Subject: 18casme2, n=3 | train_mean: 0.9945 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [15:43<00:00,  9.44s/it]


Subject: 19casme2, n=14 | train_mean: 1.0 | test_mean: 0.12


100%|█████████████████████████████████████████| 100/100 [16:06<00:00,  9.67s/it]


Subject: 20casme2, n=7 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 100/100 [16:32<00:00,  9.92s/it]


Subject: 21casme2, n=2 | train_mean: 1.0 | test_mean: 0.3333


100%|█████████████████████████████████████████| 100/100 [16:31<00:00,  9.91s/it]


Subject: 22casme2, n=2 | train_mean: 0.9909 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [15:59<00:00,  9.59s/it]


Subject: 23casme2, n=10 | train_mean: 1.0 | test_mean: 0.1429


100%|█████████████████████████████████████████| 100/100 [16:00<00:00,  9.60s/it]


Subject: 24casme2, n=10 | train_mean: 1.0 | test_mean: 0.1667


100%|█████████████████████████████████████████| 100/100 [16:27<00:00,  9.88s/it]


Subject: 25casme2, n=3 | train_mean: 0.9961 | test_mean: 1.0


100%|█████████████████████████████████████████| 100/100 [15:47<00:00,  9.48s/it]


Subject: 26casme2, n=13 | train_mean: 0.9945 | test_mean: 0.1944


100%|█████████████████████████████████████████| 100/100 [15:57<00:00,  9.58s/it]


Subject: 006samm, n=10 | train_mean: 1.0 | test_mean: 0.12


100%|█████████████████████████████████████████| 100/100 [16:23<00:00,  9.83s/it]


Subject: 007samm, n=4 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:35<00:00,  9.96s/it]


Subject: 009samm, n=1 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [15:50<00:00,  9.51s/it]


Subject: 011samm, n=12 | train_mean: 1.0 | test_mean: 0.2593


100%|█████████████████████████████████████████| 100/100 [16:35<00:00,  9.95s/it]


Subject: 012samm, n=1 | train_mean: 0.9828 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:03<00:00,  9.63s/it]


Subject: 014samm, n=8 | train_mean: 1.0 | test_mean: 0.4667


100%|█████████████████████████████████████████| 100/100 [16:31<00:00,  9.91s/it]


Subject: 015samm, n=2 | train_mean: 0.9928 | test_mean: 0.3333


100%|█████████████████████████████████████████| 100/100 [16:23<00:00,  9.83s/it]


Subject: 016samm, n=4 | train_mean: 1.0 | test_mean: 0.1333


100%|█████████████████████████████████████████| 100/100 [16:31<00:00,  9.92s/it]


Subject: 018samm, n=2 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:33<00:00,  9.94s/it]


Subject: 019samm, n=1 | train_mean: 1.0 | test_mean: 1.0


100%|█████████████████████████████████████████| 100/100 [16:35<00:00,  9.95s/it]


Subject: 020samm, n=1 | train_mean: 0.9961 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:31<00:00,  9.92s/it]


Subject: 021samm, n=2 | train_mean: 0.9862 | test_mean: 1.0


100%|█████████████████████████████████████████| 100/100 [16:31<00:00,  9.91s/it]


Subject: 022samm, n=2 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:28<00:00,  9.88s/it]


Subject: 026samm, n=3 | train_mean: 1.0 | test_mean: 0.4


100%|█████████████████████████████████████████| 100/100 [16:32<00:00,  9.93s/it]


Subject: 028samm, n=2 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:27<00:00,  9.87s/it]


Subject: 030samm, n=3 | train_mean: 0.9633 | test_mean: 1.0


100%|█████████████████████████████████████████| 100/100 [16:28<00:00,  9.89s/it]


Subject: 032samm, n=3 | train_mean: 0.9808 | test_mean: 0.4


100%|█████████████████████████████████████████| 100/100 [16:30<00:00,  9.91s/it]


Subject: 033samm, n=2 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:27<00:00,  9.87s/it]


Subject: 035samm, n=3 | train_mean: 1.0 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:34<00:00,  9.95s/it]


Subject: 036samm, n=1 | train_mean: 0.9848 | test_mean: 0.0


100%|█████████████████████████████████████████| 100/100 [16:35<00:00,  9.95s/it]


Subject: 037samm, n=1 | train_mean: 0.9923 | test_mean: 1.0
Total f1: [0.24785964912280703]
