In [1]:
import meb
from meb import utils
from meb import datasets
from meb import core
from meb import models

from functools import partial

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import timm
from tqdm import tqdm



pd.set_option("display.max_columns", 50)
%load_ext autoreload
%autoreload 2

In [2]:
# Only validate casme2
class IValidator(core.CrossDatasetValidator):
    def __init__(self, config: "Config"):
        super().__init__(config)
        
    def validate_n_times(
        self, df: pd.DataFrame, input_data, n_times: int = 5
    ) -> None:
        self.verbose = False
        self.disable_tqdm = True
        au_results = []
        dataset_results = []
        casme2_idx = df["dataset"] == "casme2"
        for n in tqdm(range(n_times)):
            outputs_list = self.validate(df, input_data, seed_n=n + 45)
            au_result, dataset_result = self.printer.results_to_list(outputs_list, df[casme2_idx])
            au_results.append(au_result)
            dataset_results.append(dataset_result)

        aus = [i for i in self.cf.action_units]
        dataset_names = df["dataset"].unique().tolist()
        aus.append("Average")
        dataset_names.append("Average")
        au_results = np.array(au_results)
        dataset_results = np.array(dataset_results)
        for i in range(len(self.cf.evaluation_fn)):
            if len(self.cf.evaluation_fn) > 1:
                print(self.printer.metric_name(self.cf.evaluation_fn[i]))
            au_result = self.printer.list_to_latex(list(au_results[:, i].mean(axis=0)))
            dataset_result = self.printer.list_to_latex(
                list(dataset_results[:, i].mean(axis=0))
            )
            print("AUS:", aus)
            print(au_result)
            print("\nDatasets: ", dataset_names)
            print(dataset_result)
    
    def validate(self, df: pd.DataFrame, input_data: np.ndarray, seed_n: int = 1):
        utils.set_random_seeds(seed_n)
        dataset_names = df["dataset"].unique()
        # Create a boolean array with the AUs
        labels = np.array(df[self.cf.action_units])
        outputs_list = []
        for dataset_name in dataset_names:
            if dataset_name != "casme2":
                continue
            train_metrics, test_metrics, outputs_test = self.validate_split(
                df, input_data, labels, dataset_name
            )
            outputs_list.append(outputs_test)
            if self.verbose:
                self.printer.print_train_test_evaluation(
                    train_metrics, test_metrics, dataset_name, outputs_test.shape[0]
                )

        # Calculate total f1-scores
        predictions = torch.cat(outputs_list)
        idx = df["dataset"] == "casme2"
        metrics = self.evaluation_fn(labels[idx], predictions)
        if self.verbose:
            self.printer.print_test_validation(metrics)
        return outputs_list

# Off-ApexNet

In [6]:
c = datasets.CrossDataset(resize=28, optical_flow=True)
df = c.data_frame
data = c.data

In [7]:
class Config(core.Config):
    device = torch.device("cuda:1")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    batch_size = 128
    epochs = 3000
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary")
    ]
    model = partial(meb.models.OffApexNet, num_classes=len(core.Config.action_units))

In [8]:
# [:, :2] to remove optical strain as it is not used in OffApexNet
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx, :2], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|██████████████████████████████████████████| 5/5 [1:15:37<00:00, 907.55s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
84.5 & 82.2 & 93.7 & 49.8 & 48.7 & 45.9 & 57.5 & 48.4 & 58.3 & 56.8 & 64.9 & 68.8 & 63.3

Datasets:  ['casme', 'casme2', 'Average']
63.3 & 63.3
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
73.1 & 68.2 & 93.9 & 0.0 & 0.0 & 0.0 & 23.7 & 0.0 & 23.2 & 21.2 & 32.5 & 41.6 & 31.5

Datasets:  ['casme', 'casme2', 'Average']
31.5 & 31.5
['casme2', 'casme', 'samm']


100%|██████████████████████████████████████████| 5/5 [1:12:33<00:00, 870.64s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.3 & 84.5 & 93.4 & 49.8 & 48.7 & 55.4 & 64.9 & 48.2 & 68.8 & 57.4 & 68.8 & 73.5 & 66.6

Datasets:  ['casme', 'casme2', 'samm', 'Average']
66.6 & 66.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
75.4 & 72.1 & 93.4 & 0.0 & 0.0 & 18.3 & 35.3 & 0.0 & 44.4 & 21.0 & 40.0 & 50.5 & 37.5

Datasets:  ['casme', 'casme2', 'samm', 'Average']
37.5 & 37.5
['casme2', 'casme', 'samm', 'mmew']


100%|██████████████████████████████████████████| 5/5 [1:18:17<00:00, 939.44s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
89.7 & 82.4 & 95.2 & 48.5 & 59.8 & 57.7 & 62.5 & 50.9 & 73.6 & 66.3 & 67.1 & 76.1 & 69.1

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
69.1 & 69.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
81.7 & 68.6 & 95.2 & 0.0 & 22.1 & 23.2 & 30.9 & 5.6 & 52.7 & 37.7 & 36.6 & 55.5 & 42.5

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
42.5 & 42.5
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|█████████████████████████████████████████| 5/5 [1:27:31<00:00, 1050.40s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.0 & 84.0 & 95.5 & 49.5 & 60.7 & 69.3 & 64.4 & 52.4 & 74.4 & 60.1 & 70.7 & 75.5 & 70.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
70.6 & 70.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
83.9 & 71.3 & 95.6 & 0.0 & 24.4 & 46.8 & 32.4 & 8.2 & 55.6 & 25.3 & 43.7 & 54.2 & 45.1

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
45.1 & 45.1
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|█████████████████████████████████████████| 5/5 [1:37:20<00:00, 1168.17s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
92.5 & 85.3 & 94.7 & 49.7 & 56.4 & 65.7 & 54.3 & 56.6 & 70.8 & 75.0 & 65.3 & 72.4 & 69.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
69.9 & 69.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.5 & 73.5 & 94.7 & 0.0 & 15.4 & 38.1 & 11.7 & 16.4 & 47.2 & 55.2 & 33.3 & 48.4 & 43.4

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
43.4 & 43.4





# SSSNet

In [3]:
c = datasets.CrossDataset(resize=64, optical_flow=True)
df = c.data_frame
data = c.data

In [4]:
class Config(core.Config):
    device = torch.device("cuda:1")
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary")
    ]
    batch_size = 128
    epochs = 200
    model = partial(meb.models.SSSNet, num_classes=len(core.Config.action_units))

In [5]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|█████████████████████████████████████████████| 5/5 [06:54<00:00, 82.96s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
90.7 & 79.2 & 94.2 & 49.8 & 48.7 & 45.9 & 55.8 & 48.4 & 56.3 & 64.0 & 71.0 & 76.5 & 65.1

Datasets:  ['casme', 'casme2', 'Average']
65.1 & 65.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
83.7 & 62.9 & 94.3 & 0.0 & 0.0 & 0.0 & 14.4 & 0.0 & 19.2 & 36.1 & 44.4 & 56.3 & 34.3

Datasets:  ['casme', 'casme2', 'Average']
34.3 & 34.3
['casme2', 'casme', 'samm']


100%|█████████████████████████████████████████████| 5/5 [06:42<00:00, 80.41s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
95.5 & 84.7 & 94.4 & 49.8 & 48.7 & 63.5 & 63.8 & 48.2 & 71.3 & 68.9 & 70.3 & 77.9 & 69.7

Datasets:  ['casme', 'casme2', 'samm', 'Average']
69.7 & 69.7
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.9 & 72.5 & 94.3 & 0.0 & 0.0 & 34.0 & 30.1 & 0.0 & 48.4 & 43.0 & 42.9 & 58.9 & 43.0

Datasets:  ['casme', 'casme2', 'samm', 'Average']
43.0 & 43.0
['casme2', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████████| 5/5 [07:39<00:00, 91.94s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.7 & 82.8 & 96.2 & 48.1 & 52.9 & 68.0 & 56.3 & 51.0 & 69.1 & 67.3 & 61.2 & 80.3 & 68.9

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
68.9 & 68.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
88.8 & 69.4 & 96.3 & 0.0 & 8.4 & 43.8 & 15.3 & 5.8 & 44.6 & 39.6 & 25.2 & 63.4 & 41.7

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
41.7 & 41.7
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|█████████████████████████████████████████████| 5/5 [07:52<00:00, 94.55s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.1 & 84.2 & 96.5 & 49.8 & 63.8 & 72.6 & 54.0 & 53.4 & 75.6 & 66.8 & 69.6 & 79.7 & 71.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
71.6 & 71.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.8 & 71.9 & 96.4 & 0.0 & 30.2 & 54.0 & 10.7 & 10.3 & 57.2 & 38.4 & 41.5 & 62.3 & 46.7

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
46.7 & 46.7
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|████████████████████████████████████████████| 5/5 [09:23<00:00, 112.66s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.2 & 84.5 & 95.1 & 49.8 & 60.0 & 73.0 & 55.2 & 53.9 & 74.6 & 78.8 & 57.6 & 79.4 & 71.1

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
71.1 & 71.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
84.4 & 72.3 & 95.0 & 0.0 & 22.5 & 52.3 & 13.0 & 11.0 & 54.3 & 61.9 & 18.0 & 61.7 & 45.5

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
45.5 & 45.5





# Resnets

In [7]:
c = datasets.CrossDataset(resize=112, optical_flow=True)
df = c.data_frame
data = c.data

In [15]:
class Config(core.Config):
    epochs = 50
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    num_workers = 0
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary"),
    ]
    model = partial(timm.models.resnet10t, num_classes=len(core.Config.action_units), pretrained=True)

In [16]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|███████████████████████████████████████████| 10/10 [01:49<00:00, 10.91s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
60.9 & 61.1 & 78.4 & 49.8 & 48.7 & 45.9 & 48.7 & 48.4 & 46.3 & 50.8 & 48.4 & 47.4 & 52.9

Datasets:  ['casme', 'casme2', 'Average']
52.9 & 52.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
26.9 & 26.6 & 74.5 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 7.6 & 0.0 & 0.0 & 11.3

Datasets:  ['casme', 'casme2', 'Average']
11.3 & 11.3
['casme2', 'casme', 'samm']


100%|███████████████████████████████████████████| 10/10 [03:15<00:00, 19.54s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
76.4 & 80.6 & 86.8 & 49.8 & 48.7 & 47.4 & 53.5 & 48.4 & 51.9 & 53.8 & 48.4 & 52.0 & 58.1

Datasets:  ['casme', 'casme2', 'samm', 'Average']
58.1 & 58.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
56.2 & 65.1 & 85.2 & 0.0 & 0.0 & 3.0 & 9.5 & 0.0 & 11.4 & 13.6 & 0.0 & 8.9 & 21.1

Datasets:  ['casme', 'casme2', 'samm', 'Average']
21.1 & 21.1
['casme2', 'casme', 'samm', 'mmew']


100%|███████████████████████████████████████████| 10/10 [06:13<00:00, 37.33s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.6 & 82.2 & 92.2 & 49.7 & 48.7 & 59.8 & 53.5 & 50.1 & 53.2 & 58.4 & 48.4 & 67.4 & 63.1

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
63.1 & 63.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
88.6 & 68.1 & 91.7 & 0.0 & 0.0 & 27.3 & 9.8 & 3.5 & 13.4 & 23.4 & 0.0 & 38.9 & 30.4

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
30.4 & 30.4
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|███████████████████████████████████████████| 10/10 [08:26<00:00, 50.63s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.7 & 81.7 & 94.7 & 49.8 & 62.4 & 70.8 & 50.7 & 50.1 & 66.7 & 58.0 & 49.0 & 62.4 & 65.7

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
65.7 & 65.7
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
85.4 & 67.2 & 94.7 & 0.0 & 27.3 & 51.4 & 4.1 & 3.5 & 41.5 & 21.7 & 1.2 & 29.2 & 35.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
35.6 & 35.6
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|███████████████████████████████████████████| 10/10 [16:13<00:00, 97.36s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.5 & 83.8 & 94.5 & 49.6 & 55.8 & 71.5 & 52.0 & 54.4 & 70.4 & 72.3 & 48.4 & 71.2 & 67.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
67.9 & 67.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
85.0 & 71.0 & 94.4 & 0.0 & 14.0 & 49.4 & 6.6 & 11.9 & 46.2 & 50.6 & 0.0 & 45.9 & 39.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
39.6 & 39.6





In [5]:
class Config(core.Config):
    epochs = 50
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    num_workers = 0
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary"),
    ]
    model = partial(timm.models.resnet18, num_classes=len(core.Config.action_units), pretrained=True)

In [8]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|█████████████████████████████████████████████| 5/5 [02:26<00:00, 29.25s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
59.3 & 54.7 & 86.4 & 49.8 & 48.7 & 45.9 & 48.7 & 48.4 & 46.4 & 47.2 & 48.4 & 48.2 & 52.7

Datasets:  ['casme', 'casme2', 'Average']
52.7 & 52.7
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
23.5 & 13.9 & 86.4 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 1.5 & 10.4

Datasets:  ['casme', 'casme2', 'Average']
10.4 & 10.4
['casme2', 'casme', 'samm']


100%|█████████████████████████████████████████████| 5/5 [04:14<00:00, 50.85s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
67.4 & 84.6 & 92.1 & 49.8 & 48.7 & 50.1 & 48.6 & 48.4 & 51.9 & 53.9 & 49.6 & 60.7 & 58.8

Datasets:  ['casme', 'casme2', 'samm', 'Average']
58.8 & 58.8
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
39.1 & 72.3 & 91.8 & 0.0 & 0.0 & 8.4 & 0.0 & 0.0 & 11.2 & 13.9 & 2.4 & 26.0 & 22.1

Datasets:  ['casme', 'casme2', 'samm', 'Average']
22.1 & 22.1
['casme2', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████████| 5/5 [07:53<00:00, 94.74s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.6 & 81.9 & 94.2 & 49.7 & 48.7 & 57.0 & 48.6 & 52.9 & 57.6 & 67.7 & 51.8 & 68.7 & 63.9

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
63.9 & 63.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
77.5 & 67.5 & 94.2 & 0.0 & 0.0 & 22.4 & 0.0 & 9.0 & 21.9 & 42.8 & 6.8 & 41.5 & 32.0

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
32.0 & 32.0
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|████████████████████████████████████████████| 5/5 [10:58<00:00, 131.77s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
92.2 & 82.3 & 94.8 & 49.8 & 64.9 & 69.3 & 52.6 & 49.5 & 68.4 & 57.5 & 59.7 & 68.6 & 67.5

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
67.5 & 67.5
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.1 & 68.8 & 94.8 & 0.0 & 32.6 & 47.3 & 8.4 & 2.2 & 45.2 & 21.5 & 22.2 & 41.0 & 39.2

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
39.2 & 39.2
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|████████████████████████████████████████████| 5/5 [21:13<00:00, 254.63s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.6 & 85.7 & 95.2 & 49.7 & 54.5 & 69.0 & 49.3 & 54.2 & 75.2 & 77.4 & 56.5 & 80.0 & 69.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
69.9 & 69.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
85.2 & 74.3 & 95.1 & 0.0 & 11.4 & 44.6 & 1.7 & 11.5 & 55.1 & 59.4 & 15.9 & 63.0 & 43.1

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
43.1 & 43.1





In [20]:
class Config(core.Config):
    epochs = 50
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    num_workers = 0
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary"),
    ]
    model = partial(timm.models.resnet34, num_classes=len(core.Config.action_units), pretrained=True)

In [21]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|█████████████████████████████████████████████| 5/5 [02:14<00:00, 27.00s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
67.9 & 62.9 & 78.9 & 49.8 & 48.7 & 45.9 & 49.2 & 48.4 & 46.4 & 51.0 & 48.3 & 50.1 & 54.0

Datasets:  ['casme', 'casme2', 'Average']
54.0 & 54.0
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
42.5 & 33.3 & 76.2 & 0.0 & 0.0 & 0.0 & 1.6 & 0.0 & 0.0 & 7.7 & 0.0 & 5.4 & 13.9

Datasets:  ['casme', 'casme2', 'Average']
13.9 & 13.9
['casme2', 'casme', 'samm']


100%|█████████████████████████████████████████████| 5/5 [04:01<00:00, 48.36s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
71.4 & 69.4 & 87.0 & 49.7 & 48.7 & 51.6 & 50.1 & 48.4 & 55.5 & 58.3 & 49.4 & 67.6 & 58.9

Datasets:  ['casme', 'casme2', 'samm', 'Average']
58.9 & 58.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
47.6 & 44.8 & 86.4 & 0.0 & 0.0 & 11.7 & 2.9 & 0.0 & 18.2 & 23.1 & 2.1 & 39.7 & 23.0

Datasets:  ['casme', 'casme2', 'samm', 'Average']
23.0 & 23.0
['casme2', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████████| 5/5 [07:34<00:00, 90.98s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.0 & 80.0 & 93.0 & 49.7 & 48.7 & 60.1 & 53.2 & 51.7 & 62.5 & 63.4 & 51.8 & 70.3 & 64.3

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
64.3 & 64.3
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
76.5 & 64.0 & 92.8 & 0.0 & 0.0 & 28.9 & 9.9 & 6.7 & 31.8 & 33.7 & 6.7 & 44.8 & 33.0

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
33.0 & 33.0
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|████████████████████████████████████████████| 5/5 [10:37<00:00, 127.47s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.0 & 83.7 & 94.8 & 49.7 & 56.3 & 68.4 & 49.4 & 52.0 & 70.9 & 54.1 & 57.8 & 67.9 & 66.5

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
66.5 & 66.5
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.6 & 70.8 & 94.8 & 0.0 & 15.4 & 45.5 & 2.0 & 7.1 & 50.1 & 14.7 & 18.4 & 39.7 & 37.2

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
37.2 & 37.2
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|████████████████████████████████████████████| 5/5 [23:27<00:00, 281.56s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.5 & 85.3 & 94.2 & 49.7 & 52.7 & 70.0 & 50.8 & 53.9 & 67.9 & 76.9 & 56.6 & 79.2 & 69.2

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
69.2 & 69.2
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
88.5 & 73.6 & 94.1 & 0.0 & 8.0 & 46.6 & 4.7 & 11.2 & 41.4 & 58.9 & 16.3 & 61.6 & 42.1

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
42.1 & 42.1





In [13]:
class Config(core.Config):
    epochs = 50
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    num_workers = 0
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary"),
    ]
    model = partial(timm.models.resnet50, num_classes=len(core.Config.action_units), pretrained=True)

In [14]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=10)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|███████████████████████████████████████████| 10/10 [12:45<00:00, 76.55s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
47.3 & 47.8 & 36.5 & 49.8 & 48.7 & 45.9 & 48.7 & 48.4 & 46.4 & 47.2 & 48.4 & 47.4 & 46.9

Datasets:  ['casme', 'casme2', 'Average']
46.9 & 46.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
0.0 & 0.0 & 6.7 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.6

Datasets:  ['casme', 'casme2', 'Average']
0.6 & 0.6
['casme2', 'casme', 'samm']


100%|██████████████████████████████████████████| 10/10 [22:12<00:00, 133.20s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
47.3 & 47.8 & 47.4 & 49.8 & 48.7 & 45.9 & 48.7 & 48.4 & 46.4 & 47.2 & 48.4 & 47.4 & 47.8

Datasets:  ['casme', 'casme2', 'samm', 'Average']
47.8 & 47.8
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
0.0 & 0.0 & 24.8 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 2.1

Datasets:  ['casme', 'casme2', 'samm', 'Average']
2.1 & 2.1
['casme2', 'casme', 'samm', 'mmew']


100%|██████████████████████████████████████████| 10/10 [32:03<00:00, 192.37s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
50.8 & 55.5 & 74.9 & 49.8 & 48.7 & 49.3 & 48.7 & 48.4 & 46.4 & 47.2 & 48.4 & 47.4 & 51.3

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
51.3 & 51.3
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
6.7 & 15.4 & 69.2 & 0.0 & 0.0 & 6.6 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 8.2

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
8.2 & 8.2
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|██████████████████████████████████████████| 10/10 [45:01<00:00, 270.17s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
88.3 & 79.8 & 91.1 & 49.8 & 48.7 & 59.3 & 48.7 & 48.4 & 62.5 & 47.2 & 48.4 & 47.4 & 60.0

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
60.0 & 60.0
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
78.9 & 63.7 & 90.5 & 0.0 & 0.0 & 29.0 & 0.0 & 0.0 & 31.7 & 0.0 & 0.0 & 0.0 & 24.5

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
24.5 & 24.5
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|████████████████████████████████████████| 10/10 [1:01:38<00:00, 369.83s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.5 & 80.1 & 94.7 & 49.6 & 48.7 & 61.2 & 50.0 & 48.4 & 62.2 & 64.9 & 48.4 & 51.9 & 62.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
62.6 & 62.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
84.9 & 64.3 & 94.6 & 0.0 & 0.0 & 30.1 & 2.7 & 0.0 & 30.4 & 36.7 & 0.0 & 8.8 & 29.4

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
29.4 & 29.4





In [17]:
class Config(core.Config):
    epochs = 50
    device = torch.device("cuda:0")
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    num_workers = 0
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary"),
    ]
    model = partial(timm.models.resnet101, num_classes=len(core.Config.action_units), pretrained=True)

In [19]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|████████████████████████████████████████████| 5/5 [08:42<00:00, 104.44s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
47.3 & 47.5 & 33.6 & 49.8 & 48.7 & 45.9 & 48.7 & 48.4 & 46.4 & 47.1 & 48.4 & 49.0 & 46.7

Datasets:  ['casme', 'casme2', 'Average']
46.7 & 46.7
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
0.0 & 0.0 & 1.5 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 0.4

Datasets:  ['casme', 'casme2', 'Average']
0.4 & 0.4
['casme2', 'casme', 'samm']


100%|████████████████████████████████████████████| 5/5 [14:55<00:00, 179.17s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
47.3 & 47.7 & 34.4 & 49.8 & 48.7 & 45.8 & 48.7 & 48.4 & 46.4 & 47.2 & 48.4 & 47.4 & 46.7

Datasets:  ['casme', 'casme2', 'samm', 'Average']
46.7 & 46.7
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
0.0 & 0.0 & 2.9 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.2

Datasets:  ['casme', 'casme2', 'samm', 'Average']
0.2 & 0.2
['casme2', 'casme', 'samm', 'mmew']


100%|████████████████████████████████████████████| 5/5 [30:46<00:00, 369.21s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
50.4 & 55.0 & 73.6 & 49.7 & 48.7 & 46.4 & 48.7 & 48.4 & 46.4 & 47.6 & 48.4 & 48.2 & 51.0

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
51.0 & 51.0
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
6.3 & 15.1 & 67.9 & 0.0 & 0.0 & 1.0 & 0.0 & 0.0 & 0.0 & 1.1 & 0.0 & 1.5 & 7.7

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
7.7 & 7.7
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|████████████████████████████████████████████| 5/5 [29:49<00:00, 357.95s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
76.0 & 75.1 & 86.2 & 49.8 & 48.7 & 57.5 & 48.7 & 48.4 & 54.2 & 47.2 & 52.0 & 51.4 & 57.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
57.9 & 57.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
56.0 & 54.2 & 85.1 & 0.0 & 0.0 & 29.3 & 0.0 & 0.0 & 15.8 & 0.0 & 7.9 & 8.3 & 21.4

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
21.4 & 21.4
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|████████████████████████████████████████████| 5/5 [57:14<00:00, 686.88s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
89.2 & 80.4 & 91.5 & 49.6 & 48.7 & 53.8 & 48.7 & 48.2 & 58.4 & 64.6 & 50.5 & 54.2 & 61.5

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
61.5 & 61.5
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
80.9 & 65.0 & 91.1 & 0.0 & 0.0 & 16.6 & 0.0 & 0.0 & 23.7 & 36.1 & 4.4 & 13.6 & 27.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
27.6 & 27.6





### resnet 2+1d

In [7]:
c = datasets.CrossDataset(resize=112, color=True, preload=True)
df = c.data_frame
data = c.data

Loading data: 100%|███████████████████████████| 189/189 [00:29<00:00,  6.32it/s]
Loading data: 100%|███████████████████████████| 256/256 [02:15<00:00,  1.88it/s]
Loading data: 100%|███████████████████████████| 159/159 [01:51<00:00,  1.43it/s]
Loading data: 100%|███████████████████████████| 267/267 [00:43<00:00,  6.14it/s]
Loading data: 100%|███████████████████████████| 300/300 [01:15<00:00,  3.97it/s]
Loading data: 100%|███████████████████████████| 860/860 [03:42<00:00,  3.87it/s]




In [8]:
import torch.nn.functional as F
#interpolate samples with less than 8 frames
n_frames = 8
for i, video in enumerate(data):
    if video.shape[0] < n_frames:
        new_shape = (n_frames,) + video.shape[1:-1]
        video = torch.tensor(video).permute(3, 0, 1, 2).unsqueeze(0).float()
        new_video = F.interpolate(video, size=new_shape, mode="trilinear")
        data[i] = new_video.squeeze(0).permute(1, 2, 3, 0).byte().numpy()

In [9]:
# Create a function that returns the model as it needs to be modified
def r2plus1d(num_classes: int):
    model = torchvision.models.video.r2plus1d_18(weights=torchvision.models.video.R2Plus1D_18_Weights.DEFAULT)
    model.fc = nn.Linear(in_features=512, out_features=num_classes)
    return model

In [10]:
class Config(core.Config):
    device = torch.device("cuda:0")
    epochs = 100
    batch_size = 16
    optimizer = partial(optim.Adam, lr=1e-4, weight_decay=1e-3)
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary")
    ]
    train_transform = {
        "spatial": None,
        "temporal": datasets.UniformTemporalSubsample(8),
    }
    test_transform = {
        "spatial": None,
        "temporal": datasets.UniformTemporalSubsample(8),
    }
    model = partial(r2plus1d, num_classes=len(core.Config.action_units))

In [11]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd", "casme3a"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|█████████████████████████████████████████| 5/5 [1:42:40<00:00, 1232.07s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
49.0 & 47.6 & 71.7 & 49.8 & 48.7 & 45.9 & 50.1 & 48.4 & 47.5 & 49.7 & 55.7 & 48.7 & 51.1

Datasets:  ['casme', 'casme2', 'Average']
51.1 & 51.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
4.6 & 0.0 & 67.6 & 0.0 & 0.0 & 0.0 & 2.9 & 0.0 & 2.2 & 5.3 & 15.6 & 2.6 & 8.4

Datasets:  ['casme', 'casme2', 'Average']
8.4 & 8.4
['casme2', 'casme', 'samm']


100%|█████████████████████████████████████████| 5/5 [3:05:23<00:00, 2224.67s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
48.2 & 47.9 & 70.7 & 49.6 & 48.7 & 47.0 & 49.9 & 48.3 & 46.9 & 50.0 & 49.1 & 50.2 & 50.6

Datasets:  ['casme', 'casme2', 'samm', 'Average']
50.6 & 50.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
2.5 & 1.1 & 68.7 & 0.0 & 0.0 & 3.3 & 2.5 & 0.0 & 1.1 & 7.1 & 1.7 & 6.2 & 7.8

Datasets:  ['casme', 'casme2', 'samm', 'Average']
7.8 & 7.8
['casme2', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████| 5/5 [5:41:10<00:00, 4094.10s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
64.8 & 61.9 & 84.7 & 49.6 & 48.7 & 51.5 & 48.6 & 50.4 & 51.2 & 54.8 & 54.1 & 56.8 & 56.4

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
56.4 & 56.4
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
35.0 & 28.0 & 83.4 & 0.0 & 0.0 & 11.2 & 0.0 & 4.0 & 10.3 & 15.2 & 11.3 & 19.0 & 18.1

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
18.1 & 18.1
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|█████████████████████████████████████████| 5/5 [7:59:26<00:00, 5753.36s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.5 & 77.5 & 87.8 & 49.7 & 48.7 & 49.0 & 49.7 & 48.4 & 51.9 & 59.5 & 53.7 & 58.1 & 60.0

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
60.0 & 60.0
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
75.4 & 58.5 & 86.9 & 0.0 & 0.0 & 6.3 & 2.2 & 0.0 & 10.8 & 24.5 & 10.4 & 21.2 & 24.7

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
24.7 & 24.7
['casme2', 'casme', 'samm', 'mmew', 'fourd', 'casme3a']


100%|███████████████████████████████████████| 5/5 [15:23:04<00:00, 11076.92s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.2 & 76.6 & 90.6 & 49.7 & 48.7 & 59.8 & 48.4 & 52.5 & 57.6 & 66.6 & 62.9 & 74.1 & 64.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
64.9 & 64.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
84.1 & 56.6 & 90.2 & 0.0 & 0.0 & 26.9 & 0.0 & 8.4 & 22.0 & 38.4 & 28.6 & 51.6 & 33.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
33.9 & 33.9



