In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

import cebra
from cebra import CEBRA
import seaborn as sns
from dataframe.csv_utils import (
    load_data_from_csv,
)

from labels import get_behavioral_labels

from constants import SUEJECT_BATCHES, SORTED_BLOCK_NAMES, V_COLOR_MAP, AUDIO_BLOCKS

In [None]:
# import importlib
# import labels
# importlib.reload(labels)

In [None]:
from data_utils import (
    load_data_from_dir,
)

subject_list = [2001, 2003, 2017, 2026, 2028, 2033,  2037, 2041]
subj_to_data = {}
data_dir = '../CleandDataV2/'
for subj in subject_list:
    dir_name = data_dir + str(subj)
    all_data = load_data_from_dir(dir_name)
    subj_to_data[subj] = all_data

In [None]:
from features.psd import welch_bandpower
from feature_extraction import EEG_BANDS, Feature


def get_psd(trial_data, srate, band):
    low, high = band
    freqs, psd = welch_bandpower(trial_data, srate, None, 2)

    # Find closest indices of band in frequency vector
    idx_band = np.logical_and(freqs >= low, freqs <= high)

    return psd[idx_band]


def get_psd_by_channel(block_data, marker, channel_type: str, feature: Feature):
    psd_data = []
    time_series_data = block_data.get_all_data()[marker]

    # loop through all trials: time -> frequency
    for t in range(time_series_data.shape[2]):
        all_channel_psd = []
        for i, c in enumerate(block_data.get_chanlocs(marker)):
            if not c.startswith(channel_type):
                continue

            data = time_series_data[i]
            psd = get_psd(data[:, t], block_data.get_srate(marker), EEG_BANDS[feature])
            all_channel_psd = (
                np.hstack((all_channel_psd, psd)) if len(all_channel_psd) > 0 else psd
            )

        psd_data = (
            np.vstack((psd_data, all_channel_psd))
            if len(psd_data) > 0
            else all_channel_psd
        )

    return psd_data


def get_time_series_data_by_channel(block_data, marker, channel_type: str):
    channel_data = []
    time_series_data = block_data.get_all_data()[marker]

    # loop through all trials: time
    for t in range(time_series_data.shape[2]):
        all_channel_data = []
        for i, c in enumerate(block_data.get_chanlocs(marker)):
            if not c.startswith(channel_type):
                continue
            # data is in the shape of (12288, 13)
            data = np.array(time_series_data[i][:, t])
            all_channel_data = (
                np.vstack((all_channel_data, data))
                if len(all_channel_data) > 0
                else data
            )

        all_channel_data = np.swapaxes(all_channel_data, 0, -1)

        channel_data = (
            np.vstack((channel_data, all_channel_data))
            if len(channel_data) > 0
            else all_channel_data
        )

    return channel_data


def get_block_features(blocks, subject_data, marker, channel, feature):
    raw_data = []

    for b in blocks:
        block_data = subject_data[b]
        psd_data = get_psd_by_channel(block_data, marker, channel, feature)
        raw_data = np.vstack((psd_data, raw_data)) if len(raw_data) > 0 else psd_data

    return raw_data


def get_categorical_labels(blocks, subject_data):
    behavioral_labels = []

    for b in blocks:
        block_data = subject_data[b]
        v_label = block_data.get_labels()
        a_label = block_data.get_labels("arousal")

        labels = [
            get_behavioral_labels(v_label[i], a_label[i]) for i in range(len(v_label))
        ]
        behavioral_labels.extend(labels)

    return behavioral_labels


def get_raw_labels(blocks, subject_data):
    v_labels = []
    a_labels = []

    for b in blocks:
        block_data = subject_data[b]

        v_labels.extend(block_data.get_labels().flatten())
        a_labels.extend(block_data.get_labels("arousal").flatten())

    return v_labels, a_labels


def get_block_time_series_features(blocks, subject_data, marker, channel):
    raw_data = []
    behavioral_labels = []

    for b in blocks:
        block_data = subject_data[b]
        psd_data = get_time_series_data_by_channel(block_data, marker, channel)

        v_label = block_data.get_labels()
        a_label = block_data.get_labels("arousal")
        # (TODO) extend the labels to match the time series dimention
        labels = [
            get_behavioral_labels(v_label[i], a_label[i]) for i in range(len(v_label))
        ]
        behavioral_labels.extend(labels)

        raw_data = np.vstack((psd_data, raw_data)) if len(raw_data) > 0 else psd_data

    return raw_data, behavioral_labels

In [None]:
marker = "EEG"
# [2001, 2003, 2017, 2026, 2028, 2033,  2037, 2041]
subj = 2033
subject_data = subj_to_data[subj]
use_time_series = False

In [None]:
channel_feature_to_time_series_data = {}
channel_feature_to_data = {"A": {}, "B": {}, "C": {}, "D": {}}
if use_time_series:
    for c in ["A", "B", "C", "D"]:
        channel_feature_to_time_series_data[c], _ = get_block_time_series_features(AUDIO_BLOCKS, subject_data, marker, c)
else:
    for c in ["A", "B", "C", "D"]:
        for f in EEG_BANDS.keys():
            raw_data = get_block_features(
                AUDIO_BLOCKS, subject_data, marker, c, f
            )
            channel_feature_to_data[c][f] = raw_data

    valence_labels, arousal_labels = get_raw_labels(AUDIO_BLOCKS, subject_data)
    behavioral_labels = get_categorical_labels(AUDIO_BLOCKS, subject_data)
    print(valence_labels)

### Load EEG average power spectral features

In [None]:
dir_name = "eeg_features2"
result = load_data_from_csv(dir_name)

In [None]:
subjects = result['Subject'].unique()

all_blocks = []
for b in SORTED_BLOCK_NAMES:
    all_blocks.extend([b] * 13)


result["condition"] = all_blocks * len(subjects)
mask = result["condition"].isin(AUDIO_BLOCKS)
audio_only = result[mask]
audio_only

In [None]:
# First subject only
subject_data = audio_only[audio_only['Subject'].isin([2000])]
neural_data = subject_data.drop(columns=['Valence', 'Arousal', 'Attention', 'Subject', 'condition'])
v_label = subject_data['Valence']
a_label = subject_data['Arousal']

behavioral_labels = [get_behavioral_labels(v_label[i], a_label[i]) for i in range(len(v_label))]
print(neural_data.shape, len(behavioral_labels))

### Model Training

In [None]:
def model_fit(neural_data, out_dim, num_hidden_units, behavioral_labels, max_iterations: int=10, max_adapt_iterations: int=10):
    single_cebra_model = CEBRA(
        # model_architecture = "offset10-model",
        batch_size=512,
        output_dimension=out_dim,
        max_iterations=max_iterations,
        num_hidden_units=num_hidden_units,
        max_adapt_iterations=max_adapt_iterations,
    )

    if behavioral_labels is None:
        single_cebra_model.fit(neural_data)
    else:
        single_cebra_model.fit(neural_data, behavioral_labels)
    # cebra.plot_loss(single_cebra_model)
    return single_cebra_model

### Visualize time embeddings

In [None]:
nrows = 1
ncols = len(channel_feature_to_time_series_data)
fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(ncols * 5, nrows * 5),
    subplot_kw=dict(projection="3d"),
)
fig.suptitle(
    f"Subject {subj}: {marker} time latent embedding in audio condition",
    fontsize=15,
)

idx = 0
loss_data = {}
for channel, data in channel_feature_to_time_series_data.items():
    output_dim = 16
    max_hidden_units = 256
    single_cebra_model = model_fit(
        neural_data, output_dim, max_hidden_units, None, 150, 150
    )
    loss_data[channel] = single_cebra_model.state_dict_["loss"]
    embedding = single_cebra_model.transform(neural_data)

    ax = cebra.plot_embedding(
        embedding,
        embedding_labels='time',
        title=f"{channel} Channels",
        markersize=5,
        alpha=0.6,
        ax=axes.flat[idx],
    )
    idx += 1

#plt.savefig(f"results/cebra/cebra_{subj}_eeg_time_series_channel_{channel}_O{output_dim}H{max_hidden_units}.png")
loss_data = pd.DataFrame(loss_data)


In [None]:
sns.lineplot(data=loss_data)
plt.savefig(f"results/cebra/cebra_{subj}_eeg_time_series_infoloss.png")

In [None]:
cebra.plot_loss(single_cebra_model)

### Visualize embedding

In [None]:
import ipywidgets as widgets

label_selector = widgets.Dropdown(
    options=["valence", "hybrid", "arousal"],
    value='valence',
    description='label_type:',
    disabled=False,
)
label_selector

In [None]:
IDX_MAP = {
    "hvha": 0,
    "hvla": 1,
    "nvha": 2,
    "nvla": 3,
    "lvha": "blue",
    "lvla": "steelblue",
}

label_type = label_selector.value
cmap = [V_COLOR_MAP[l] for l in behavioral_labels]

if label_type == "valence":
    labels = valence_labels
elif label_type == "arousal":
    labels = arousal_labels
else:
    labels = [IDX_MAP[l] for l in behavioral_labels]

loss_data = {"A": {}, "B": {}, "C": {}, "D": {}}
for channel, feature_to_data in channel_feature_to_data.items():
    nrows = 2
    ncols = int(len(feature_to_data) / 2)
    fig, axes = plt.subplots(
        nrows=2,
        ncols=ncols,
        figsize=(ncols * 5, nrows * 5),
        subplot_kw=dict(projection="3d"),
    )

    idx = 0
    output_dim = 8
    max_hidden_units = 256

    for f, neural_data in feature_to_data.items():
        single_cebra_model = model_fit(
            neural_data, output_dim, max_hidden_units, np.array(labels)
        )
        loss_data[channel][f.name] = single_cebra_model.state_dict_["loss"]

        # plot embedding
        embedding = single_cebra_model.transform(neural_data)
        ax = cebra.plot_embedding(
            embedding,
            embedding_labels=np.array(cmap),
            title=f"{channel}:{f.name}",
            ax=axes.flat[idx],
            markersize=5,
            alpha=0.6,
        )
        idx += 1

    plt.legend(
        handles=[
            mpatches.Patch(color="red", label="hvha", alpha=0.6),
            mpatches.Patch(color="magenta", label="hvla", alpha=0.6),
            mpatches.Patch(color="green", label="nvha", alpha=0.6),
            mpatches.Patch(color="olive", label="nvla", alpha=0.6),
        ],
        loc="lower right",
        bbox_to_anchor=(1.05, 1),
    )

    fig.suptitle(
        f"Subject {subj}: {marker} {label_type} latent embedding in audio condition",
        fontsize=15,
    )

    plt.savefig(f"results/cebra/{label_type}_{subj}_eeg_bands_channel_{channel}_O{output_dim}H{max_hidden_units}.png")

In [None]:
nrows = 1
ncols = len(loss_data)
fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
     sharey=True,
    figsize=(ncols * 5, nrows * 5),
)

fig.suptitle(
    f"Subject {subj}: {marker} {label_type} InfoNCE loss in audio condition",
    fontsize=15,
)
       
for c, ax in zip(loss_data.keys(), axes.flatten()):
    df = pd.DataFrame(loss_data[c])
    sns.lineplot(data=df, ax=ax)
    ax.set_title('channel:' + c)
    ax.set_ylabel('InfoNCE Loss')
    ax.set_xlabel('Steps')
plt.savefig(f"results/cebra/{label_type}_{subj}_eeg_bands_channel_loss_{channel}_O{output_dim}H{max_hidden_units}.png")

In [None]:
loss_dict = {'InfoNCE Loss': [], 'band': [], 'channel': [], 'Steps': []}
for c , f_to_data in loss_data.items():
    for f, data in f_to_data.items():
        loss_dict['InfoNCE Loss'].extend(np.array(data))
        loss_dict['Steps'].extend(np.arange(0, len(data), dtype=int))
        loss_dict['band'].extend([f]*len(data))
        loss_dict['channel'].extend([c]*len(data))

loss_dict = pd.DataFrame(loss_dict)
loss_dict

In [None]:
sns.lineplot(data=loss_dict, y='InfoNCE Loss', x='Steps', hue='band', style='channel')

In [None]:
# 1. Define the parameters, either variable or fixed
params_grid = dict(
    output_dimension=[6, 8],
    learning_rate=[3e-4],
    max_iterations=10,
    num_hidden_units=[32, 64, 128, 256],
    max_adapt_iterations=10,
    temperature_mode="auto",
    verbose=False,
)

# 2. Define the datasets to iterate over
datasets = {
    "neural_data": channel_feature_to_data["C"][Feature.THETA],
}

# 3. Create and fit the grid search to your data
grid_search = cebra.grid_search.GridSearch()
grid_search = grid_search.fit_models(datasets=datasets, params=params_grid, models_dir="saved_models")

# 4. Get the results
df_results = grid_search.get_df_results(models_dir="saved_models")
# 5. Get the best model for a given dataset
best_model, best_model_name = grid_search.get_best_model(dataset_name="neural_data", models_dir="saved_models")

In [None]:
best_model_name

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122, projection="3d")

ax1 = cebra.plot_embedding(embedding, embedding_labels=np.array(cmap), idx_order=(1,2,3), title="Latents: (1,2,3)", ax=ax1, markersize=5, alpha=0.6)
ax2 = cebra.plot_embedding(embedding, embedding_labels=np.array(cmap), idx_order=(4,5,6), title="Latents: (4,5,6)", ax=ax2, markersize=5, alpha=0.6)