In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr


import mne

import kinnd

In [None]:
LAB_SERVER_IS_MOUNTED = kinnd.utils.paths.lab_server_is_mounted(strict=False)
LOCAL_DATA_DIR = Path("/Volumes") / "UBUNTU18" / "USC" / "charlotte_semantics_data" / "sem_esrp"
if LAB_SERVER_IS_MOUNTED:
    directory = None
else:
    directory = LOCAL_DATA_DIR


fpaths = kinnd.utils.paths.get_semantics_fpaths(directory=directory)

# Load Subject List

In [None]:
def get_group_assignment(filename=None):
    """Return a DataFrame containing the group assignment for each semantics subject.

    Parameters
    ----------
    filename : str | Path | None
        If None, will attempt to load the file from the lab server at
        ``charlotte_semantics_data/ERSP_Particpants.xlsx.``. If a str or Path, it must
        point tot he file that contains the group assignment for each subject in the
        semantics study. This file is named ``ERSP_Particpants.xlsx`` on the lab server.

    Returns
    -------
    df : pd.DataFrame
        A DataFrame with the following columns:
        - subject : str
            The subject ID.
        - group : str
            The group assignment for the subject.
    """
    if filename is None:
        fname = kinnd.utils.paths.lab_server_path() / "charlotte_semantics_data" / "ERSP_Particpants.xlsx"
    else:
        if not isinstance(directory, (str, Path)):
            raise TypeError(f"filename must be a str or Path. Got {type(filename)}")
        fname = Path(filename).expanduser().resolve()
        if not fname.exists():
            raise FileNotFoundError(f"File not found: {fname}")
    participants_df = pd.read_excel(fname)

    mv_asd_series = participants_df["MV ASD"].to_frame(name="subject").dropna()
    mv_asd_series["group"] = "ASD-Nonverbal"

    v_asd_series = participants_df["V ASD"].to_frame(name="subject").dropna()
    v_asd_series["group"] = "ASD-Verbal"

    td_series = participants_df["TD"].to_frame(name="subject").dropna()
    td_series["group"] = "TD"

    df = pd.concat([mv_asd_series, v_asd_series, td_series], ignore_index=True)
    for tup in df.itertuples():
        stem = "sub-"
        sub = tup.subject.replace("s17", "")
        sub = str(int(sub)).zfill(2)
        df.loc[tup.Index, "subject"] = stem + sub
    df = df.drop(index=df[df.subject == "sub-23"].index)
    df.set_index("subject", inplace=True)
    return df

In [None]:
df = get_group_assignment(filename=LOCAL_DATA_DIR.parent / "ERSP_Participants.xlsx")

# Load Epochs

In [None]:
from collections import defaultdict
from tqdm.notebook import tqdm

semantics_epochs = defaultdict()
for subject in tqdm(fpaths, total=len(fpaths)):
    ep = kinnd.studies.semantics.read_epochs_semantics(fpaths, subject, verbose="CRITICAL")
    semantics_epochs[subject] = ep
semantics_epochs = dict(semantics_epochs)

In [None]:
def epochs_to_xarray(ep):
    """Convert an MNE Epochs object to an xarray Dataset."""
    cond_xrs = defaultdict()
    for condition in ep.event_id:
        if not len(ep[condition]):
            continue
        data = ep[condition].get_data()
        dims = ("epoch", "channel", "time")
        coords = {
            "epoch": ep[condition].selection,
            "channel": ep.ch_names,
            "time": ep.times,
        }
        cond_xrs[condition] = xr.DataArray(data, coords=coords, dims=dims)
    return xr.Dataset(cond_xrs)

semantics_xr = {subject: epochs_to_xarray(ep)
                for subject, ep
                in tqdm(semantics_epochs.items())
                }


In [None]:
if LAB_SERVER_IS_MOUNTED:
    out_dir = kinnd.utils.paths.semantics_path() / "derivatives" / "epochs"
    assert out_dir.exists()

    for subject, xr_data in tqdm(semantics_xr.items(), total=len(semantics_xr)):
        out_fpath = out_dir / f"{subject}_epochs.netcdf"
        xr_data.to_netcdf(out_fpath)

# Evoked

In [None]:
semantics_evoked_xr = defaultdict()

for subject, ep in semantics_epochs.items():
    condition_xrs = defaultdict()
    for condition in ["match", "mismatch"]:
        ev = ep[condition].apply_baseline((None, 0)).average()
        assert ev.data.shape == (len(ev.ch_names), len(ev.times))

        ev_xr = xr.DataArray(
            [ev.data],
            coords={
                "subject": [subject],
                "channel": ev.ch_names,
                "time": ev.times,
            },
            dims=("subject", "channel", "time"),
        )
        condition_xrs[condition] = ev_xr
    ev_ds = xr.Dataset(condition_xrs)
    semantics_evoked_xr[subject] = ev_ds
semantics_evoked_xr = dict(semantics_evoked_xr)

semantics_evoked_xr = xr.concat(semantics_evoked_xr.values(), dim="subject")
if LAB_SERVER_IS_MOUNTED:
    semantics_evoked_xr.to_netcdf(out_dir.parent / "evoked" / "evoked.netcdf")

In [None]:
semantics_evoked_xr

In [None]:
import seaborn as sns
sns.set_style("darkgrid")

ds = semantics_evoked_xr
ROI = ["Cz", "Fz"]
ds_long = ds[["match", "mismatch"]].sel(channel=ROI, time=slice(1.0, None)).mean("channel").to_dataframe().reset_index().melt(
    id_vars=["subject", "time"],
    value_vars=["match", "mismatch"],
    var_name="condition",
    value_name="amplitude"
)
ds_long = ds_long.set_index("subject").join(df)

sns.lineplot(data=ds_long.loc[ds_long.group == "TD"],
             x="time",
             y="amplitude",
             hue="condition",
             errorbar="sd")

In [None]:
kinnd.viz.plot_topomap

In [None]:
import kinnd.viz


mon = mne.channels.make_standard_montage("standard_1020")
info = mne.create_info(
    ch_names=ds["match"].channel.values.tolist(),
    sfreq=1000,
    ch_types="eeg",
)
info.set_montage(mon)

def get_difference_wave(ds):
    """Subtract the match from the mismatch waveform for each subject."""
    match = ds["match"].sel(time=slice(1.0, None)).mean(["subject", "time"])
    mismatch = ds["mismatch"].sel(time=slice(1.0, None)).mean(["subject", "time"])
    diff = mismatch - match
    return diff
sel_kwargs = dict(time=slice(1.0, None))
diff_td = get_difference_wave(
    ds.sel(subject=df[df.group == "TD"].index, **sel_kwargs)
    )
diff_asd_v = get_difference_wave(
    ds.sel(subject=df[df.group == "ASD-Verbal"].index, **sel_kwargs)
    )
diff_asd_mv = get_difference_wave(
    ds.sel(subject=df[df.group == "ASD-Nonverbal"].index, **sel_kwargs)
    )

vmin = np.min([diff_td.values, diff_asd_v.values, diff_asd_mv.values])
vmax = np.max([diff_td.values, diff_asd_v.values, diff_asd_mv.values])
assert isinstance(vmin, (int, float))
assert isinstance(vmax, (int, float))

fig, ax = plt.subplots(1, 3, constrained_layout=True, figsize=(10, 5))
sns.set_style("white")

topo_kwargs = dict(show=False, vlim=(vmin, vmax), names=diff_td.channel.values)

dat = dict(zip(diff_td.channel.values, diff_td.values))
kinnd.viz.plot_topomap(
    dat, info, axes=ax[0], **topo_kwargs)
ax[0].set_title("TD")

dat = dict(zip(diff_asd_v.channel.values, diff_asd_v.values))
kinnd.viz.plot_topomap(dat, info, axes=ax[1], **topo_kwargs)
ax[1].set_title("ASD-Verbal")

dat = dict(zip(diff_asd_mv.channel.values, diff_asd_mv.values))
kinnd.viz.plot_topomap(dat, info, axes=ax[2], **topo_kwargs)
ax[2].set_title("ASD-Nonverbal")

fig.suptitle("Mismatch - Match EEG activity (1-2s)")
fig.show()

In [None]:
import itertools

groups = df.group.unique()
assert len(groups) == 3
conditions = ["match", "mismatch"]
fig, axes = plt.subplots(3, 2, figsize=(10, 10), constrained_layout=True)

group_dict = defaultdict()
for group, condition in itertools.product(groups, conditions):
    data = (ds[condition].sel(subject=df[df.group == group].index, **sel_kwargs)
                         .mean(["subject", "time"])
                         )
    group_dict[(group, condition)] = dict(zip(data.channel.values, data.values))

vmin = np.min(
    [list(group_dict[(group, condition)].values())
     for group in groups
     for condition in conditions
     ]
    )
vmax = np.max(
    [list(group_dict[(group, condition)].values())
     for group in groups
     for condition in conditions
     ]
    )
assert isinstance(vmin, (int, float))
assert isinstance(vmax, (int, float))
vlim = (vmin, vmax)
topo_kwargs = dict(show=False, vlim=vlim, names=data.channel.values)

for this_ax, (group, condition) in zip(axes.flatten(), list(itertools.product(groups, conditions))):
    data = (ds[condition].sel(subject=df[df.group == group].index, **sel_kwargs)
                         .mean(["subject", "time"])
                         )
    data_dict = dict(zip(data.channel.values, data.values))
    kinnd.viz.plot_topomap(
        data_dict,
        info,
        axes=this_ax,
        **topo_kwargs,
    )
    this_ax.set_title(f"{group}: {condition}")

In [None]:
FRONT = ["F3", "Fz", "F4"]
MID = ["C3", "Cz", "C4"]
ROI = FRONT + MID

fig, axes = plt.subplots(1, 3, constrained_layout=True, figsize=(10, 5))
sns.set_style("darkgrid")
colors = sns.color_palette()
melt_kwargs = dict(id_vars=["subject", "time"],
                    value_vars=["match", "mismatch"],
                    var_name="condition", value_name="amplitude"
                    )

for this_ax, this_group in zip(axes, ("TD", "ASD-Verbal", "ASD-Nonverbal")):
    df_ev = ds.sel(subject=df[df.group == this_group].index, channel=ROI).mean("channel").to_dataframe()
    df_ev = df_ev.reset_index().melt(**melt_kwargs)
    sns.lineplot(data=df_ev, x="time", y="amplitude", hue="condition", ax=this_ax)
    this_ax.set_ylim(-1.4*1e-5, .4*1e-5) # rough guess based on the data
    this_ax.set_title(f"{this_group}: Frontal ROI")
    this_ax.axvline(1.0, color="k", linestyle="--")
    this_ax.axvspan(0.0, 0.99, color=colors[2], alpha=0.1)
    this_ax.axvspan(1.0, 2.0, color=colors[4], alpha=0.2)
    this_ax.text(0.5, 0, "Image", ha="center", va="center", fontsize=12)
    this_ax.text(1.5, -1.0*1e-5, "Word", ha="center", va="center", fontsize=12)
fig.show()

# Power

In [None]:
def add_band_col(df):
    """Add a column to a DataFrame that indicates the frequency band of each row."""
    freq_bounds = {"_": 0, "delta": 3, "theta": 7, "alpha": 13, "beta": 30, "gamma": 50}
    df["band"] = pd.cut(
        df["freq"], bins=list(freq_bounds.values()),
        labels=list(freq_bounds.keys())[1:]
    )
    freq_bands_of_interest = ["delta", "theta", "alpha", "beta", "gamma"]
    df = df[df.band.isin(freq_bands_of_interest)]
    df["band"] = df["band"].cat.remove_unused_categories()
    return df

In [None]:
power_dict = defaultdict()
for subject, ep in semantics_epochs.items():
    psd = ep.copy().apply_baseline((None, 0)).crop(tmin=1.0, tmax=2.0)
    psd = psd.compute_psd(method="welch", fmin=2, fmax=50)
    pow_xrs = defaultdict()
    for condition in psd.event_id:
        if not len(psd[condition]):
            continue
        assert psd.average().data.shape == (len(psd.ch_names), len(psd.freqs))
        pow_xr = xr.DataArray(
            [psd[condition].average().data],
            coords={
                "subject": [subject],
                "channel": psd.ch_names,
                "freq": psd.freqs,
            },
            dims=("subject", "channel", "freq"),
        )
        pow_xrs[condition] = pow_xr
    power_dict[subject] = xr.Dataset(pow_xrs)
power_dict = dict(power_dict)
power_ds = xr.concat(power_dict.values(), dim="subject")


In [None]:
def get_vlim(group_dict):
   """Get the vmin and vmax for the topomaps."""
   vmin = np.min([list(group_dict[(group, condition)].values())
                  for group in groups
                  for condition in conditions
                  ])
   vmax = np.max([list(group_dict[(group, condition)].values())
                  for group in groups
                  for condition in conditions
                  ])
   assert isinstance(vmin, (int, float))
   assert isinstance(vmax, (int, float))
   return (vmin, vmax)

In [None]:
def get_power(power_ds, *, group_df, group, condition, freqs):
    """Get the power for a group and condition."""
    df = group_df
    data = (power_ds[condition].sel(subject=df[df.group == group].index, freq=freqs).sum("freq")
                              .mean("subject")
                              )
    return dict(zip(data.channel.values, data.values))

freq_bands = {"delta": slice(2, 4),
              "theta": slice(4, 8),
              "alpha": slice(8, 13),
              "beta": slice(13, 30),
              "gamma": slice(30, 50)
              }

In [None]:
get_power(power_ds,
          group_df=df,
          group="TD",
          condition="match",
          freqs=freq_bands["delta"]
          )

In [None]:
power_ds

In [None]:

def plot_power_topomaps(power_ds, *, group_df, band):
    fig, ax = plt.subplots(3, 2, figsize=(10, 10), constrained_layout=True)
    sns.set_style("white")

    groups = group_df.group.unique()
    conditions = power_ds.data_vars

    group_power_dict = defaultdict()
    vmins = []
    vmaxs = []
    BAND = band.lower()
    for this_group, this_condition, in list(itertools.product(groups, conditions)):
        this_power = get_power(
            power_ds,
            group_df=group_df,
            group=this_group,
            condition=this_condition,
            freqs=freq_bands[BAND]
            )
        vmins.append(np.min(list(this_power.values())))
        vmaxs.append(np.max(list(this_power.values())))
        group_power_dict[(this_group, this_condition)] = this_power

    vlim = (np.min(vmins), np.max(vmaxs))
    topo_kwargs = dict(show=False, vlim=vlim, names=list(this_power.keys()))

    for this_ax, (group, condition) in zip(ax.flatten(), list(itertools.product(groups, conditions))):
        data = group_power_dict[(group, condition)]
        kinnd.viz.plot_topomap(
            data,
            info,
            axes=this_ax,
            **topo_kwargs,
        )
        this_ax.set_title(f"{group}: {condition}")
    fig.suptitle(f"{BAND.capitalize()} power (1-2s)")

plot_power_topomaps(power_ds, group_df=df, band="delta")


In [None]:
plot_power_topomaps(power_ds, group_df=df, band="theta")

In [None]:
plot_power_topomaps(power_ds, group_df=df, band="alpha")

In [None]:
plot_power_topomaps(power_ds, group_df=df, band="beta")

In [None]:
plot_power_topomaps(power_ds, group_df=df, band="gamma")

In [None]:
def plot_psds(power_ds, *, group_df, log=True, vlim=None):
    """Plot the power spectral density for each group."""
    fig, axes = plt.subplots(1, 3, constrained_layout=True, figsize=(10, 5))
    sns.set_style("darkgrid")

    groups = group_df.group.unique()
    conditions = power_ds.data_vars

    melt_kwargs = dict(
        id_vars=["channel", "freq"],
        value_vars=conditions,
        var_name="condition",
        value_name="power"
        )

    for ax, group in zip(axes, groups):
        psd_df = (power_ds.sel(subject=df[df.group == group].index)
                        .mean(["subject"])
                        .to_dataframe()
                        )
        psd_df = psd_df.reset_index().melt(**melt_kwargs)

        psd_df["power"] = np.log10(psd_df["power"]) if log else psd_df["power"]
        sns.lineplot(
            data=psd_df,
            x="freq",
            y="power",
            units="channel",
            hue="condition",
            ax=ax,
            errorbar=None,
            estimator=None,
            linewidth=.5,
            alpha=.5
            )
        ax.set_title(group)
        if vlim is not None:
            ax.set_ylim(vlim)
        ax.set_ylabel("log10(power)")
        ax.set_xlabel("Frequency (Hz)")

    fig.suptitle("Power Spectral Density (1-2s)")
    fig.show()

plot_psds(power_ds, group_df=df, vlim=(-14, np.log10(2*1e-11)))

# Relative Power

In [None]:
total_power = power_ds.sum("freq")
power_ds_rel = power_ds / total_power

In [None]:
plot_psds(power_ds_rel, group_df=df, log=False, vlim=(0, .20))

In [None]:
plot_power_topomaps(power_ds_rel, group_df=df, band="delta")

In [None]:
plot_power_topomaps(power_ds_rel, group_df=df, band="theta")

In [None]:
plot_power_topomaps(power_ds_rel, group_df=df, band="alpha")

### TFR for all participants

In [None]:
RECOMPUTE = False

if RECOMPUTE:
    freqs = np.arange(2, 36)
    kwargs = dict(
        method="multitaper",
        freqs=freqs,
        n_cycles=freqs,
        use_fft=True,
    )

    tfr_dict = defaultdict()
    for subject, ep in tqdm(semantics_epochs.items(), total=len(semantics_epochs)):
        print(subject)
        this_ep = ep.copy().apply_baseline((None, 0))
        tfr = this_ep.compute_tfr(**kwargs)
        condition_xrs = defaultdict()
        for condition in ["match", "mismatch"]:
            ev = tfr[condition].average()
            assert ev.data.shape == (len(ev.ch_names), len(ev.freqs), len(ev.times))
            ev_xr = xr.DataArray(
                [ev.data],
                coords={
                    "subject": [subject],
                    "channel": ev.ch_names,
                    "time": ev.times,
                    "freq": ev.freqs,
                },
                dims=("subject", "channel", "freq", "time"),
            )
            condition_xrs[condition] = ev_xr
        ev_ds = xr.Dataset(condition_xrs)
        tfr_dict[subject] = ev_ds
    tfr_ds = xr.concat(tfr_dict.values(), dim="subject")

    if LAB_SERVER_IS_MOUNTED:
        out_name = out_dir.parent / "tfr" / "tfr.netcdf"
        tfr_ds.to_netcdf(out_name)

else:
    tfr_ds = xr.open_dataset(directory.parent / "derivatives" / "tfr" / "tfr.netcdf")

tfr_ds


In [None]:
def get_tfr_power(tfr_ds, *, group_df, group, roi, freqs):
    """Get the power for a group and condition."""
    df = group_df
    tfr_df = tfr_ds.sel(
        subject=df[df.group == group].index,
        channel=roi, freq=freqs).mean(["freq", "channel"]).to_dataframe().reset_index()
    tfr_df = tfr_df.melt(
        id_vars=["subject", "time"],
        value_vars=["match", "mismatch"],
        var_name="condition",
        value_name="power"
    )
    return tfr_df

fig, ax = plt.subplots(1, 3, figsize=(10, 5), constrained_layout=True)
BAND = "theta"
ylim = (0, 6*1e-9)

tfr_df = get_tfr_power(tfr_ds, group_df=df, group="TD", roi=ROI, freqs=freq_bands[BAND])
sns.lineplot(data=tfr_df, x="time", y="power", hue="condition", n_boot=100, ax=ax[0])
ax[0].set_title("TD")
ax[0].set_ylim(ylim)

tfr_df = get_tfr_power(tfr_ds, group_df=df, group="ASD-Verbal", roi=ROI, freqs=freq_bands[BAND])
sns.lineplot(data=tfr_df, x="time", y="power", hue="condition", n_boot=100, ax=ax[1])
ax[1].set_title("ASD-Verbal")
ax[1].set_ylim(ylim)

tfr_df = get_tfr_power(tfr_ds, group_df=df, group="ASD-Nonverbal", roi=ROI, freqs=freq_bands[BAND])
sns.lineplot(data=tfr_df, x="time", y="power", hue="condition", n_boot=100, ax=ax[2])
ax[2].set_title("ASD-Nonverbal")
ax[2].set_ylim(ylim)

fig.suptitle(f"{BAND} power")
fig.show()

In [None]:
tfr_df

In [None]:
def xr_to_df(xr, subject, channels=("Cz", "Fz")):
    """Convert an xarray Dataset to a DataFrame."""
    arr = xr.sel(subject=subject, channel=list(channels))
    if len(channels) > 1:
        arr = arr.mean("channel")
    df = arr.to_dataframe().reset_index()
    df = df.drop(columns="subject")
    df = df.melt(
        id_vars=["time", "freq"],
        value_vars=["match", "mismatch"],
        var_name="condition",
        value_name="value",
    )
    df = add_band_col(df)
    return df


if LAB_SERVER_IS_MOUNTED:
    df_dict = defaultdict()
    for subject in tqdm(tfr_ds.subject.values, total=len(tfr_ds.subject)):
        df = xr_to_df(tfr_ds, subject)
        out_name = out_dir.parent / "tfr" / "csv" / f"{subject}_tfr.csv"
        df.to_csv(out_name, index=False)
        df_dict[subject] = df
    df_dict = dict(df_dict)

In [None]:
group_df = get_group_assignment()
group_df.head()

In [None]:
ROI = ["Cz", "Fz"]


group_ds_tfr = defaultdict()
for group in ["ASD-Nonverbal", "ASD-Verbal", "TD"]:
    group_ids = group_df[group_df["group"] == group].index
    group_ds = tfr_ds.sel(subject=group_ids, channel=ROI).mean("channel")
    match = group_ds["match"]
    mismatch = group_ds["mismatch"]
    diff = (mismatch - match).mean("subject")
    group_ds_tfr[group] = diff
group_ds_tfr = xr.Dataset(dict(group_ds_tfr))


In [None]:
asd_verbal = group_ds_tfr["ASD-Verbal"].to_dataframe().reset_index()
add_band_col(asd_verbal)
asd_verbal = asd_verbal.rename(columns={"ASD-Verbal": "ERDS"})
asd_verbal["group"] = "ASD-Verbal"

asd_nv = group_ds_tfr["ASD-Nonverbal"].to_dataframe().reset_index()
add_band_col(asd_nv)
asd_nv = asd_nv.rename(columns={"ASD-Nonverbal": "ERDS"})
asd_nv["group"] = "ASD-Nonverbal"

td = group_ds_tfr["TD"].to_dataframe().reset_index()
add_band_col(td)
td = td.rename(columns={"TD": "ERDS"})
td["group"] = "TD"

df = pd.concat([asd_verbal, asd_nv, td], ignore_index=True)

g = sns.FacetGrid(df, col="band", col_wrap=3, sharey=False, margin_titles=True)
g.map(sns.lineplot, "time", "ERDS", "group")
g.set_axis_labels("Time (s)", "ERDS")
g.set_titles(col_template="{col_name}")
g.add_legend()
g.fig.subplots_adjust(top=0.9)


In [None]:
mvs = group_df[group_df.group == "ASD-Nonverbal"].index
asds = group_df[group_df.group == "ASD-Verbal"].index
tds = group_df[group_df.group == "TD"].index

fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(15, 5))


for group, this_ax in zip(["ASD-Nonverbal", "ASD-Verbal", "TD"], ax.flatten()):
    subs = {"ASD-Nonverbal": mvs, "ASD-Verbal": asds, "TD": tds}
    (tfr_ds["match"].sel(subject=subs[group], channel=ROI, freq=slice(2,15))
                    .mean("channel")
                    .mean("subject")
                    .plot(ax=this_ax, vmin=0, vmax=1.2e-8)
                    )
    this_ax.set_title(group)
fig.show()

# Global Field Power

In [None]:
global_power = defaultdict()

for subject, epochs in tqdm(semantics_epochs.items()):
    gfp = epochs.copy().apply_baseline((None, 0))

    band_gfps = defaultdict()
    for band, frequencies in freq_bands.items():
        fmin = frequencies.start
        fmax = frequencies.stop
        gfp = gfp.filter(fmin, fmax, l_trans_bandwidth=1, h_trans_bandwidth=1)
        condition_gfps = defaultdict()
        for condition in ["match", "mismatch"]:
            this_gfp = gfp[condition].subtract_evoked()
            this_gfp = gfp[condition].apply_hilbert(envelope=True)
            this_gfp = gfp[condition].average()
            assert this_gfp.data.shape == (len(this_gfp.ch_names), len(this_gfp.times))

            this_gfp_xr = xr.DataArray(
                [[this_gfp.data]],
                coords={
                    "subject": [subject],
                    "band": [band],
                    "channel": this_gfp.ch_names,
                    "time": this_gfp.times,
                },
                dims=("subject", "band", "channel", "time"),
            )
            condition_gfps[condition] = this_gfp_xr
        gfp_ds = xr.Dataset(condition_gfps)
        band_gfps[band] = gfp_ds
    global_power[subject] = xr.concat(band_gfps.values(), dim="band")
global_power = dict(global_power)
global_power = xr.concat(global_power.values(), dim="subject")


In [None]:
global_power