In [None]:
import sys
from pathlib import Path

import mne

import kinnd

In [None]:
assert kinnd.utils.paths.lab_server_is_mounted()

fpaths = kinnd.utils.paths.get_semantics_fpaths()

In [None]:
from collections import defaultdict
from tqdm.notebook import tqdm

semantics_epochs = defaultdict()
for subject in tqdm(fpaths, total=len(fpaths)):
    ep = kinnd.studies.semantics.read_epochs_semantics(subject, verbose="WARNING")
    semantics_epochs[subject] = ep
semantics_epochs = dict(semantics_epochs)

In [None]:
import xarray as xr

def epochs_to_xarray(ep):
    """Convert an MNE Epochs object to an xarray Dataset."""
    cond_xrs = defaultdict()
    for condition in ep.event_id:
        if not len(ep[condition]):
            continue
        data = ep[condition].get_data()
        dims = ("epoch", "channel", "time")
        coords = {
            "epoch": ep[condition].selection,
            "channel": ep.ch_names,
            "time": ep.times,
        }
        cond_xrs[condition] = xr.DataArray(data, coords=coords, dims=dims)
    return xr.Dataset(cond_xrs)

semantics_xr = {subject: epochs_to_xarray(ep) for subject, ep in semantics_epochs.items()}


In [None]:
out_dir = kinnd.utils.paths.semantics_path() / "derivatives" / "epochs"
assert out_dir.exists()

for subject, xr_data in tqdm(semantics_xr.items(), total=len(semantics_xr)):
    out_fpath = out_dir / f"{subject}_epochs.netcdf"
    xr_data.to_netcdf(out_fpath)

In [None]:
semantics_evoked_xr = defaultdict()

for subject, ep in semantics_epochs.items():
    condition_xrs = defaultdict()
    for condition in ["match", "mismatch"]:
        ev = ep[condition].apply_baseline((None, 0)).average()
        assert ev.data.shape == (len(ev.ch_names), len(ev.times))

        ev_xr = xr.DataArray(
            [ev.data],
            coords={
                "subject": [subject],
                "channel": ev.ch_names,
                "time": ev.times,
            },
            dims=("subject", "channel", "time"),
        )
        condition_xrs[condition] = ev_xr
    ev_ds = xr.Dataset(condition_xrs)
    semantics_evoked_xr[subject] = ev_ds
semantics_evoked_xr = dict(semantics_evoked_xr)

semantics_evoked_xr = xr.concat(semantics_evoked_xr.values(), dim="subject")
semantics_evoked_xr.to_netcdf(out_dir.parent / "evoked" / "evoked.netcdf")

In [None]:
import pandas as pd

def get_group_assignment():
    """Return a DataFrame containing the group assignment for each semantics subject."""
    fname = kinnd.utils.paths.lab_server_path() / "charlotte_semantics_data" / "ERSP_Particpants.xlsx"
    participants_df = pd.read_excel(fname)

    mv_asd_series = participants_df["MV ASD"].to_frame(name="subject").dropna()
    mv_asd_series["group"] = "ASD-Nonverbal"

    v_asd_series = participants_df["V ASD"].to_frame(name="subject").dropna()
    v_asd_series["group"] = "ASD-Verbal"

    td_series = participants_df["TD"].to_frame(name="subject").dropna()
    td_series["group"] = "TD"
    td_series

    df = pd.concat([mv_asd_series, v_asd_series, td_series], ignore_index=True)
    for tup in df.itertuples():
        stem = "sub-"
        sub = tup.subject.replace("s17", "")
        sub = str(int(sub)).zfill(2)
        df.loc[tup.Index, "subject"] = stem + sub
    df = df.drop(index=df[df.subject == "sub-23"].index)
    df.set_index("subject", inplace=True)
    return df

In [None]:
import seaborn as sns
sns.set_style("darkgrid")

ds = semantics_evoked_xr
ROI = ["Cz", "Fz"]
ds_long = ds[["match", "mismatch"]].sel(channel=ROI, time=slice(1.0, None)).mean("channel").to_dataframe().reset_index().melt(
    id_vars=["subject", "time"],
    value_vars=["match", "mismatch"],
    var_name="condition",
    value_name="amplitude"
)
ds_long = ds_long.set_index("subject").join(df)

sns.lineplot(data=ds_long.loc[ds_long.group == "TD"],
             x="time",
             y="amplitude",
             hue="condition",
             errorbar="sd")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

mon = mne.channels.make_standard_montage("standard_1020")
info = mne.create_info(
    ch_names=ds["match"].channel.values.tolist(),
    sfreq=1000,
    ch_types="eeg",
)
info.set_montage(mon)

match = ds["match"].sel(time=slice(1.0, None)).mean(["subject", "time"])
mismatch = ds["mismatch"].sel(time=slice(1.0, None)).mean(["subject", "time"])

diff =  mismatch - match

fig, ax = plt.subplots(constrained_layout=True, figsize=(5, 5))
sns.set_style("white")
mne.viz.plot_topomap(
    diff.data,
    info,
    axes=ax,
    names=diff.channel.values,
    show=False,
    )
ax.set_title("Warmer colors indicate greater activity in the mismatch condition.")
title = "EEG activity difference in the mismatch condition\n"
title += "relative to the match condition."
fig.suptitle(title, fontsize=14)
fig.colorbar(ax.images[0], ax=ax)
fig.show()

### TFR for all participants

In [None]:
freqs = np.arange(2, 36)
kwargs = dict(
    method="multitaper",
    freqs=freqs,
    n_cycles=freqs,
    use_fft=True,
)

tfr_dict = defaultdict()
for subject, ep in tqdm(semantics_epochs.items(), total=len(semantics_epochs)):
    print(subject)
    this_ep = ep.copy().apply_baseline((None, 0))
    tfr = this_ep.compute_tfr(**kwargs)
    condition_xrs = defaultdict()
    for condition in ["match", "mismatch"]:
        ev = tfr[condition].average()
        assert ev.data.shape == (len(ev.ch_names), len(ev.freqs), len(ev.times))
        ev_xr = xr.DataArray(
            [ev.data],
            coords={
                "subject": [subject],
                "channel": ev.ch_names,
                "time": ev.times,
                "freq": ev.freqs,
            },
            dims=("subject", "channel", "freq", "time"),
        )
        condition_xrs[condition] = ev_xr
    ev_ds = xr.Dataset(condition_xrs)
    tfr_dict[subject] = ev_ds
tfr_ds = xr.concat(tfr_dict.values(), dim="subject")

out_name = out_dir.parent / "tfr" / "tfr.netcdf"
tfr_ds.to_netcdf(out_name)


In [None]:
def add_band_col(df):
    """Add a column to a DataFrame that indicates the frequency band of each row."""
    freq_bounds = {"_": 0, "delta": 3, "theta": 7, "alpha": 13, "beta": 30, "gamma": 45}
    df["band"] = pd.cut(
        df["freq"], bins=list(freq_bounds.values()),
        labels=list(freq_bounds.keys())[1:]
    )
    freq_bands_of_interest = ["delta", "theta", "alpha", "beta", "gamma"]
    df = df[df.band.isin(freq_bands_of_interest)]
    df["band"] = df["band"].cat.remove_unused_categories()
    return df

def xr_to_df(xr, subject, channels=("Cz", "Fz")):
    """Convert an xarray Dataset to a DataFrame."""
    arr = xr.sel(subject=subject, channel=list(channels))
    if len(channels) > 1:
        arr = arr.mean("channel")
    df = arr.to_dataframe().reset_index()
    df = df.drop(columns="subject")
    df = df.melt(
        id_vars=["time", "freq"],
        value_vars=["match", "mismatch"],
        var_name="condition",
        value_name="value",
    )
    df = add_band_col(df)
    return df

df = xr_to_df(tfr_ds, "sub-10")

df_dict = defaultdict()
for subject in tqdm(tfr_ds.subject.values, total=len(tfr_ds.subject)):
    df = xr_to_df(tfr_ds, subject)
    out_name = out_dir.parent / "tfr" / "csv" / f"{subject}_tfr.csv"
    df.to_csv(out_name, index=False)
    df_dict[subject] = df
df_dict = dict(df_dict)

In [None]:
group_df = get_group_assignment()
group_df.head()

In [None]:
ROI = ["Cz", "Fz"]


group_ds_tfr = defaultdict()
for group in ["ASD-Nonverbal", "ASD-Verbal", "TD"]:
    group_ids = group_df[group_df["group"] == group].index
    group_ds = tfr_ds.sel(subject=group_ids, channel=ROI).mean("channel")
    match = group_ds["match"]
    mismatch = group_ds["mismatch"]
    diff = (mismatch - match).mean("subject")
    group_ds_tfr[group] = diff
group_ds_tfr = xr.Dataset(dict(group_ds_tfr))


In [None]:
asd_verbal = group_ds_tfr["ASD-Verbal"].to_dataframe().reset_index()
add_band_col(asd_verbal)
asd_verbal = asd_verbal.rename(columns={"ASD-Verbal": "ERDS"})
asd_verbal["group"] = "ASD-Verbal"

asd_nv = group_ds_tfr["ASD-Nonverbal"].to_dataframe().reset_index()
add_band_col(asd_nv)
asd_nv = asd_nv.rename(columns={"ASD-Nonverbal": "ERDS"})
asd_nv["group"] = "ASD-Nonverbal"

td = group_ds_tfr["TD"].to_dataframe().reset_index()
add_band_col(td)
td = td.rename(columns={"TD": "ERDS"})
td["group"] = "TD"

df = pd.concat([asd_verbal, asd_nv, td], ignore_index=True)

g = sns.FacetGrid(df, col="band", col_wrap=3, sharey=False, margin_titles=True)
g.map(sns.lineplot, "time", "ERDS", "group")
g.set_axis_labels("Time (s)", "ERDS")
g.set_titles(col_template="{col_name}")
g.add_legend()
g.fig.subplots_adjust(top=0.9)


In [None]:
mvs = group_df[group_df.group == "ASD-Nonverbal"].index
asds = group_df[group_df.group == "ASD-Verbal"].index
tds = group_df[group_df.group == "TD"].index

fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(15, 5))


for group, this_ax in zip(["ASD-Nonverbal", "ASD-Verbal", "TD"], ax.flatten()):
    subs = {"ASD-Nonverbal": mvs, "ASD-Verbal": asds, "TD": tds}
    (tfr_ds["match"].sel(subject=subs[group], channel=ROI, freq=slice(2,15))
                    .mean("channel")
                    .mean("subject")
                    .plot(ax=this_ax, vmin=0, vmax=1.2e-8)
                    )
    this_ax.set_title(group)
fig.show()