In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

from data_utils.data_utils import load_data_from_dir, AUDIO_BLOCKS
from features.labels import get_user_rating_raw_labels

In [None]:
from features.psd import get_psd_by_channel
from features.constants import EEG_BANDS

def get_block_features(
    blocks, subject_data, marker, channel, feature,
):
    features = []
    for b in blocks:
        block_data = subject_data[b]
        psd_data = get_psd_by_channel(block_data, marker, channel, feature)

        features = np.vstack((psd_data, features)) if len(features) > 0 else psd_data
    return features

def get_eeg_channel_feature_to_data(
    subject_data, block_list, feature_list,
):
    channel_feature_to_data = {"A": {}, "B": {}, "C": {}, "D": {}}
    for c in channel_feature_to_data.keys():
        for f in feature_list:
            raw_data = get_block_features(block_list, subject_data, "EEG", c, f)
            channel_feature_to_data[c][f.name] = raw_data

    return channel_feature_to_data

# Extract EEG spectral features

In [None]:
data_dir = "../data/"  # Replace with your own data dir

subject_list = []
valence_labels, arousal_labels, label_thresholds = [], [], []
marker_features = []
for i, d in enumerate(os.listdir(data_dir)):
    dir_name = data_dir + d
    if not os.path.isdir(dir_name):
        continue

    subject_data = load_data_from_dir(dir_name)
    features = get_eeg_channel_feature_to_data(
        subject_data, AUDIO_BLOCKS, EEG_BANDS.keys(),
    )
    marker_features.append(features)
    subject_list.append(d)
    # get user rating valence and arousal
    vl, arl, _ = get_user_rating_raw_labels(subject_data, AUDIO_BLOCKS)
    valence_labels.append(vl)
    arousal_labels.append(arl)
    label_thresholds.append((np.mean(vl), np.mean(arl)))

Optional save your features/labels

In [None]:
with open(f"result/eeg_features.pkl", "wb") as handle:
    d = {"eeg_features": marker_features}
    pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(f"result/behavioral_labels.pkl", "wb") as handle:
    d = {
        "valence_labels": valence_labels,
        "arousal_labels": arousal_labels,
        "label_thresholds": label_thresholds,
    }
    pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Model training with CEBRA + KNN

## Prepare training and validation dataset

In [None]:
from training_utils.dataset import get_consecutive_validation_indexes
from training_utils.dataset import DatasetBuilder

n_trial_per_block = 13
n_step_trial = 3
val_indexes = [
    get_consecutive_validation_indexes(
        len(valence_labels[0]), len(AUDIO_BLOCKS), 1, i, n_step_trial
    )
    for i in range(1, n_trial_per_block, n_step_trial)
]
print(len(val_indexes), val_indexes)

dataset_builder = DatasetBuilder(len(valence_labels[0]), val_indexes_group=val_indexes)
len(valence_labels[0])

In [None]:
from features.features_utils import prepare_eeg_data
    
def prepare_dataset(
    data,
    dataset_builder,
    labels,
    has_all_spectral: bool = False,
    filtered_channel: str = "",
):  
    
    data_dict = prepare_eeg_data(data, has_all_spectral, filtered_channel)
    dataset_dict = {k: {} for k in data_dict.keys()}
    for k, feature_to_data in data_dict.items():
        for f, fd in feature_to_data.items():
            dataset_dict[k][f] = dataset_builder.train_test_split(fd, labels)
    return dataset_dict

## Run cross validation

In [None]:
from training_utils.training import decode_marker_data, get_metadata

subject_accuracy_summary = {
    "subject": [],
    "channel": [],
    'feature': [],
    "label_type": [],
    "cv_scores": [],
    "cv_mean_score": [],
}

###CHANGE ME####
method = 'CEBRA'
filtered_channel = 'C'
combined_all_spectral = False
cebra_output_dim = 24
MAX_ITERATION  = 10
###############

subject_to_embedding = { s: {'valence': [], 'arousal': []} for s in subject_list}

for idx in range(len(subject_list)):
    subj = subject_list[idx]
    print('decoding subject...', subj)

    v_thred, a_thred = label_thresholds[idx]
    for lt in ["valence", "arousal"]:
        labels = valence_labels[idx] if lt == "valence" else arousal_labels[idx]
        thred = v_thred if lt == "valence" else a_thred

        dataset_dict = prepare_dataset(
            marker_features[idx],
            dataset_builder,
            labels,
            combined_all_spectral,
            filtered_channel,
        )

        subject_to_embedding[subj][lt], accuracy = decode_marker_data(
            dataset_dict, lt, v_thred, a_thred, method, cebra_output_dim, thred, MAX_ITERATION,
        )
        
        all_channels, all_feature_name, cv_scores = get_metadata(accuracy)

        subject_accuracy_summary["subject"].extend(
            [subj] * len(all_feature_name)
        )  
        subject_accuracy_summary["channel"].extend(all_channels)
        subject_accuracy_summary["feature"].extend(all_feature_name)      
        subject_accuracy_summary["cv_mean_score"].extend([round(np.mean(cv_scores), 2)])
        subject_accuracy_summary["cv_scores"].extend(cv_scores)
        subject_accuracy_summary["label_type"].extend([lt] * len(all_feature_name))

subject_accuracy_summary = pd.DataFrame(subject_accuracy_summary)
subject_accuracy_summary["subject"] = subject_accuracy_summary["subject"].astype(int)

## Save the results

In [None]:
print(subject_accuracy_summary[subject_accuracy_summary.label_type =='valence']['cv_mean_score'].mean())
print(subject_accuracy_summary[subject_accuracy_summary.label_type =='arousal']['cv_mean_score'].mean())

subject_accuracy_summary["channel"] = subject_accuracy_summary["channel"].astype(str)
subject_accuracy_summary.to_csv(f'results/{method}_eeg.csv')

subject_accuracy_summary.head()

## Evaluate your results

In [None]:
data = subject_accuracy_summary
g = sns.swarmplot(
    data=data,
    x="label_type",
    y="cv_mean_score",
    hue="channel",
    alpha=0.6,
    dodge=True,
    legend=False,
)
g.set_ylim((0.2, 1))
g.set_yticklabels(np.round(g.get_yticks(), 2), size = 15)
g.set_xticklabels(['valence', 'arousal'], size = 15)

df_means = (
    data.groupby(["label_type", "channel"])["cv_mean_score"].agg("mean").reset_index()
)

pp = sns.pointplot(
    x="label_type",
    y="cv_mean_score",
    data=df_means,
    hue="channel",
    dodge=0.6,
    linestyles="",
    scale=2.5,
    markers="_",
    order=["valence", "arousal"],
)
sns.despine(bottom = True, left = True)
g.axhline(0.5, color="red", dashes=(2, 2))


### Plot embeddings

In [None]:
colorMap = {'valence': '#94b325', 'arousal': '#595eeb'}
label_type = 'arousal'
best_embedding_idx = []

for i, s in enumerate(subject_list):
    (
        _,
        max_idx,
        (embeddings, val_embeddings),
        (embedding_labels, val_embedding_labels),
    ) = subject_to_embedding[s][label_type][0]
    abs_corr = []
    for idx in range(embeddings.shape[-1]):
        corr = np.corrcoef(embeddings[:, idx], embedding_labels)[0, 1]
        abs_corr.append(np.abs(corr))
    
    max_score_index = np.array(abs_corr).argmax(axis=0)
    best_embedding_idx.append(max_score_index)

n_row, n_col = (5, 8)
f, axarr = plt.subplots(n_row, n_col, figsize=(3 * n_col, 3 * n_row), sharey=True)
for idx, ax in enumerate(axarr.flat):
    s = subject_list[idx]
    l = f'L{best_embedding_idx[idx]}'
    (
        name,
        max_idx,
        (embeddings, val_embeddings),
        (embedding_labels, val_embedding_labels),
    ) = subject_to_embedding[s][label_type][0]

    result = pd.DataFrame({l: embeddings[:, best_embedding_idx[idx]], 'labels': embedding_labels})
    sns.regplot(data=result, ci=99, x=l, y='labels', color=colorMap[label_type], line_kws=dict(color="r"), ax=ax)    
    corr = np.corrcoef(result[l], embedding_labels)[0, 1]
    ax.text(
        0.1,
        0.95,
        "$r$ = {:.3f}".format(corr),
        horizontalalignment="left",
        verticalalignment="center",
        color='red',
        fontweight='heavy',
        transform=ax.transAxes,
        size=12,
    )    
    ax.set_title(f'{s}:{name}')
    ax.set(ylim=(0, 1))
f.tight_layout(pad=1.8)