In [2]:
import meb
from meb import utils
from meb import datasets
from meb import core
from meb import models

from functools import partial
from typing import List, Tuple

import numpy as np
import pandas as pd
from numba import jit, njit
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import timm
from tqdm import tqdm
from sklearn.svm import SVC
from sklearn.multioutput import MultiOutputClassifier



pd.set_option("display.max_columns", 50)
%load_ext autoreload
%autoreload 2

In [3]:
c = datasets.CrossDataset(optical_flow=True, resize=28)
df = c.data_frame
data = c.data

In [4]:
class GeneticAlgorithm:
    def __init__(self, features: np.ndarray, labels: np.ndarray, subject_idx: int,
                 population_size: int = 50, d: int = 50, chr_length: int = 400,
                 change_rate: float = 0.8, multi_label: bool = False
    ):
        self.features = features
        self.labels = labels
        self.subject_idx = subject_idx
        self.population_size = population_size
        self.d = d
        self.chr_length = chr_length
        self.change_rate = change_rate
        self.clf = SVC(kernel="rbf", C=2, random_state=42)
        if multi_label:
            self.clf = MultiOutputClassifier(self.clf)
    
    def perform(self) -> Tuple[np.ndarray, np.ndarray]:
        pop = np.random.randint(2, size=(self.population_size, self.chr_length))
        for i in tqdm(range(self.d)):
            a_pop = crosspop(pop, self.change_rate)
            b_pop = variation(pop)
            new_pop = np.concatenate((pop, a_pop, b_pop))
            _, next_pop = self.select_svm(new_pop)
            pop = next_pop[:self.population_size]
            best_pop = next_pop[0]
        prediction = self.evaluate(best_pop)
        return prediction
    
    def select_svm(self, pop: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        d, w = pop.shape
        AC = np.zeros((self.d, 2))
        for i in range(self.d):
            if sum(pop[i]) != 0:
                # Select features based on the mutation
                selected_features = self.features[:, pop[i] == 1]
                # Split the data to train and test
                train_features = selected_features[~self.subject_idx]
                train_labels = self.labels[~self.subject_idx]
                # Cross validation with subjects not in the testing data
                scores = cross_val_score(self.clf, train_features, train_labels, cv=5)
                AC[i] = [np.mean(scores), i]
            else:
                AC[i] = [0, i]
        # Sort based on best performance
        idx = np.argsort(AC[:, 0])[::-1]
        AC = AC[idx]
        new_pop = pop[idx]
        return AC, new_pop  

    def evaluate(self, pop: np.ndarray) -> np.ndarray:
        # Select features based on the mutation
        selected_features = self.features[:, pop == 1]
        # Split the data to train and test
        train_features = selected_features[~self.subject_idx]
        train_labels = self.labels[~self.subject_idx]
        test_features = selected_features[self.subject_idx]
        test_labels = self.labels[self.subject_idx]
        # Train and evaluate
        self.clf.fit(train_features, train_labels)
        prediction = self.clf.predict(test_features)
        return prediction
    
    
@njit
def judgepop(pop1, pop2):
    return (pop1 != pop2).sum()


@njit
def concat(arr1, arr2) -> np.ndarray:
    """Custom concat function to avoid case where one array
    of np.concatenate is empty which numba doesn't like"""
    if len(arr1) == 0:
        return arr2
    elif len(arr2) == 0:
        return arr1
    return np.concatenate((arr1, arr2))


@njit(fastmath=True)
def variation(pop: np.ndarray) -> np.ndarray:
    rate = 0.2
    u, w = pop.shape
    n_pop = np.zeros_like(pop)
    for i in range(u):
        if i == 0 or judgepop(pop[0], pop[i]) >= w * rate:
            locate = int(np.random.rand() * w)
            n_pop[i] = pop[i]
            if pop[i, locate] == 0:
                n_pop[i, locate] = 1
            else:
                n_pop[i, locate] = 0
        else:
            x = np.round(w * rate)
            temp = np.random.permutation(w)[:x]
            n_pop[i] = pop[i]
            for j in range(x):
                if pop[i, temp[j]] == 0:
                    n_pop[i, temp[j]] = 1
                else:
                    n_pop[i, temp[j]] = 0
    return n_pop


@njit
def crosspop(pop: np.ndarray, change_rate: float) -> np.ndarray:
    u, w = pop.shape
    m = np.random.permutation(u)
    n_pop = np.zeros_like(pop)
    for i in range(0, u, 2):
        if np.random.rand(1) < change_rate:
            locate = np.int(np.random.rand() * w)
            n_pop[i] = concat(pop[m[i], :locate], pop[m[i + 1], locate:])
            n_pop[i + 1] = concat(pop[m[i + 1], :locate], pop[m[i], locate:]) 
    return n_pop

In [5]:
from torch import optim
from torch import nn
from sklearn import svm
from sklearn.model_selection import cross_val_score

In [6]:
class STSTNet(nn.Module):
    def __init__(self, in_channels: int = 3, num_channels: int = 3, **kwargs):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels=3, kernel_size=3, padding=2)
        self.conv2 = nn.Conv2d(in_channels, out_channels=5, kernel_size=3, padding=2)
        self.conv3 = nn.Conv2d(in_channels, out_channels=8, kernel_size=3, padding=2)
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(3)
        self.bn2 = nn.BatchNorm2d(5)
        self.bn3 = nn.BatchNorm2d(8)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=3, padding=1)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(in_features=5 * 5 * 16, out_features=num_channels)

    def forward(self, x, only_features: bool = False):
        x1 = self.dropout(self.maxpool(self.bn1(self.relu(self.conv1(x)))))
        x2 = self.dropout(self.maxpool(self.bn2(self.relu(self.conv2(x)))))
        x3 = self.dropout(self.maxpool(self.bn3(self.relu(self.conv3(x)))))
        x = torch.cat((x1, x2, x3), 1)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        if only_features:
            return x
        x = self.fc(x)
        return x

In [7]:
class Config(core.Config):
    action_units = utils.dataset_aus["cross"]
    device = torch.device("cuda:1")
    epochs = 500
    num_workers = 0
    optimizer = partial(optim.Adam, lr=5e-5, weight_decay=1e-3)
    scheduler = None
    model = partial(STSTNet, num_channels=len(action_units))
    game_d = 50
    evaluation_fn = [
        partial(utils.MultiLabelF1Score, average="macro"),
        partial(utils.MultiLabelF1Score, average="binary"),
    ]

In [10]:
out = GAMEValidation(Config).validate_n_times(df, data, n_times=5)

  0%|                                                     | 0/5 [00:00<?, ?it/s]
  0%|                                                    | 0/50 [00:00<?, ?it/s][A
  2%|▊                                        | 1/50 [06:40<5:27:27, 400.97s/it][A
  4%|█▋                                       | 2/50 [13:16<5:18:20, 397.92s/it][A
  6%|██▍                                      | 3/50 [19:53<5:11:26, 397.57s/it][A
  8%|███▎                                     | 4/50 [26:34<5:05:46, 398.84s/it][A
 10%|████                                     | 5/50 [33:14<4:59:15, 399.01s/it][A
 12%|████▉                                    | 6/50 [39:57<4:53:37, 400.41s/it][A
 14%|█████▋                                   | 7/50 [46:38<4:47:07, 400.65s/it][A
 16%|██████▌                                  | 8/50 [53:20<4:40:51, 401.23s/it][A
 18%|███████                                | 9/50 [1:00:00<4:33:57, 400.91s/it][A
 20%|███████▌                              | 10/50 [1:06:43<4:27:31, 401.28s/it

 80%|████████████████████████████████        | 40/50 [3:30:07<53:03, 318.33s/it][A
 82%|████████████████████████████████▊       | 41/50 [3:35:25<47:44, 318.28s/it][A
 84%|█████████████████████████████████▌      | 42/50 [3:40:46<42:31, 318.94s/it][A
 86%|██████████████████████████████████▍     | 43/50 [3:46:06<37:14, 319.24s/it][A
 88%|███████████████████████████████████▏    | 44/50 [3:51:26<31:57, 319.51s/it][A
 90%|████████████████████████████████████    | 45/50 [3:56:44<26:34, 319.00s/it][A
 92%|████████████████████████████████████▊   | 46/50 [4:02:01<21:14, 318.52s/it][A
 94%|█████████████████████████████████████▌  | 47/50 [4:07:27<16:02, 320.73s/it][A
 96%|██████████████████████████████████████▍ | 48/50 [4:12:51<10:43, 321.79s/it][A
 98%|███████████████████████████████████████▏| 49/50 [4:18:15<05:22, 322.29s/it][A
100%|████████████████████████████████████████| 50/50 [4:23:40<00:00, 316.41s/it][A

  0%|                                                    | 0/50 [00:00<?, ?

 58%|██████████████████████                | 29/50 [2:34:23<1:51:29, 318.54s/it][A
 60%|██████████████████████▊               | 30/50 [2:39:47<1:46:40, 320.00s/it][A
 62%|███████████████████████▌              | 31/50 [2:45:05<1:41:09, 319.45s/it][A
 64%|████████████████████████▎             | 32/50 [2:50:20<1:35:27, 318.21s/it][A
 66%|█████████████████████████             | 33/50 [2:55:37<1:30:05, 317.97s/it][A
 68%|█████████████████████████▊            | 34/50 [3:00:58<1:24:58, 318.65s/it][A
 70%|██████████████████████████▌           | 35/50 [3:06:16<1:19:36, 318.42s/it][A
 72%|███████████████████████████▎          | 36/50 [3:11:38<1:14:34, 319.64s/it][A
 74%|████████████████████████████          | 37/50 [3:16:56<1:09:08, 319.10s/it][A
 76%|████████████████████████████▉         | 38/50 [3:22:10<1:03:31, 317.65s/it][A
 78%|███████████████████████████████▏        | 39/50 [3:27:28<58:14, 317.70s/it][A
 80%|████████████████████████████████        | 40/50 [3:32:46<52:58, 317.83s

 38%|███████████████▏                        | 19/50 [54:11<1:28:21, 171.01s/it][A
 40%|████████████████                        | 20/50 [57:01<1:25:23, 170.80s/it][A
 42%|████████████████▊                       | 21/50 [59:50<1:22:13, 170.11s/it][A
 44%|████████████████▋                     | 22/50 [1:02:40<1:19:22, 170.11s/it][A
 46%|█████████████████▍                    | 23/50 [1:05:28<1:16:20, 169.63s/it][A
 48%|██████████████████▏                   | 24/50 [1:08:18<1:13:30, 169.65s/it][A
 50%|███████████████████                   | 25/50 [1:11:07<1:10:35, 169.42s/it][A
 52%|███████████████████▊                  | 26/50 [1:13:57<1:07:48, 169.54s/it][A
 54%|████████████████████▌                 | 27/50 [1:16:46<1:04:58, 169.51s/it][A
 56%|█████████████████████▎                | 28/50 [1:19:36<1:02:09, 169.51s/it][A
 58%|███████████████████████▏                | 29/50 [1:22:25<59:17, 169.43s/it][A
 60%|████████████████████████                | 30/50 [1:25:12<56:17, 168.90s

 16%|██████▌                                  | 8/50 [40:42<3:33:26, 304.92s/it][A
 18%|███████▍                                 | 9/50 [45:48<3:28:28, 305.09s/it][A
 20%|████████                                | 10/50 [50:53<3:23:26, 305.17s/it][A
 22%|████████▊                               | 11/50 [56:00<3:18:43, 305.72s/it][A
 24%|█████████                             | 12/50 [1:01:10<3:14:28, 307.05s/it][A
 26%|█████████▉                            | 13/50 [1:06:23<3:10:24, 308.77s/it][A
 28%|██████████▋                           | 14/50 [1:11:33<3:05:28, 309.12s/it][A
 30%|███████████▍                          | 15/50 [1:16:39<2:59:48, 308.24s/it][A
 32%|████████████▏                         | 16/50 [1:21:44<2:54:01, 307.10s/it][A
 34%|████████████▉                         | 17/50 [1:26:46<2:48:08, 305.71s/it][A
 36%|█████████████▋                        | 18/50 [1:31:52<2:43:06, 305.84s/it][A
 38%|██████████████▍                       | 19/50 [1:37:00<2:38:15, 306.31s

 96%|██████████████████████████████████████▍ | 48/50 [4:32:14<11:23, 341.59s/it][A
 98%|███████████████████████████████████████▏| 49/50 [4:38:00<05:43, 343.05s/it][A
100%|████████████████████████████████████████| 50/50 [4:43:37<00:00, 340.36s/it][A

  0%|                                                    | 0/50 [00:00<?, ?it/s][A
  2%|▊                                        | 1/50 [05:20<4:21:33, 320.28s/it][A
  4%|█▋                                       | 2/50 [10:39<4:15:44, 319.67s/it][A
  6%|██▍                                      | 3/50 [15:56<4:09:29, 318.51s/it][A
  8%|███▎                                     | 4/50 [21:18<4:05:08, 319.74s/it][A
 10%|████                                     | 5/50 [26:37<3:59:34, 319.44s/it][A
 12%|████▉                                    | 6/50 [31:57<3:54:24, 319.64s/it][A
 14%|█████▋                                   | 7/50 [37:18<3:49:23, 320.08s/it][A
 16%|██████▌                                  | 8/50 [42:38<3:44:04, 320.10

 76%|██████████████████████████████▍         | 38/50 [2:58:54<56:19, 281.63s/it][A
 78%|███████████████████████████████▏        | 39/50 [3:03:35<51:34, 281.32s/it][A
 80%|████████████████████████████████        | 40/50 [3:08:16<46:52, 281.22s/it][A
 82%|████████████████████████████████▊       | 41/50 [3:12:57<42:11, 281.29s/it][A
 84%|█████████████████████████████████▌      | 42/50 [3:17:38<37:29, 281.13s/it][A
 86%|██████████████████████████████████▍     | 43/50 [3:22:19<32:46, 280.95s/it][A
 88%|███████████████████████████████████▏    | 44/50 [3:26:59<28:04, 280.82s/it][A
 90%|████████████████████████████████████    | 45/50 [3:31:40<23:23, 280.66s/it][A
 92%|████████████████████████████████████▊   | 46/50 [3:36:20<18:42, 280.55s/it][A
 94%|█████████████████████████████████████▌  | 47/50 [3:41:00<14:01, 280.44s/it][A
 96%|██████████████████████████████████████▍ | 48/50 [3:45:40<09:20, 280.33s/it][A
 98%|███████████████████████████████████████▏| 49/50 [3:50:20<04:40, 280.27s

 54%|████████████████████▌                 | 27/50 [2:31:38<2:09:01, 336.61s/it][A
 56%|█████████████████████▎                | 28/50 [2:37:14<2:03:18, 336.29s/it][A
 58%|██████████████████████                | 29/50 [2:42:50<1:57:44, 336.40s/it][A
 60%|██████████████████████▊               | 30/50 [2:48:26<1:52:04, 336.21s/it][A
 62%|███████████████████████▌              | 31/50 [2:54:01<1:46:23, 335.95s/it][A
 64%|████████████████████████▎             | 32/50 [2:59:37<1:40:45, 335.84s/it][A
 66%|█████████████████████████             | 33/50 [3:05:13<1:35:08, 335.80s/it][A
 68%|█████████████████████████▊            | 34/50 [3:10:51<1:29:44, 336.52s/it][A
 70%|██████████████████████████▌           | 35/50 [3:16:33<1:24:31, 338.13s/it][A
 72%|███████████████████████████▎          | 36/50 [3:22:15<1:19:11, 339.41s/it][A
 74%|████████████████████████████          | 37/50 [3:27:57<1:13:42, 340.21s/it][A
 76%|████████████████████████████▉         | 38/50 [3:33:39<1:08:08, 340.67s

MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
83.3 & 81.6 & 89.3 & 48.7 & 49.2 & 59.3 & 52.2 & 50.1 & 54.0 & 53.8 & 49.3 & 66.3 & 61.4

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
59.6 & 63.6 & 63.9 & 59.1 & 62.3 & 60.9 & 61.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
71.1 & 67.6 & 85.7 & 1.1 & 0.0 & 26.2 & 7.8 & 2.1 & 12.8 & 15.4 & 0.0 & 34.7 & 27.0

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
23.8 & 31.0 & 32.0 & 23.3 & 29.6 & 25.6 & 27.6


MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
83.3 & 81.6 & 89.3 & 48.7 & 49.2 & 59.3 & 52.2 & 50.1 & 54.0 & 53.8 & 49.3 & 66.3 & 61.4

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
59.6 & 63.6 & 63.9 & 59.1 & 62.3 & 60.9 & 61.6
MultiLabelF1Score
AUS: ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU7', 'AU9', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'Average']
71.1 & 67.6 & 85.7 & 1.1 & 0.0 & 26.2 & 7.8 & 2.1 & 12.8 & 15.4 & 0.0 & 34.7 & 27.0

Datasets:  ['casme', 'casme2', 'samm', 'fourd', 'mmew', 'casme3a', 'Average']
23.8 & 31.0 & 32.0 & 23.3 & 29.6 & 25.6 & 27.6

In [9]:
class GAMEValidation(core.CrossDatasetValidation):
    def __init__(self, config: Config, verbose: bool = True):
        super().__init__(config)
        self.verbose = verbose
        self.disable_tqdm = True

    def extract_features(self, dataloader: torch.utils.data.DataLoader):
        self.model.eval()
        features_list = []
        for batch in dataloader:
            data_batch = batch[0].to(self.cf.device)
            labels_batch = batch[1]
            features = self.model(data_batch.float(), only_features=True)
            features_list.append(features.detach().cpu())
        self.model.train()
        features = torch.cat(features_list)
        return features
    
    def validate_split(self, df: pd.DataFrame, input_data: np.ndarray, labels: np.ndarray, split_name: str):
        train_data, train_labels, test_data, test_labels = self.split_data(
            df[self.split_column], input_data, labels, split_name
        )
        train_loader = self.get_data_loader(train_data, train_labels, train=True)
        test_loader = self.get_data_loader(test_data, test_labels, train=False)
        self.model = self.cf.model()
        self.model.to(self.cf.device)
        self.criterion = self.cf.criterion()
        self.optimizer = self.cf.optimizer(self.model.parameters())
        self.scheduler = self.cf.scheduler(self.optimizer) if self.cf.scheduler else None
        self.mixup_fn = self.cf.mixup_fn() if self.cf.mixup_fn else None
        
        self.train_model(train_loader, test_loader)
        # Create dataloader from all data and extract features
        data_loader = self.get_data_loader(input_data, labels, train=False)
        features = self.extract_features(data_loader)
        #Classify with genetic algorithm
        ga = GeneticAlgorithm(features, labels, df[self.split_column] == split_name, multi_label=True,
                              population_size=self.cf.game_d, d=self.cf.game_d)
        prediction = ga.perform()
        test_metrics = self.evaluation_fn(torch.tensor(labels[df[self.split_column] == split_name]),
                                     torch.tensor(prediction))
        return [], test_metrics, torch.tensor(prediction)