In [None]:
import numpy as np
import pandas as pd
import random
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import seaborn as sns
import cebra
from cebra import CEBRA

from dataframe.csv_utils import (
    load_data_from_csv,
)
from data_utils import (
    load_data_from_dir,
)
from labels import get_behavioral_labels
from plotting import subplot_confusion_matrix
from constants import SUEJECT_BATCHES, AUDIO_BLOCKS
from features.constants import Feature, MARKER_TO_FEATURE

data_dir = "../CleandDataV2/"
random.seed(33)

## Helper functions

In [None]:
from features.psd import get_psd_by_channel, get_psd
from biomarkers import EEG_CHANEL_NAMES
from labels import get_raw_labels, get_label_means
from feature_extraction import EEG_BANDS
from resample.resample import (
    get_validation_indexes,
    upsample_by_attention,
    get_resampled_list_index,
    slice_data_by_seconds,
)


def get_psd_by_channel_band(block_data, channel_type: str, srate: int, feature):
    psd_data = []
    num_trials = block_data.shape[0]
    # loop through all trials: time -> frequency
    for t in range(num_trials):
        all_channel_psd = []
        for i, c in enumerate(EEG_CHANEL_NAMES):
            if not c.startswith(channel_type):
                continue

            psd = get_psd(block_data[t, i, :], srate, EEG_BANDS[feature])
            all_channel_psd.append(psd)

        all_channel_psd = np.concatenate(all_channel_psd)
        psd_data.append(all_channel_psd)

    return np.stack(psd_data, axis=0)


def get_features(block_data, marker, channel_type: str, srate: int, feature):
    if marker == "EEG":
        return get_psd_by_channel_band(block_data, channel_type, srate, feature)
    if feature == Feature.ECG_HF or feature == Feature.EGG_FILTERED:
        return block_data[:, 0, :]
    elif feature == Feature.ECG_LF or feature == Feature.EGG_PHASE:
        return block_data[:, 1, :]
    elif feature == Feature.ECG_LFHF or feature == Feature.EGG_AMPLITUDE:
        return block_data[:, 2, :]


def get_block_features(
    blocks, subject_data, marker, channel, feature, with_sliced: bool = False
):
    features = []

    for b in blocks:
        block_data = subject_data[b]
        if with_sliced:
            srate = block_data.get_srate(marker)
            sliced_data = slice_data_by_seconds(
                block_data.get_all_data()[marker], srate, 4
            )
            psd_data = get_features(sliced_data, marker, channel, srate, feature)
        else:
            psd_data = get_psd_by_channel(block_data, marker, channel, feature)

        features = np.vstack((psd_data, features)) if len(features) > 0 else psd_data

    return features


def get_categorical_labels(blocks, subject_data):
    behavioral_labels = []

    for b in blocks:
        block_data = subject_data[b]
        v_label = block_data.get_labels()
        a_label = block_data.get_labels("arousal")

        labels = [
            get_behavioral_labels(v_label[i], a_label[i]) for i in range(len(v_label))
        ]
        behavioral_labels.extend(labels)

    return behavioral_labels


def get_label_category(labels, label_type, v_thred, a_thred):
    threshold = a_thred if label_type == "arousal" else v_thred
    return [0 if p < threshold else 1 for p in labels]


def get_channel_feature_to_data(subject_data, marker: str = "EEG"):
    sliced_channel_feature_to_data = {"A": {}, "B": {}, "C": {}, "D": {}}
    for c in sliced_channel_feature_to_data.keys():
        for f in EEG_BANDS.keys():
            raw_data = get_block_features(
                AUDIO_BLOCKS, subject_data, marker, c, f, True
            )
            sliced_channel_feature_to_data[c][f] = raw_data

    return sliced_channel_feature_to_data


def get_feature_to_data(subject_data, marker: str = "EEG"):
    if marker == "EEG":
        return get_channel_feature_to_data(subject_data, marker)

    sliced_feature_to_data = {marker: {f: {} for f in MARKER_TO_FEATURE[marker]}}
    for f in sliced_feature_to_data[marker].keys():
        raw_data = get_block_features(AUDIO_BLOCKS, subject_data, marker, "", f, True)
        sliced_feature_to_data[marker][f] = raw_data

    return sliced_feature_to_data

## Load data and process features

### Slicing

In [None]:
import os

marker = "EEG"
num_slice_per_trial = 5

# [2001, 2003, 2017, 2026, 2028, 2033,  2037, 2041]
subject_list = []
marker_features = []
valence_labels, arousal_labels, attention_labels = [], [], []
label_thresholds = []
for d in os.listdir(data_dir):
    dir_name = data_dir + d
    if not os.path.isdir(dir_name):
        continue

    subject_data = load_data_from_dir(dir_name)
    features = get_feature_to_data(subject_data, marker)
    vl, arl, atl = get_raw_labels(AUDIO_BLOCKS, subject_data, num_slice_per_trial)

    subject_list.append(d)
    marker_features.append(features)
    valence_labels.append(vl)
    arousal_labels.append(arl)
    attention_labels.append(atl)
    label_thresholds.append(get_label_means(subject_data))

In [None]:
# Read dictionary pkl file
with open('./eeg_features2/all_features.pkl', 'rb') as fp:
    marker_features = pickle.load(fp)
    print('Person dictionary')
marker_features

### No slicing

In [None]:
marker = "EEG"
# [2001, 2003, 2017, 2026, 2028, 2033,  2037, 2041]
subj = 2041
subject_data = subj_to_data[subj]
num_slice_per_trial = 1

channel_feature_to_data = {"A": {}, "B": {}, "C": {}, "D": {}}
for c in ["A", "B", "C", "D"]:
    for f in EEG_BANDS.keys():
        raw_data = get_block_features(AUDIO_BLOCKS, subject_data, marker, c, f)
        channel_feature_to_data[c][f] = raw_data

## Model Training

### Training utils

In [None]:
import ipywidgets as widgets
from plotting import plot_roc_curve
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score, accuracy_score
# from importlib import reload
# import features.constants

# reload(features.constants)
from features.constants import Feature

# CEBRA AND PCA hyper-parameters
OUTPUT_DIM = 8
MAX_HIDDEN_UNITS = 256


def model_fit(
    neural_data,
    out_dim,
    num_hidden_units,
    behavioral_labels,
    max_iterations: int = 10,
    max_adapt_iterations: int = 10,
):
    single_cebra_model = CEBRA(
        # model_architecture = "offset10-model",
        batch_size=512,
        output_dimension=out_dim,
        max_iterations=max_iterations,
        num_hidden_units=num_hidden_units,
        max_adapt_iterations=max_adapt_iterations,
    )

    if behavioral_labels is None:
        single_cebra_model.fit(neural_data)
    else:
        single_cebra_model.fit(neural_data, behavioral_labels)
    # cebra.plot_loss(single_cebra_model)
    return single_cebra_model


def get_embeddings(
    train_data,
    val_data,
    train_labels,
    use_pca: bool = False,
    out_dim: int = 16,
    num_hidden_units: int = 256,
):
    if use_pca:
        # Run PCA
        pca = PCA(n_components=out_dim)
        pca = pca.fit(train_data)
        return pca.transform(train_data), pca.transform(val_data)

    single_cebra_model = model_fit(train_data, out_dim, num_hidden_units, train_labels)

    # Calculate embedding
    embedding = single_cebra_model.transform(train_data)
    val_embedding = single_cebra_model.transform(val_data)
    return embedding, val_embedding


def _train_test_split(data, labels, attention_labels, val_indexes: list = []):
    if len(val_indexes) == 0:
        val_indexes = get_validation_indexes()

    train_indexes = list(set(range(len(labels))) - set(val_indexes))

    resampled_list = get_resampled_list_index(train_indexes, attention_labels)

    train_labels = np.array(labels)[train_indexes][resampled_list]
    train_data = data[train_indexes][resampled_list]

    val_data = data[val_indexes]
    val_label = np.array(labels)[val_indexes]
    return train_data, train_labels, val_data, val_label


# output_dim, max_hidden_units only needed for CEBRA
def run_knn_decoder(
    dataset,
    method,
    threshold,
    output_dim,
    max_hidden_units,
):
    y_pred, y_pred_cat, all_embeddings = [], [], []
    for _, (train_data, train_labels, val_data, _) in enumerate(dataset):
        embedding, val_embedding = get_embeddings(
            train_data=train_data,
            val_data=val_data,
            train_labels=train_labels,
            use_pca=(method == "PCA"),
            out_dim=output_dim,
            num_hidden_units=max_hidden_units,
        )
        all_embeddings.append(embedding)
        # 4. Train the decoder on training embedding and labels
        # train_true_cat = get_label_category(train_labels, label_type)
        decoder = cebra.KNNDecoder()
        decoder.fit(embedding, np.array(train_labels))

        # score = decoder.score(val_embedding, np.array(val_labels))
        prediction = decoder.predict(val_embedding)
        y_pred.append(prediction)
        y_pred_cat.append([0 if p < threshold else 1 for p in prediction])

    return y_pred, y_pred_cat, all_embeddings


def get_all_spectral_features(
    feature_to_data: dict, val_indexes, attention_labels, labels
):
    all_spetral_psd = [feature_to_data[f] for f in EEG_BANDS.keys()]
    all_spetral_psd = np.hstack(all_spetral_psd)
    return [
        _train_test_split(all_spetral_psd, labels, attention_labels, val_indexes[i])
        for i in range(len(val_indexes))
    ]


def get_all_channel_features(data: dict, val_indexes, attention_labels, labels):
    all_spetral_psd = []
    for _, feature_to_data in data.items():
        spetral_psd = [feature_to_data[f] for f in EEG_BANDS.keys()]
        all_spetral_psd.extend(spetral_psd)

    all_spetral_psd = np.hstack(all_spetral_psd)
    return [
        _train_test_split(all_spetral_psd, labels, attention_labels, val_indexes[i])
        for i in range(len(val_indexes))
    ]


def prepare_dataset(
    data,
    val_indexes,
    attention_labels,
    labels,
    has_all_spectral: bool = False,
    filtered_channel: str = "",
):
    if filtered_channel == "ALL":
        return {
            "ALL": {
                Feature.ALL_SPECTRAL: get_all_channel_features(
                    data, val_indexes, attention_labels, labels
                )
            }
        }

    dataset_dict = {k: {} for k in data.keys()}
    for channel, feature_to_data in data.items():
        if len(feature_to_data) == 0 or (
            channel != filtered_channel and len(filtered_channel) > 0
        ):
            continue

        if has_all_spectral:
            dataset_dict[channel][Feature.ALL_SPECTRAL] = get_all_spectral_features(
                feature_to_data, val_indexes, attention_labels, labels
            )
            continue

        for f, neural_data in feature_to_data.items():
            # Prepare the data
            dataset_dict[channel][f] = [
                _train_test_split(neural_data, labels, attention_labels, val_indexes[i])
                for i in range(len(val_indexes))
            ]
    return dataset_dict


def set_pane_axis(ax):
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 0)
    ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 0)
    ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 0)
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
    ax.zaxis.set_ticks([])

def decode_marker_data(
    dataset_dict,
    label_type,
    v_thred,
    a_thred,
    method,
    threshold,
    plot_roc: bool = False,
    plot_embed: bool = False,
):
    f1_score_data = {k: {} for k in dataset_dict.keys()}
    accuracy = {k: {} for k in dataset_dict.keys()}
    list_embedding_tuple = []
    for channel, feature_to_data in dataset_dict.items():
        if len(feature_to_data) == 0:
            continue

        for f, dataset in feature_to_data.items():
            val_true_cat = [
                get_label_category(val_labels, label_type, v_thred, a_thred)
                for _, (_, _, _, val_labels) in enumerate(dataset)
            ]

            y_pred, val_pred_cat, all_embeddings = run_knn_decoder(
                dataset,
                method,
                threshold,
                OUTPUT_DIM,
                MAX_HIDDEN_UNITS,
            )

            if plot_roc:
                plot_roc_curve(y_pred, val_true_cat, method, label_type, channel, f)

            score = [
                f1_score(y_pred=val_pred_cat[i], y_true=val_true_cat[i])
                for i in range(len(val_pred_cat))
            ]
            ac_scores = [
                accuracy_score(y_pred=val_pred_cat[i], y_true=val_true_cat[i])
                for i in range(len(val_pred_cat))
            ]
            if plot_embed:
                max_score_index = np.array(ac_scores).argmax(axis=0)
                best_acc = round(ac_scores[max_score_index], 2)
                list_embedding_tuple.append(
                    (
                        f"{channel} {f.name} Acc:{best_acc}",
                        all_embeddings[max_score_index],
                        dataset[max_score_index][1],
                    )
                )

            f1_score_data[channel][f] = np.mean(score)
            accuracy[channel][f] = np.mean(ac_scores)

    if len(list_embedding_tuple) > 0:
        n_row, n_col = (2, 3) if len(list_embedding_tuple) == 6 else (1, 4)
        fig, axes = plt.subplots(
            nrows=n_row,
            sharey=True,
            ncols=n_col,
            figsize=(n_col * 5, n_row * 5),
            subplot_kw=dict(projection="3d"),
        )
        idx1, idx2, idx3 = (0, 1, 2)
        for idx, (title, embeddings, embedding_labels) in enumerate(list_embedding_tuple):
            y = axes.flat[idx].scatter(
                embeddings[:, idx1],
                embeddings[:, idx2],
                embeddings[:, idx3],
                cmap="cool",
                c=embedding_labels,
                s=5,
                vmin=0,
                vmax=1,
            )
            axes.flat[idx].set_title(title)
            yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(0, 1, 9))
            yc.ax.tick_params(labelsize=10)
            yc.ax.set_title("score", fontsize=10)    
            set_pane_axis(axes.flat[idx])        
            
        fig.suptitle(f'{method} - {label_type} Latents: (1,2,3)')    
       
    return f1_score_data, accuracy

### Cross Validation

In [None]:
from resample.resample import get_consecutive_validation_indexes

n_step_trial = 3
val_indexes = [
    get_consecutive_validation_indexes(
        len(valence_labels[0]), len(AUDIO_BLOCKS), num_slice_per_trial, i, n_step_trial
    )
    for i in range(1, 13, n_step_trial)
]
print(len(val_indexes), val_indexes)

### Get subjects summary

In [None]:
def get_feature_names_and_mean_scores(
    dataset_dict, accuracy, marker: str, filtered_channel: str
):
    if marker != "EEG":
        return list(dataset_dict[marker].keys()), [
            accuracy[marker][f] for f in all_feature_name
        ]

    if len(filtered_channel) == 0:
        all_feature_name = list(dataset_dict.keys())
        mean_scores = [accuracy[c][Feature.ALL_SPECTRAL] for c in all_feature_name]
    else:
        all_feature_name = list(dataset_dict[filtered_channel].keys())
        mean_scores = [accuracy[filtered_channel][f] for f in all_feature_name]

    return all_feature_name, mean_scores


subject_accuracy_summary = {
    "subject": [],
    "channel": [],
    "label_type": [],
    "cv_mean_score": [],
}
filtered_channel = "D"

for idx in range(len(subject_list)):
    print('decoding subject...', subject_list[idx])

    v_thred, a_thred = label_thresholds[idx]
    for lt in ["valence", "arousal"]:
        labels = valence_labels[idx] if lt == "valence" else arousal_labels[idx]
        thred = v_thred if lt == "valence" else a_thred
        dataset_dict = prepare_dataset(
            marker_features[idx],
            val_indexes,
            attention_labels[idx],
            labels,
            False,
            filtered_channel,
        )

        f1_score_data, accuracy = decode_marker_data(
            dataset_dict, lt, v_thred, a_thred, "CEBRA", thred, False, False
        )

        all_feature_name, mean_scores = get_feature_names_and_mean_scores(
            dataset_dict, accuracy, marker, filtered_channel
        )

        subject_accuracy_summary["subject"].extend(
            [subject_list[idx]] * len(all_feature_name)
        )
        subject_accuracy_summary["channel"].extend(all_feature_name)
        subject_accuracy_summary["cv_mean_score"].extend(mean_scores)
        subject_accuracy_summary["label_type"].extend([lt] * len(all_feature_name))

subject_accuracy_summary = pd.DataFrame(subject_accuracy_summary)
subject_accuracy_summary["subject"] = subject_accuracy_summary["subject"].astype(int)

In [None]:
subject_accuracy_summary["channel"] = subject_accuracy_summary["channel"].astype(str)
subject_accuracy_summary.to_csv('CEBRA_D8_U256_EEG_D_spectral.csv')

In [None]:
data = subject_accuracy_summary
title = "EEG D channel - CEBRA"  #
g = sns.swarmplot(
    data=data,
    x="label_type",
    y="cv_mean_score",
    hue="channel",
    alpha=0.6,
    dodge=True,
    legend=False,
)
g.set_ylim((0.35, 1))

df_means = (
    data.groupby(["label_type", "channel"])["cv_mean_score"].agg("mean").reset_index()
)
pp = sns.pointplot(
    x="label_type",
    y="cv_mean_score",
    data=df_means,
    hue="channel",
    dodge=0.6,
    linestyles="",
    errorbar=None,
    scale=2.5,
    markers="_",
    hue_order=[
        "Feature.DELTA",
        "Feature.THETA",
        "Feature.ALPHA",
        "Feature.BETA1",
        "Feature.BETA2",
        "Feature.GAMMA",
    ],
    order=["valence", "arousal"],
)

sns.move_legend(pp, "upper right", bbox_to_anchor=(1.4, 1))
g.set_title(title)

In [None]:
aduio_ari = pd.read_csv('../aduio_ari.csv')
print('mean:',aduio_ari['ari'].mean(), 'max:', aduio_ari['ari'].max(), 'min:', aduio_ari['ari'].min())

high_c = aduio_ari[aduio_ari['ari'] > aduio_ari['ari'].mean()]
low_c = aduio_ari[aduio_ari['ari'] <= aduio_ari['ari'].mean()]
print(len(high_c), len(low_c))

ari_scores = []
for s in list(subject_accuracy_summary['subject']):
    score = aduio_ari[aduio_ari['subject'] == int(s)]['ari'].values[0]
    ari_scores.append(score)
subject_accuracy_summary['ari_scores'] = ari_scores
subject_accuracy_summary.head()

In [None]:
sns.relplot(
    data=subject_accuracy_summary,
    x="ari_scores", y="cv_mean_score", hue="label_type", col="channel",
)

### Run for single subject

#### Process labels

In [None]:
label_selector = widgets.Dropdown(
    options=["valence", "hybrid", "arousal"],
    value="valence",
    description="label_type:",
    disabled=False,
)
label_selector

### Leave last two trials in each block as validation

In [None]:
from sklearn.metrics import confusion_matrix

label_type = label_selector.value
validation_list = [10, 11, 12, 23, 24, 25, 36, 37, 38, 49, 50, 51]
train_list = [i for i in range(len(valence_labels)) if i not in validation_list]

train_attention_labels = np.array(attention_labels)[train_list]
resampled_list = upsample_by_attention(train_attention_labels, 52 * 4)
train_labels = np.array(labels)[train_list][resampled_list]

validation_labels = np.array(labels)[validation_list]

train_true_cat = get_label_category(train_labels, label_type, v_thred, a_thred)
val_true_cat = get_label_category(validation_labels, label_type, v_thred, a_thred)

r2_score_data = {"A": {}, "B": {}, "C": {}, "D": {}}
for channel, feature_to_data in channel_feature_to_data.items():
    if len(feature_to_data) == 0:
        continue

    nrows = 2
    ncols = int(len(feature_to_data) / 2)
    fig, axes = plt.subplots(
        nrows=nrows,
        sharey=True,
        ncols=ncols,
        figsize=(ncols * 5, nrows * 5),
    )

    idx = 0
    for f, neural_data in feature_to_data.items():
        train_data = neural_data[train_list][resampled_list]
        val_data = neural_data[validation_list]

        # loss_data[channel][f.name] = single_cebra_model.state_dict_["loss"]
        embedding, val_embedding = get_embeddings(
            train_data=train_data,
            val_data=val_data,
            train_labels=train_labels,
            use_pca=use_pca,
            out_dim=OUTPUT_DIM,
            num_hidden_units=MAX_HIDDEN_UNITS,
        )
        # 4. Train the decoder on training embedding and labels
        decoder = cebra.KNNDecoder()
        decoder.fit(embedding, np.array(train_true_cat))

        # 5. Compute the score on validation embedding and labels
        score = decoder.score(val_embedding, np.array(val_true_cat))
        r2_score_data[channel][f.name] = score
        # 5. Get the discrete labels predictions
        prediction = decoder.predict(val_embedding)
        print(channel, f, score)
        # print('pre', prediction)
        # print('true', val_true_cat)
        # print('--------')
        cm = confusion_matrix(val_true_cat, prediction)
        subplot_confusion_matrix(
            ax=axes.flat[idx],
            cf=cm,
            categories=[f"low {label_type}", f"high {label_type}"],
            percent="by_row",
            vmin=0,
            vmax=1,
        )
        axes.flat[idx].set_title(f"{channel}:{f.name}")
        idx += 1

## MiSC: Loss/Grid search

In [None]:
nrows = 1
ncols = len(loss_data)
fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    sharey=True,
    figsize=(ncols * 5, nrows * 5),
)

fig.suptitle(
    f"Subject {subj}: {marker} {label_type} InfoNCE loss in audio condition",
    fontsize=15,
)

for c, ax in zip(loss_data.keys(), axes.flatten()):
    df = pd.DataFrame(loss_data[c])
    sns.lineplot(data=df, ax=ax)
    ax.set_title("channel:" + c)
    ax.set_ylabel("InfoNCE Loss")
    ax.set_xlabel("Steps")
# plt.savefig(f"results/cebra/{label_type}_{subj}_eeg_bands_channel_loss_{channel}_O{output_dim}H{max_hidden_units}.png")

In [None]:
loss_dict = {"InfoNCE Loss": [], "band": [], "channel": [], "Steps": []}
for c, f_to_data in loss_data.items():
    for f, data in f_to_data.items():
        loss_dict["InfoNCE Loss"].extend(np.array(data))
        loss_dict["Steps"].extend(np.arange(0, len(data), dtype=int))
        loss_dict["band"].extend([f] * len(data))
        loss_dict["channel"].extend([c] * len(data))

loss_dict = pd.DataFrame(loss_dict)
loss_dict

In [None]:
sns.lineplot(data=loss_dict, y="InfoNCE Loss", x="Steps", hue="band", style="channel")

In [None]:
# 1. Define the parameters, either variable or fixed
params_grid = dict(
    output_dimension=[6, 8],
    learning_rate=[3e-4],
    max_iterations=10,
    num_hidden_units=[32, 64, 128, 256],
    max_adapt_iterations=10,
    temperature_mode="auto",
    verbose=False,
)

# 2. Define the datasets to iterate over
datasets = {
    "neural_data": channel_feature_to_data["C"][Feature.THETA],
}

# 3. Create and fit the grid search to your data
grid_search = cebra.grid_search.GridSearch()
grid_search = grid_search.fit_models(
    datasets=datasets, params=params_grid, models_dir="saved_models"
)

# 4. Get the results
df_results = grid_search.get_df_results(models_dir="saved_models")
# 5. Get the best model for a given dataset
best_model, best_model_name = grid_search.get_best_model(
    dataset_name="neural_data", models_dir="saved_models"
)