In [None]:
from pathlib import Path
import mffpy
import kinnd
import mne

In [None]:
directory = Path("/Volumes") / "UBUNTU18" / "USC" / "listen"

# Load Raw Data

In [None]:
def check_directory(directory):
    """Check if a directory exists, and raise an error if it doesn't.

    Parameters
    ----------
    directory : str | Path
        The directory to check.

    Returns
    -------
    Path
        The directory as a Path object.

    Raises
    ------
    FileNotFoundError
        If the directory does not exist.
    IOError
        If the directory is not a string or Path object.
    """
    if not isinstance(directory, (str, Path)):
        raise IOError(
            f"directory must be a string or Path object, not {type(directory)}"
            )
    if not Path(directory).exists():
        raise FileNotFoundError(f"Directory {directory} does not exist.")
    return Path(directory)

def read_raw_listen(filename):
    import mffpy
    import pytz
    import datetime

    mff_reader = mffpy.Reader(filename)
    mff_reader.set_unit("EEG", "V")

    # Basic Information
    meas_date = mff_reader.startdatetime
    meas_date = meas_date.replace(tzinfo=pytz.timezone("US/Pacific"))
    meas_date = meas_date.astimezone(pytz.utc)
    meas_date = meas_date.replace(tzinfo=datetime.timezone.utc)

    sfreq = mff_reader.sampling_rates["EEG"]

    # Montage
    with mff_reader.directory.filepointer("info1") as fp:
        info = mffpy.XML.from_file(fp)
    montage_map = {"HydroCel GSN 128 1.0": "GSN-HydroCel-129",}
    mon = info.generalInformation["montageName"]
    montage = mne.channels.make_standard_montage(montage_map[mon])

    # samples
    eeg, _ = mff_reader.get_physical_samples()["EEG"]

    # Events
    categories = mffpy.XML.from_file(filename / "Events_ECI TCP-IP 55513.xml")
    events = categories.get_content()["event"]


    # Create MNE Objects
    ch_names = montage.ch_names
    ch_types = ["eeg"] * len(ch_names)
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
    info.set_montage(montage)
    info.set_meas_date(meas_date)
    raw = mne.io.RawArray(eeg, info)

    # Annotations
    cel_map = {"1": "match", "2": "mismatch"}
    WANT_EVENTS = ["img+", "snd+"]
    stim_map = {"img+": "image", "snd+": "word"}

    onsets = []
    durations = []
    descriptions = []
    for event in events:
        if event["code"] not in WANT_EVENTS:
            continue
        onset = event["beginTime"].replace(tzinfo=pytz.timezone("US/Pacific"))
        onset = onset.astimezone(pytz.utc).replace(tzinfo=datetime.timezone.utc)
        ts = (onset - raw.info["meas_date"]).total_seconds()
        duration = event["duration"] / 1000
        condition = cel_map[str(event["keys"]["cel#"])]
        description = f"{stim_map[event['code']]}_{condition}"
        onsets.append(ts)
        durations.append(duration)
        descriptions.append(description)
    raw.set_annotations(mne.Annotations(onsets, durations, descriptions))
    return raw

def get_listen_fpaths(directory=None):
    """Get the filepaths for Listen EEG data."""
    from collections import defaultdict
    if directory is None:
        raise NotImplementedError("Automatic search for Listen files not implemented yet.")
    directory = check_directory(directory)
    fpaths = kinnd.paths.get_eeg_fpaths(study="listen", directory=directory)
    subject_dict = defaultdict(dict)

    for fpath in fpaths:
        fname = fpath.name.lower()
        key = f"sub-{fname.split('_')[1]}"
        if "phonemes" in fname:
            subject_dict[key]["phonemes"] = fpath
        elif "resting" in fname:
            subject_dict[key]["resting"] = fpath
        elif "semantics" in fname:
            subject_dict[key]["semantics"] = fpath
        else:
            raise ValueError(f"File {fpath} does not match appear to belong to a Listen task."
                             "Expected 'phonemes', 'resting', or 'semantics' in the filename.")
    return dict(subject_dict)

In [None]:
listen_fpaths = get_listen_fpaths(directory)
listen_fpaths.keys()

# Process EEG data

In [None]:
import pylossless as ll 

config = ll.config.Config()
config.load_default()
config_fpath = Path(".") / "listen_config.yaml"

config["filtering"]["notch_filter_args"]["freqs"] = [60]
config["noisy_channels"]["outliers_kwargs"]["k"] = 6
config["noisy_channels"]["outliers_kwargs"]["lower"] = 0.25
config["noisy_channels"]["outliers_kwargs"]["upper"] = 0.75

config["noisy_epochs"]["outliers_kwargs"]["k"] = 6
config["noisy_epochs"]["outliers_kwargs"]["lower"] = 0.25
config["noisy_epochs"]["outliers_kwargs"]["upper"] = 0.75

config["uncorrelated_channels"]["outliers_kwargs"]["k"] = 6
config["uncorrelated_channels"]["outliers_kwargs"]["lower"] = 0.25
config["uncorrelated_channels"]["outliers_kwargs"]["upper"] = 0.75

config["uncorrelated_epochs"]["outliers_kwargs"]["k"] = 6
config["uncorrelated_epochs"]["outliers_kwargs"]["lower"] = 0.25
config["uncorrelated_epochs"]["outliers_kwargs"]["upper"] = 0.75

config["ica"]["noisy_ic_epochs"]["outliers_kwargs"]["k"] = 6
config["ica"]["noisy_ic_epochs"]["outliers_kwargs"]["lower"] = 0.25
config["ica"]["noisy_ic_epochs"]["outliers_kwargs"]["upper"] = 0.75

config.save(config_fpath)

In [None]:
from tqdm.notebook import tqdm

for subject in tqdm(listen_fpaths, total=len(listen_fpaths)):
    try:
        semantics_fpath = listen_fpaths[subject]["semantics"]
    except KeyError:
        print(f"### Skipping {subject} because no semantics file found.")
        continue
    if (directory / "derivatives" / "pylossless" / subject).exists():
        print(f"### Skipping {subject} because output already exists.")
        continue
    print(f"### Processing {subject} ###")
    raw = read_raw_listen(semantics_fpath)
    raw.info["bads"].extend(["E125", "E126", "E127", "E128"])

    # Find Breaks
    break_annots = mne.preprocessing.annotate_break(raw)
    raw.set_annotations(raw.annotations + break_annots)

    # Pipeline
    pipeline = ll.LosslessPipeline(config_fpath)
    pipeline.run_with_raw(raw)
    rejection_policy = ll.RejectionPolicy()
    cleaned_raw = rejection_policy.apply(pipeline)

    dpath = directory / "derivatives" / "pylossless" / subject
    dpath.mkdir(exist_ok=False, parents=False)
    eeg_out = dpath / f"{subject}_ses-01_task-semantics_eeg.fif"
    cleaned_raw.save(eeg_out)
    ica_out = dpath / f"{subject}_ses-01_task-semantics_ica.fif"
    pipeline.ica2.save(ica_out)
    labels_out = dpath / f"{subject}_ses-01_task-semantics_ic-labels.csv"
    pipeline.flags["ic"].to_csv(labels_out)
    del cleaned_raw, pipeline

# Load processed Raws, Epochs

In [None]:
import xarray as xr


def load_eeg(fpath):
    return mne.io.read_raw(fpath)

def make_epochs(raw, tmin=-0.1, tmax=2, preload=False):
    events, event_ids = mne.events_from_annotations(
        raw,
        regexp="image_match|image_mismatch",
        event_id={"image_match": 1, "image_mismatch": 2}
        )
    epochs = mne.Epochs(
        raw,
        events,
        event_id={"match": 1, "mismatch": 2},
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        preload=preload
        )
    return epochs


def get_evoked(epochs, event_id=None):
    if event_id is not None:
        epochs = epochs[event_id]
    evoked = epochs.apply_baseline((None, 0)).average()
    return evoked


def epochs_to_xr(epochs):
    data = epochs.get_data()
    coords = {
        "epoch": range(len(epochs)),
        "time": epochs.times,
        "channel": epochs.ch_names,
        }
    dims = ("epoch", "channel", "time")
    return xr.DataArray(data, coords=coords, dims=dims)


def evoked_to_xr(evoked, subject):
    data = [evoked.data]
    coords = {
        "subject": [subject],
        "channel": evoked.ch_names,
        "time": evoked.times,
        }
    dims = ("subject", "channel", "time")
    return xr.DataArray(data, coords=coords, dims=dims)


def get_evoked_xr(evoked, subject):
    dims = ("subject", "channel", "time")
    coords = {"subject": [subject], "channel": evoked.ch_names, "time": evoked.times}
    data = [evoked.data]
    return xr.DataArray(data, coords=coords, dims=dims)

In [None]:
SKIRT = [
    "E81", "E88", "E94", "E99", "E107", "E113", "E119", "E120", "E121", "E125",
    "E126", "E8", "E9", "E14", "E17", "E21", "E22", "E25", "E127", "E32", "E128",
    "E43", "E48", "E49", "E56", "E63", "E68", "E73"]

In [None]:
raw.plot(theme="light")

In [None]:
derivatives_dir = directory / "derivatives" / "pylossless"

derivative_paths = list(derivatives_dir.rglob("sub-*_eeg.fif"))
subjects = [p.parent.name for p in derivative_paths]


derivative_paths = list(derivatives_dir.rglob("sub-*_eeg.fif"))
subjects = [p.parent.name for p in derivative_paths]

SAVE_EPOCHS = True
evoked_da = []
raws = dict()
for file, subject in list(zip(derivative_paths, subjects)):
    raw = load_eeg(file).load_data()
    raw.filter(None, 50)
    raw.interpolate_bads(reset_bads=False)
    raw.set_eeg_reference()
    raws[subject] = raw.copy()
    epochs = make_epochs(raw, preload=True)
    if SAVE_EPOCHS:
        out_path = file.parent.parent.parent / "epochs" / f"{file.stem}_epo.fif"
        epochs.save(out_path, overwrite=True)

# Descriptives

In [None]:
from tqdm.notebook import tqdm
# epo_fpaths = kinnd.paths.get_epo_fpaths(study="listen", directory=directory)
epo_dir = directory / "derivatives" / "epochs"
epo_fpaths = list(epo_dir.glob("*_epo.fif"))
subjects = [f.stem.split("_")[0] for f in epo_fpaths]

n_events = {"match": [], "mismatch": []}
n_bad_channels = []

epos = dict()
for epo_fpath, subject in tqdm(list(zip(epo_fpaths, subjects)), total=len(epo_fpaths)):
    ep = mne.read_epochs(epo_fpath)
    n_events["match"].append(len(ep["match"]))
    n_events["mismatch"].append(len(ep["mismatch"]))
    n_bad_channels.append(len(ep.info["bads"]))
    epos[subject] = ep.copy()
    del ep

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_style("darkgrid")
colors = sns.color_palette()

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

sns.violinplot(data=n_events, ax=axes[0])
sns.stripplot(data=n_events, ax=axes[0], color=colors[4])
axes[0].set_title("Number of Events per Condition")
axes[0].set_ylabel("Number of Events")
axes[0].set_xticklabels(["Match", "Mismatch"])
axes[0].axhline(60, color="red", linestyle="--", label="Expected Events")

NUM_CHANNELS = 129
num_good_channels = [NUM_CHANNELS - n for n in n_bad_channels]
sns.violinplot(data=num_good_channels, ax=axes[1])
sns.stripplot(data=num_good_channels, ax=axes[1], color=colors[4])
axes[1].set_title("Number of Good Channels per Subject")
axes[1].set_ylabel("Number of Good Channels")
axes[1].axhline(129, color="red", linestyle="--", label="Total Channels")


# Evoked

In [None]:
for subject, epo in tqdm(epos.items(), total=len(epos)):
    evoked_ds = xr.Dataset(
        {"match": get_evoked_xr(get_evoked(epo["match"]), subject),
        "mismatch": get_evoked_xr(get_evoked(epo["mismatch"]), subject)}
        )
    evoked_da.append(evoked_ds)
evoked_semantics = xr.concat(evoked_da, dim="subject")
evoked_semantics.to_netcdf(directory / "derivatives" / "evoked.nc")

In [None]:
evoked_semantics

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, constrained_layout=True, figsize=(12, 6))

info = epochs.info
ch_names = evoked_semantics.coords["channel"].values
topo_kwargs = dict(vlim=(-.000025, .000025), show=False, names=ch_names)

data_dict = dict(
    zip(
        ch_names,
        evoked_semantics["match"].sel(time=slice(1, None)).mean(["subject", "time"]).values)
    )
im = kinnd.viz.plot_topomap(data_dict, info, axes=ax[0],  **topo_kwargs)
ax[0].set_title("Match")

data_dict = dict(
    zip(ch_names,
    evoked_semantics["mismatch"].sel(time=slice(1, None)).mean(["subject", "time"]).values)
    )
im = kinnd.viz.plot_topomap(data_dict, info, axes=ax[1], **topo_kwargs)
ax[1].set_title("Mismatch")

fig.colorbar(im, ax=ax, shrink=0.5, cmap="RdBu_r")
fig.show()

In [None]:
ROI = ["E29", "E13", "E6", "E112", "E111", "E28", "E20", "E12", "E5", "E118", "E39", "E7", "E106", "E105"]

fig, ax = plt.subplots(constrained_layout=True)

df = evoked_semantics.sel(channel=ROI).mean("channel").to_dataframe().reset_index()
df = df.melt(id_vars=["subject", "time"], var_name="condition", value_name="amplitude")
sns.lineplot(data=df, x="time", y="amplitude", hue="condition", ax=ax, n_boot=1000)

ax.axvspan(0, 1, color=colors[4], alpha=0.2)
ax.text(0.5, 0, "Image", fontsize=14)

ax.axvline(1, color=colors[5], linestyle="--")

# Power

In [None]:
def psd_to_xr(psd, subject):
    data = [psd.get_data()]
    coords = {
        "subject": [subject],
        "channel": psd.ch_names,
        "freq": psd.freqs,
        }
    dims = ("subject", "channel", "freq")
    return xr.DataArray(data, coords=coords, dims=dims)

psd_da = []
for subject, epo in tqdm(epos.items(), total=len(epos)):
    epo.info["bads"] = []
    psd = epo.compute_psd(method="welch", fmin=1, fmax=50)
    this_ds = xr.Dataset(
        {"match": psd_to_xr(psd["match"].average(), subject),
        "mismatch": psd_to_xr(psd["mismatch"].average(), subject)
        }
    )
    psd_da.append(this_ds)
psd_semantics = xr.concat(psd_da, dim="subject")

In [None]:
import numpy as np
import seaborn as sns

sns.set_style("darkgrid")

def plot_psd(semantics_psd, ax, condition="match"):

    fig = ax.get_figure()

    df = semantics_psd.median("subject").to_dataframe().reset_index()
    df = df.melt(id_vars=["channel", "freq"], var_name="condition", value_name="power")
    df = df.loc[df["condition"] == condition]
    df["power"] = np.log10(df["power"])

    sns.lineplot(
        data=df,
        x="freq",
        y="power",
        hue="channel",
        ax=ax,
        palette=[sns.color_palette()[0]],
        alpha=0.5,
        linewidth=.5,
        n_boot=100,
        legend=False,
    )

    df_grand_av = psd_semantics.median(["subject", "channel"]).to_dataframe().reset_index()
    df_grand_av = df_grand_av.melt(id_vars=["freq"], var_name="condition", value_name="power")
    df_grand_av = df_grand_av.loc[df_grand_av["condition"] == condition]
    df_grand_av["power"] = np.log10(df_grand_av["power"])

    sns.lineplot(
        data=df_grand_av,
        x="freq",
        y="power",
        ax=ax,
        color=sns.color_palette("tab10")[1],
        linewidth=2,
        errorbar=None,
    )
    ax.set_title(condition)
    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel("Power (log10)")
    ax.set_xlim(0, 50)
    return ax

fig, ax = plt.subplots(1, 2, figsize=(12, 6))

plot_psd(psd_semantics, ax[0], "match")
plot_psd(psd_semantics, ax[1], "mismatch")

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
mne.viz.plot_topomap(list(data_dict.values()), info, show=False, names=ch_names, axes=ax)
fig.show()

In [None]:
# exclude drop_chs
psd_semantics.sel(channel=~np.isin(psd_semantics["channel"], SKIRT)).coords["channel"].values

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# vlim = np.quantile(list(data_dict.values()), [.05, .95])
topo_kwargs = dict(vlim=vlim, show=False)
topo_kwargs = dict(show=False)
drop_chs = ["E125", "E126", "E127", "E128"]
SUBSET = ~np.isin(psd_semantics["channel"], SKIRT)

info = epochs.copy().drop_channels(SKIRT).info

ch_names = psd_semantics.sel(channel=SUBSET).coords["channel"].values
data_dict = dict(
    zip(
        ch_names,
        psd_semantics["match"].sel(channel=SUBSET, freq=slice(7, 11)).median(["subject"]).mean(["freq"]).values)
    )
vlim = np.quantile(list(data_dict.values()), [.05, .95])
im = kinnd.viz.plot_topomap(data_dict, info, axes=ax[0], **topo_kwargs)
ax[0].set_title("Match")

data_dict = dict(
    zip(
        ch_names,
        psd_semantics["mismatch"].sel(channel=SUBSET, freq=slice(7, 11)).median(["subject"]).mean(["freq"]).values)
    )

im = kinnd.viz.plot_topomap(data_dict, info, axes=ax[1], **topo_kwargs)
ax[1].set_title("Mismatch")

fig.colorbar(im, ax=ax, shrink=0.5, cmap="RdBu_r")
fig.show()