In [1]:
import meb
from meb import utils
from meb import datasets
from meb import core
from meb import models

from functools import partial
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import timm



pd.set_option("display.max_columns", 50)
%load_ext autoreload
%autoreload 2

In [35]:
c = datasets.CrossDataset(resize=64, optical_flow=True)
df = c.data_frame
data = c.data

In [12]:
class Config(core.Config):
    device = torch.device("cuda:0")
    num_workers = 0
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary")
    ]
    epochs = 200
    batch_size = 128
    model = partial(meb.models.SSSNet, num_classes=len(core.Config.action_units))

In [6]:
# Only validate casme2
class IValidator(core.CrossDatasetValidator):
    def __init__(self, config: "Config"):
        super().__init__(config)
        
    def validate_n_times(
        self, df: pd.DataFrame, input_data, n_times: int = 5
    ) -> None:
        self.verbose = False
        self.disable_tqdm = True
        au_results = []
        dataset_results = []
        casme2_idx = df["dataset"] == "casme2"
        for n in tqdm(range(n_times)):
            outputs_list = self.validate(df, input_data, seed_n=n + 45)
            au_result, dataset_result = self.printer.results_to_list(outputs_list, df[casme2_idx])
            au_results.append(au_result)
            dataset_results.append(dataset_result)

        aus = [i for i in self.cf.action_units]
        dataset_names = df["dataset"].unique().tolist()
        aus.append("Average")
        dataset_names.append("Average")
        au_results = np.array(au_results)
        dataset_results = np.array(dataset_results)
        for i in range(len(self.cf.evaluation_fn)):
            if len(self.cf.evaluation_fn) > 1:
                print(self.printer.metric_name(self.cf.evaluation_fn[i]))
            au_result = self.printer.list_to_latex(list(au_results[:, i].mean(axis=0)))
            dataset_result = self.printer.list_to_latex(
                list(dataset_results[:, i].mean(axis=0))
            )
            print("AUS:", aus)
            print(au_result)
            print("\nDatasets: ", dataset_names)
            print(dataset_result)
    
    def validate(self, df: pd.DataFrame, input_data: np.ndarray, seed_n: int = 1):
        utils.set_random_seeds(seed_n)
        dataset_names = df["dataset"].unique()
        # Create a boolean array with the AUs
        labels = np.array(df[self.cf.action_units])
        outputs_list = []
        for dataset_name in dataset_names:
            if dataset_name != "casme2":
                continue
            train_metrics, test_metrics, outputs_test = self.validate_split(
                df, input_data, labels, dataset_name
            )
            outputs_list.append(outputs_test)
            if self.verbose:
                self.printer.print_train_test_evaluation(
                    train_metrics, test_metrics, dataset_name, outputs_test.shape[0]
                )

        # Calculate total f1-scores
        predictions = torch.cat(outputs_list)
        idx = df["dataset"] == "casme2"
        metrics = self.evaluation_fn(labels[idx], predictions)
        if self.verbose:
            self.printer.print_test_validation(metrics)
        return outputs_list

## First row
Concatenating other datasets, no casme3

In [13]:
use_datasets = ["casme2"]
for dataset in ["casme", "samm", "mmew", "fourd"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme']


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|█████████████████████████████████████████████| 5/5 [00:27<00:00,  5.49s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
90.7 & 79.2 & 94.2 & 49.8 & 48.7 & 45.9 & 55.8 & 48.4 & 56.3 & 64.0 & 71.0 & 76.5 & 65.1

Datasets:  ['casme', 'casme2', 'Average']
65.1 & 65.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
83.7 & 62.9 & 94.3 & 0.0 & 0.0 & 0.0 & 14.4 & 0.0 & 19.2 & 36.1 & 44.4 & 56.3 & 34.3

Datasets:  ['casme', 'casme2', 'Average']
34.3 & 34.3
['casme2', 'casme', 'samm']


100%|█████████████████████████████████████████████| 5/5 [00:44<00:00,  8.95s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
95.5 & 84.7 & 94.4 & 49.8 & 48.7 & 63.5 & 63.8 & 48.2 & 71.3 & 68.9 & 70.3 & 77.9 & 69.7

Datasets:  ['casme', 'casme2', 'samm', 'Average']
69.7 & 69.7
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.9 & 72.5 & 94.3 & 0.0 & 0.0 & 34.0 & 30.1 & 0.0 & 48.4 & 43.0 & 42.9 & 58.9 & 43.0

Datasets:  ['casme', 'casme2', 'samm', 'Average']
43.0 & 43.0
['casme2', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████████| 5/5 [01:26<00:00, 17.39s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.7 & 82.8 & 96.2 & 48.1 & 52.9 & 68.0 & 56.3 & 51.0 & 69.1 & 67.3 & 61.2 & 80.3 & 68.9

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
68.9 & 68.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
88.8 & 69.4 & 96.3 & 0.0 & 8.4 & 43.8 & 15.3 & 5.8 & 44.6 & 39.6 & 25.2 & 63.4 & 41.7

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'Average']
41.7 & 41.7
['casme2', 'casme', 'samm', 'mmew', 'fourd']


100%|█████████████████████████████████████████████| 5/5 [01:57<00:00, 23.45s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.1 & 84.2 & 96.5 & 49.8 & 63.8 & 72.6 & 54.0 & 53.4 & 75.6 & 66.8 & 69.6 & 79.7 & 71.6

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
71.6 & 71.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.8 & 71.9 & 96.4 & 0.0 & 30.2 & 54.0 & 10.7 & 10.3 & 57.2 & 38.4 & 41.5 & 62.3 & 46.7

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'Average']
46.7 & 46.7





## Second row
Adding casme3 to the concatenated datasets

In [32]:
use_datasets = ["casme2"]
for dataset in ["casme3a", "casme", "samm", "mmew", "fourd"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme3a']


100%|█████████████████████████████████████████████| 5/5 [01:46<00:00, 21.20s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.0 & 85.6 & 92.9 & 49.4 & 48.7 & 51.3 & 49.5 & 48.4 & 55.3 & 74.5 & 48.4 & 76.9 & 64.5

Datasets:  ['casme2', 'casme3a', 'Average']
64.5 & 64.5
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.7 & 74.0 & 92.8 & 0.0 & 0.0 & 10.5 & 2.1 & 0.0 & 17.3 & 55.2 & 0.0 & 57.0 & 33.1

Datasets:  ['casme2', 'casme3a', 'Average']
33.1 & 33.1
['casme2', 'casme3a', 'casme']


100%|█████████████████████████████████████████████| 5/5 [02:12<00:00, 26.51s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.5 & 84.6 & 94.4 & 49.7 & 48.7 & 50.4 & 51.5 & 48.4 & 52.8 & 76.8 & 56.3 & 76.9 & 65.2

Datasets:  ['casme', 'casme2', 'casme3a', 'Average']
65.2 & 65.2
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
85.1 & 72.4 & 94.4 & 0.0 & 0.0 & 8.7 & 6.4 & 0.0 & 12.6 & 58.9 & 15.5 & 57.0 & 34.2

Datasets:  ['casme', 'casme2', 'casme3a', 'Average']
34.2 & 34.2
['casme2', 'casme3a', 'casme', 'samm']


100%|█████████████████████████████████████████████| 5/5 [02:29<00:00, 29.86s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
92.3 & 84.1 & 95.0 & 49.7 & 48.7 & 60.6 & 57.4 & 49.4 & 67.0 & 76.6 & 60.6 & 77.3 & 68.2

Datasets:  ['casme', 'casme2', 'samm', 'casme3a', 'Average']
68.2 & 68.2
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.4 & 71.5 & 94.9 & 0.0 & 0.0 & 28.3 & 17.5 & 2.2 & 39.7 & 58.1 & 24.1 & 57.6 & 40.0

Datasets:  ['casme', 'casme2', 'samm', 'casme3a', 'Average']
40.0 & 40.0
['casme2', 'casme3a', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████████| 5/5 [03:01<00:00, 36.34s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.9 & 83.9 & 94.5 & 49.8 & 48.7 & 65.0 & 59.8 & 53.3 & 67.2 & 77.1 & 54.4 & 78.6 & 68.8

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'casme3a', 'Average']
68.8 & 68.8
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
89.1 & 71.1 & 94.4 & 0.0 & 0.0 & 37.0 & 22.4 & 10.2 & 40.1 & 59.4 & 11.8 & 60.1 & 41.3

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'casme3a', 'Average']
41.3 & 41.3
['casme2', 'casme3a', 'casme', 'samm', 'mmew', 'fourd']


100%|█████████████████████████████████████████████| 5/5 [03:30<00:00, 42.12s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
91.2 & 84.5 & 95.1 & 49.8 & 60.0 & 73.0 & 55.2 & 53.9 & 74.6 & 78.8 & 57.6 & 79.4 & 71.1

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
71.1 & 71.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
84.4 & 72.3 & 95.0 & 0.0 & 22.5 & 52.3 & 13.0 & 11.0 & 54.3 & 61.9 & 18.0 & 61.7 & 45.5

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
45.5 & 45.5





## Third row
Adding both part_A and part_C

In [37]:
use_datasets = ["casme2"]
for dataset in ["casme3a", "casme3c", "casme", "samm", "mmew", "fourd"]:
    use_datasets.append(dataset)
    idx = df["dataset"].isin(use_datasets)
    print(use_datasets)
    IValidator(Config).validate_n_times(df[idx].reset_index(), data[idx], n_times=5)

['casme2', 'casme3a']


100%|█████████████████████████████████████████████| 5/5 [01:45<00:00, 21.17s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.0 & 85.6 & 92.9 & 49.4 & 48.7 & 51.3 & 49.5 & 48.4 & 55.3 & 74.5 & 48.4 & 76.9 & 64.5

Datasets:  ['casme2', 'casme3a', 'Average']
64.5 & 64.5
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
87.7 & 74.0 & 92.8 & 0.0 & 0.0 & 10.5 & 2.1 & 0.0 & 17.3 & 55.2 & 0.0 & 57.0 & 33.1

Datasets:  ['casme2', 'casme3a', 'Average']
33.1 & 33.1
['casme2', 'casme3a', 'casme3c']


100%|█████████████████████████████████████████████| 5/5 [02:11<00:00, 26.20s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
94.1 & 87.4 & 93.0 & 49.5 & 48.7 & 50.4 & 48.5 & 48.3 & 54.9 & 76.6 & 48.4 & 61.4 & 63.4

Datasets:  ['casme2', 'casme3a', 'casme3c', 'Average']
63.4 & 63.4
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
89.5 & 77.3 & 93.1 & 0.0 & 0.0 & 8.7 & 0.0 & 0.0 & 16.5 & 58.2 & 0.0 & 27.1 & 30.9

Datasets:  ['casme2', 'casme3a', 'casme3c', 'Average']
30.9 & 30.9
['casme2', 'casme3a', 'casme3c', 'casme']


100%|█████████████████████████████████████████████| 5/5 [02:29<00:00, 29.86s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
92.1 & 85.1 & 94.8 & 49.7 & 48.7 & 48.0 & 55.0 & 48.4 & 59.6 & 76.5 & 52.0 & 70.9 & 65.1

Datasets:  ['casme', 'casme2', 'casme3a', 'casme3c', 'Average']
65.1 & 65.1
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.2 & 73.2 & 94.9 & 0.0 & 0.0 & 4.0 & 13.1 & 0.0 & 25.6 & 58.2 & 7.1 & 45.4 & 34.0

Datasets:  ['casme', 'casme2', 'casme3a', 'casme3c', 'Average']
34.0 & 34.0
['casme2', 'casme3a', 'casme3c', 'casme', 'samm']


100%|█████████████████████████████████████████████| 5/5 [02:46<00:00, 33.38s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
93.4 & 84.8 & 95.1 & 49.7 & 48.7 & 56.3 & 59.2 & 48.3 & 70.5 & 75.7 & 56.5 & 69.8 & 67.3

Datasets:  ['casme', 'casme2', 'samm', 'casme3a', 'casme3c', 'Average']
67.3 & 67.3
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
88.3 & 72.8 & 95.1 & 0.0 & 0.0 & 20.3 & 21.1 & 0.0 & 46.3 & 56.4 & 15.9 & 43.4 & 38.3

Datasets:  ['casme', 'casme2', 'samm', 'casme3a', 'casme3c', 'Average']
38.3 & 38.3
['casme2', 'casme3a', 'casme3c', 'casme', 'samm', 'mmew']


100%|█████████████████████████████████████████████| 5/5 [03:29<00:00, 41.82s/it]


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
92.2 & 84.6 & 94.7 & 49.7 & 48.7 & 63.7 & 57.9 & 52.4 & 68.2 & 75.7 & 53.2 & 69.3 & 67.5

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'casme3a', 'casme3c', 'Average']
67.5 & 67.5
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.2 & 72.5 & 94.7 & 0.0 & 0.0 & 34.4 & 18.7 & 8.3 & 42.0 & 56.5 & 9.4 & 42.3 & 38.8

Datasets:  ['casme', 'casme2', 'samm', 'mmew', 'casme3a', 'casme3c', 'Average']
38.8 & 38.8
['casme2', 'casme3a', 'casme3c', 'casme', 'samm', 'mmew', 'fourd']


100%|█████████████████████████████████████████████| 5/5 [03:59<00:00, 47.95s/it]

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
92.2 & 84.3 & 94.3 & 49.8 & 58.0 & 70.7 & 53.7 & 54.0 & 74.4 & 78.7 & 60.5 & 67.8 & 69.9

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'casme3c', 'Average']
69.9 & 69.9
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
86.2 & 72.0 & 94.3 & 0.0 & 18.4 & 48.0 & 10.0 & 11.3 & 53.9 & 61.2 & 23.9 & 39.5 & 43.2

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'casme3c', 'Average']
43.2 & 43.2



