In [None]:
# General imports
import glob
import os.path
import numpy as np

from task2_regression.models.happyquokka import happyquokka
from task2_regression.models.adt_env import ADT
from task2_regression.models.vlaai import vlaai
from task2_regression.models.eegnet_env import EEGNet
from task2_regression.models.fcnn_env import FCNN
from task2_regression.models.linear import simple_linear_model
from scipy.stats import pearsonr

In [2]:
def window_data(data, window_length, hop):
    """Window data into overlapping windows.

    Parameters
    ----------
    data: np.ndarray
        Data to window. Shape (n_samples, n_channels)
    window_length: int
        Length of the window in samples.
    hop: int
        Hop size in samples.

    Returns
    -------
    np.ndarray
        Windowed data. Shape (n_windows, window_length, n_channels)
    """
    new_data = np.empty(
        ((data.shape[0] - window_length) // hop, window_length, data.shape[1])
    )
    for i in range(new_data.shape[0]):
        new_data[i, :, :] = data[
            i * hop : i * hop + window_length, : 
        ]
    return new_data


In [14]:
adt = ADT(chans=64, outputDims=1, F=8, T=16, D=4, heads=4, ff_dim=128, blocks=4, mask=False, use_bias=False, lrate=0.5)
adt.build(input_shape=(None, 320, 64))
adt.load_weights('task2_regression/experiments/results_ertnet_env_nomask/model.h5')

# happyquokka = happyquokka(num_layers=4, embed_dim=64, num_heads=2, d_hid=256)
# happyquokka.build(input_shape=(None, 320, 64))
# happyquokka.load_weights('task2_regression/experiments/results_happyquokka_env/model.h5')

# vlaai = vlaai(output_dim=1)
# vlaai.build(input_shape=(None, 320, 64))
# vlaai.load_weights('task2_regression/experiments/results_vlaai_env/model.h5')

# eegnet = EEGNet(64, 320, 0.06)
# eegnet.build(input_shape=(None, 320, 64))
# eegnet.load_weights('task2_regression/experiments/results_eegnet_env/model.h5')

# fcnn = FCNN(320)
# fcnn.build(input_shape=(None, 320, 64))
# fcnn.load_weights('task2_regression/experiments/results_fcnn_env/model.h5')

# linear = simple_linear_model(integration_window = int(64*0.5), nb_filters=1)
# linear.build(input_shape=(None, 320, 64))
# linear.load_weights('task2_regression/experiments/results_linear_env/model.h5')

In [None]:
# Load the dataset
paths = glob.glob("/workspace/auditory-eeg-challenge-2024/DTU_evaluations/DTU/*.npz")
print("Found {} paths for evaluation".format(len(paths)))
subjects = set(["_".join(os.path.basename(x).split("_")[:2]) for x in paths])
print("Found {} subjects for evaluation".format(len(subjects)))

In [7]:
# Set the number of trials that should be evaluated on for each subject
# If None, it will evaluate on all trials
# You can set this to a lower number to speed up the next code cell
nb_evaluation_trials = None

In [None]:
## Run the model evaluation
subject_scores = {}
boxplot_data = []

# Iterate over the subjects in the DTU dataset
for subject in subjects:
    print("Evaluating subject {}".format(subject))
    for index, p in enumerate(
        glob.glob("/workspace/vlaai-neuro-decoding/evaluation_datasets/DTU/{}_*.npz".format(subject))
    ):
        print("Gathering scores for {}...".format(p))
        # Load the data
        # Data is stored in .npz format with two keys: 'eeg' and 'envelope'
        # containing preprocessed EEG and corresponding speech stimulus
        # envelope.
        data = np.load(p)
        eeg = data["eeg"]
        envelope = data["envelope"]

        # Standardize EEG and envelope
        eeg = (eeg - eeg.mean(axis=0, keepdims=True)) / eeg.std(
            axis=0, keepdims=True
        )
        envelope = (
            envelope - envelope.mean(axis=0, keepdims=True)
        ) / envelope.std(axis=0, keepdims=True)

        # Window the data in windows of 5 seconds with 80% overlap
        windowed_eeg = window_data(eeg, 320, 64)
        windowed_envelope = window_data(envelope, 320, 64)

        # Evaluate the model on the overlapping windows
        if subject not in subject_scores:
            subject_scores[subject] = []
        predictions = adt.predict(windowed_eeg)
        for pred, true in zip(predictions, windowed_envelope):
            r = pearsonr(pred.reshape(-1), true.reshape(-1))
            subject_scores[subject] += [r[0]]
        if (
            nb_evaluation_trials is not None
            and index == nb_evaluation_trials - 1
        ):
            # Stop at this trial for the current subject
            break
    # Report the mean score for each subject
    mean_scores = np.mean(subject_scores[subject])
    boxplot_data += [mean_scores]
    print("Subject {}: {}".format(subject, mean_scores))

In [None]:
# Plot the results
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

plt.figure(figsize=(5, 5))
df = pd.DataFrame.from_dict({"ADT network": boxplot_data})
sns.violinplot(data=df, orient="v")
plt.ylabel("Pearson correlation")
plt.xlabel("Models")
plt.title("Evaluation of the pre-trained ADT network on the DTU dataset")
plt.grid(True)
plt.show()
print("Median score = {:.2f}".format(np.median(boxplot_data)))