# Data imbalance analysis
## Comparing the impact of dataset imbalance on classification performance metrics
In this notebook, we are going to explore how data imbalance affect classification scores on four different tasks, which have been chosen to exemplify typical use cases of ML analysis in neuroscience. To keep it simple, we will focus on binary classification problems (0 vs 1).

The four tasks are :
1. Synthetic data
2. EEG alpha oscillations (resting-state Eyes-Closed vs Eyes-Open)
3. MEG alpha oscillations (auditory vs visual stimulation)
4. MEG alpha oscillations (faces vs scrambled) 

In these tasks, we will observe the effect of data imbalance on 4 different performance metrics :   
1. Decoding Accuracy (Acc)  
2. Area Under the Curve (AUC)
3. F1
4. Balanced Accuracy (BAcc)

Finally, a few parameters of the classification pipeline must be kept in mind as they can also differentially impact performance on imbalanced data. Namely :
- Dataset size
- Classifier type
    - Support Vector Machine (SVM)
    - Linear Discriminant Analysis (LDA)
    - Logistic Regression (LR)
    - Random Forest (RF)
- Cross-validation scheme
    - K-Fold, k=5
    - Stratified K-Fold
    - Group K-Fold

## Imports
First, we start by importing functions from the provided toolbox as well as some useful plotting functions.

In [1]:
from imbalance.pipeline import Pipeline
from imbalance.viz import metric_balance, data_distribution, plot_different_cvs, plot_different_n
from imbalance.data import eegbci, gaussian_binary
from joblib import Parallel, delayed
from sklearn.model_selection import (
    KFold,
    StratifiedKFold,
    StratifiedGroupKFold,
)
import os
import numpy as np
import matplotlib.pyplot as plt
import string
from copy import deepcopy
import warnings
warnings.simplefilter("ignore", RuntimeWarning)

## Task 1 : Synthetic data
For that first classification task, we will generate data from two gaussian distributions, with means of 0 and 2.

In [None]:
pls = []

def run(distance):
    # generate random data
    x, y, groups = gaussian_binary(n_samples_per_class=500, mean_distance=distance)
    # run the pipeline
    pl = Pipeline(
        x,
        y,
        groups,
        dataset_balance=np.linspace(0.1, 0.9, 25),
        classifiers=["lr", "lda", "svm", "rf"],
        n_permutations=100,
        n_init=10,
    )
    pl.evaluate()
    return deepcopy(pl)



pls = Parallel(n_jobs=3)(delayed(run)(dist) for dist in [0, 1, 3])

In [None]:
# generate random data
x, y, groups = gaussian_binary(n_samples_per_class=1500, mean_distance=1)
# run the pipeline
pl_nsamples = Pipeline(
    x,
    y,
    groups,
    dataset_balance=np.linspace(0.1, 0.9, 25),
    classifiers=["svm"],# "rf"]
    n_permutations=0,
    n_init=10,
    dataset_size=(0.1, 0.33, 1),
)
pl_nsamples.evaluate()

fitting classifiers:  39%|████████             | 290/750 [01:30<01:51,  4.14it/s, size=1, balance=0.8, classifier=SVC]

In [None]:

pls_crossvals = {}
cv_names = ["KFold", "Stratified"]#, "Stratified Group"]
cvs = [KFold(n_splits=5), StratifiedKFold(n_splits=5)]#, StratifiedGroupKFold(n_splits=5)]
# generate random data
x, y, groups = gaussian_binary(n_samples_per_class=50, n_groups=5, mean_distance=1)
# run the pipeline
for idx_cv, cross_val in enumerate(cvs):
    pl = Pipeline(
        x,
        y,
        groups,
        dataset_balance=np.linspace(0.1, 0.9, 25),
        classifiers=["lr", "lda", "svm", "rf"],
        n_permutations=1,
        n_init=40,
        cross_validation=cross_val,
    )
    pl.evaluate()
    pls_crossvals[cv_names[idx_cv]] = deepcopy(pl)
    




In [None]:
# visualize the result
fig, axes = plt.subplots(6, 3, figsize=(20, 30), dpi=300)
figtitle = "Synthetic data"
#fig.suptitle(figtitle, fontsize=25)
classifiers=["lr", "lda", "svm", "rf"]

show_leg_distrib=True
show_leg_metric=True
for ax_idx,ax in enumerate(axes.flat):
    ax.text(-0.1, 1.05, string.ascii_uppercase[ax_idx], transform=ax.transAxes, 
            size=20, weight='bold')
    if ax_idx in [0,1,2]:
        data_distribution(pls[ax_idx], ax=ax, show=False, show_leg=show_leg_distrib)
        show_leg_distrib=False
    elif ax_idx < 15:
        metric_balance(pls[ax_idx%3], ax=ax, show=False, classifier=classifiers[(ax_idx-3)//3], show_leg=show_leg_metric)
        show_leg_metric = False
    elif ax_idx == 15:
        plot_different_n(pl_nsamples, ax=ax, show=False, classifier="svm", show_leg=True, metric="accuracy")
    elif ax_idx == 16:
        plot_different_cvs(pls_crossvals, ax=ax, show=False, classifier="svm", show_leg=True, metric="accuracy")

In [None]:
metrics=[ "roc_auc","accuracy", "f1", "balanced_accuracy"]
fig, axes = plt.subplots(2, 2, figsize=(20, 10))

#pls_crossvals_light = pls_crossvals[:2]
for ax_idx,ax in enumerate(axes.flat):
    plot_different_cvs(pls_crossvals, ax=ax, show=False, classifier="svm", show_leg=True, metric=metrics[ax_idx])

In [None]:
plot_different_n(pl_nsamples, show=False, classifier="svm", show_leg=True, metric="accuracy")

## Task 2 : EEG analysis

In [None]:
#pipeline_path="../imbalance/data/eeg.pickle"
features_path ="../imbalance/data/eeg_features.npy"

# load or generate dataset
if not os.path.isfile(features_path):
    x, y, groups = eegbci('../imbalance/data',roi=lambda x: x[0] in ['P','O'])
    np.save(features_path,dict(x=x, y=y, groups=groups))
else:
    features = np.load(features_path,allow_pickle=True).item()
    x, y, groups = features["x"] , features["y"] , features["groups"]

In [None]:
pl = Pipeline(
    x,
    y,
    groups,
    dataset_balance=np.linspace(0.1, 0.9, 25),
    classifiers=["lda","svm","lr"],
    metrics=[ "roc_auc","accuracy", "f1", "balanced_accuracy"],
)
# fit and evaluate classifiers on dataset configurations
pl.evaluate()

In [None]:
# visualize the result
fig, axes = plt.subplots(2, 3, figsize=(20, 10))
figtitle = "EEG data"
#fig.suptitle(figtitle, fontsize=25)
classifiers=["lr", "lda", "svm", "rf"]

for ax_idx,ax in enumerate(axes.flat):
    ax.text(-0.1, 1.05, string.ascii_uppercase[ax_idx], transform=ax.transAxes, 
            size=20, weight='bold')
    if ax_idx == 0:
        data_distribution(pl, ax=ax, show=False)
    elif ax_idx < 6:
        metric_balance(pl, ax=ax, show=False, classifier=classifiers[ax_idx-1])

## Task 3 : MEG 1

In [None]:
# visualize the result
fig, axes = plt.subplots(2, 3, figsize=(20, 10))
figtitle = "MEG CAMCAN data"
#fig.suptitle(figtitle, fontsize=25)
classifiers=["lr", "lda", "svm", "rf"]

for ax_idx,ax in enumerate(axes.flat):
    ax.text(-0.1, 1.05, string.ascii_uppercase[ax_idx], transform=ax.transAxes, 
            size=20, weight='bold')
    if ax_idx == 0:
        data_distribution(pl, ax=ax, show=False)
    elif ax_idx < 5:
        metric_balance(pl, ax=ax, show=False, classifier="svm")

## Task 4 : MEG 2

In [None]:
# generate random data
x, y, groups = gaussian_binary(n_samples_per_class=1500, n_groups=5, mean_distance=1)
# run the pipeline
pl_nsamples = Pipeline(
    x,
    y,
    groups,
    dataset_balance=np.linspace(0.1, 0.9, 25),
    classifiers=["svm"],# "rf"],
    metrics=["roc_auc", "accuracy", "f1", "balanced_accuracy"],
    n_permutations=0,
    n_init=1,
    dataset_size=(0.1, 0.33, 1)
)
pl_nsamples.evaluate()