In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from scipy import stats

import mne
import xarray as xr


In [None]:
ds_fpath = Path(".").resolve().parent / "derivatives" / "evoked" / "aep_evoked.nc"
# ds_fpath = Path(".").resolve().parent / "derivatives" / "evoked" / "v2_online-reference" / "aep_evoked.nc"

ds = xr.open_dataset(ds_fpath)

channels = ["E7", "E106", "E13", "E6", "E112", "E31", "E80", "E37", "E55", "E87"]
# channels = ["E112"]
df = ds.to_array().sel(channel=channels).mean("channel").squeeze().to_pandas()

In [None]:
from scipy.stats import zscore
signal_means = df.abs().mean(axis=1)
threshold = zscore(signal_means) > 3
indices_to_drop = threshold[threshold].index # indices are subject ids
indices_to_drop

In [None]:
df_long = df.drop(indices_to_drop).reset_index().melt(id_vars=["subject"], value_name="ERP")
df_long["ERP"] = df_long["ERP"] * 1e6 # Volts to Microvolts
# df_long["time"] = df_long["time"] *  1000  # seconds to milliseconds
df_long_6 = df_long.loc[df_long["subject"].str.contains("ses-06")].copy()
df_long_6["age"] = "6M"
df_long_12 = df_long.loc[df_long["subject"].str.contains("ses-12")].copy()
df_long_12["age"] = "12M"
df_long_by_age = pd.concat([df_long_6, df_long_12], axis=0)

sns.set(style="darkgrid")
fig, ax = plt.subplots(constrained_layout=True)

colors_for_plot = sns.color_palette()[:2]
sns.lineplot(
    data=df_long_by_age,
    x="time",
    y="ERP",
    hue="age",
    palette=colors_for_plot,
    linewidth=0.5,
    ax=ax,
    )

ax.legend()
ax.set_title("Auditory Evoked Potentials (Central Electrodes)")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude (µV)")
ax.axvline(0, color=sns.color_palette()[3], linestyle="--")
plt.show()


In [None]:
# ds_fpath = Path(".").resolve().parent / "derivatives" / "evoked" / "v2_online-reference" / "aep_evoked.nc"
# ds = xr.open_dataset(ds_fpath)

In [None]:
def trim_mean(x, axis):
    return stats.trim_mean(x, 0.1, axis=axis)
ds.reduce(trim_mean, dim="subject")

In [None]:
info = mne.create_info(
    ch_names=ds.dropna("channel").channel.values.tolist(),
    sfreq=500,
    ch_types="eeg",
)
montage = mne.channels.make_standard_montage("GSN-HydroCel-128")

nave = ds.dropna("channel").subject.size
data = ds.dropna("channel").reduce(trim_mean, dim="subject").to_array().squeeze().values # channel, time
evoked = mne.EvokedArray(
    data=data,
    info=info,
    tmin=-0.2,
    nave=nave,
).set_montage(montage, match_alias=True).set_eeg_reference("average")
evoked

In [None]:
evoked.plot_joint()

In [None]:
import glob

def trim_mean(x, axis):
    return stats.trim_mean(x, 0.1, axis=axis)

session = "06"

derivatives_dir = Path(".").resolve().parent / "derivatives"

glob_pattern = ( 
    derivatives_dir / 
    "evoked" / 
    "sub-*" / 
    f"ses-{session}" / 
    "*_evoked.fif"
)
evoked_files = glob.glob(str(glob_pattern))
evokeds = []
for fpath in evoked_files:
    fname = Path(fpath).name
    ev = mne.read_evokeds(fpath)[0]
    if "Vertex Reference" in ev.ch_names:
        ev.rename_channels({"Vertex Reference": "VREF"})
    ev.interpolate_bads()
    ev.set_eeg_reference("average", projection=True)
    data = ev.get_data()
    da = xr.DataArray(
        data=[data],
        dims=["subject", "channel", "time"],
        coords={
            "subject": [fname[:11]],
            "channel": ev.info["ch_names"],
            "time": ev.times,
        },
    )
    evokeds.append(da)
evokeds = xr.concat(evokeds, dim="subject")
# grand_average = mne.grand_average(evokeds)


# grand_average.plot_joint(show=True)


In [None]:
nave = evokeds.subject.size

EXCLUDE = [
    "sub-PHI7111",
    "sub-7114",
    "sub-PHI7118",
    "sub-PHI7122",
    "sub-PHI7123",
    "sub-PHI7147",
    "sub-PHI7158",
    "sub-PHI7160",
    "sub-STL7068",
    ]

data = evokeds.sel(subject=~evokeds.subject.isin(EXCLUDE)).reduce(trim_mean, dim="subject").values.squeeze() # channel, time
#data = evokeds.reduce(trim_mean, dim="subject").values.squeeze() # channel, time
montage = mne.channels.make_standard_montage("GSN-HydroCel-129")
info = mne.create_info(
    ch_names=evokeds.channel.values.tolist(),
    sfreq=500,
    ch_types="eeg",
)
grand_average = mne.EvokedArray(
    data=data,
    info=info,
    tmin=-0.2,
    nave=27,
).set_montage(montage, match_alias=True).set_eeg_reference("average")
fig = grand_average.plot_joint(show=False)
fig.axes[3].set_ylabel("Amplitude (µV)")
fig.suptitle(f"High Density Grand Average AEP, {session} Months")
fig.show()
fig.savefig(f"./grand_average_{session}_months.png")

In [None]:
session = "12"

derivatives_dir = Path(".").resolve().parent / "derivatives"

glob_pattern = ( 
    derivatives_dir / 
    "evoked" / 
    "sub-*" / 
    f"ses-{session}" / 
    "*_evoked.fif"
)
evoked_files = glob.glob(str(glob_pattern))
evokeds = []
for fpath in evoked_files:
    fname = Path(fpath).name
    ev = mne.read_evokeds(fpath)[0]
    if "Vertex Reference" in ev.ch_names:
        ev.rename_channels({"Vertex Reference": "VREF"})
    ev.interpolate_bads()
    ev.set_eeg_reference("average", projection=True)
    data = ev.get_data()
    da = xr.DataArray(
        data=[data],
        dims=["subject", "channel", "time"],
        coords={
            "subject": [fname[:11]],
            "channel": ev.info["ch_names"],
            "time": ev.times,
        },
    )
    evokeds.append(da)
evokeds = xr.concat(evokeds, dim="subject")
# grand_average = mne.grand_average(evokeds)


# grand_average.plot_joint(show=True)


In [None]:
nave = evokeds.subject.size

data = evokeds.reduce(trim_mean, dim="subject").values.squeeze() # channel, time
#data = evokeds.reduce(trim_mean, dim="subject").values.squeeze() # channel, time
montage = mne.channels.make_standard_montage("GSN-HydroCel-129")
info = mne.create_info(
    ch_names=evokeds.channel.values.tolist(),
    sfreq=500,
    ch_types="eeg",
)
grand_average = mne.EvokedArray(
    data=data,
    info=info,
    tmin=-0.2,
    nave=27,
).set_montage(montage, match_alias=True).set_eeg_reference("average")
fig = grand_average.plot_joint(show=False)
fig.axes[3].set_ylabel("Amplitude (µV)")
fig.suptitle(f"High Density Grand Average AEP, {session} Months")
fig.show()
fig.savefig(f"./grand_average_{session}_months.png")