In [None]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
import mne

from mne.decoding import (
    GeneralizingEstimator,
    LinearModel,
    Scaler,
    SlidingEstimator,
    Vectorizer,
    cross_val_multiscore,
    get_coef,
)

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold

Defining the triggers and have a look at their timing

In [None]:
epochs = mne.read_epochs("/Volumes/Extreme_SSD/payam_data/Tinreg/epochs/gwld-epo.fif", preload=True)
epochs.pick(picks="eeg")
epochs.drop_bad(reject=dict(eeg=100e-6))
even_ids = epochs.event_id
even_ids.pop("New Segment/", None)
fig, ax = plt.subplots(1, 1, figsize=(10, 4), layout="tight")
mne.viz.plot_events(epochs.events, sfreq=1000, event_id=even_ids, axes=ax)
ax.get_legend().remove()
ax.spines[["right", "top"]].set_visible(False)

In [None]:
## balance trials across the 4 classes in training
rnd_ids = [key for key in even_ids if key.endswith("rndm")]
eps_list = [epochs[rnd_id] for rnd_id in rnd_ids]
mne.epochs.equalize_epoch_counts(eps_list, method="mintime")
epochs_rnd = mne.concatenate_epochs(eps_list)

ord_ids = [key for key in even_ids if key.endswith("or")]
epochs_ord = epochs[ord_ids]

## compute covariance matrix from rnd epochs
cov = mne.compute_covariance(epochs_rnd, tmax=0.0)

## define epochs
ids = range(1, 5) 
epochs_rnd_std = epochs_rnd[[f"f{i}_std_rndm" for i in ids]]
epochs_rnd_tin = epochs_rnd[[f"f{i}_tin_rndm" for i in ids]]

epochs_ord_std = epochs_ord[[f"f{i}_std_or" for i in ids]]
epochs_ord_tin = epochs_ord[[f"f{i}_tin_or" for i in ids]]

del epochs_ord, epochs_rnd

test clfs on ordered trials

In [None]:
fig, ax = plt.subplots()
ax.plot(epochs_rnd_std.times, np.diag(avg_scores), label="score")
ax.axhline(0.25, color="k", linestyle="--", label="chance")
ax.set_xlabel("Times")
ax.set_ylabel("AUC")
ax.legend()
ax.axvline(0.0, color="k", linestyle="-")
ax.set_title("Decoding random prestimulus")

In [None]:
fig, ax = plt.subplots(1, 1)
im = ax.imshow(
    avg_scores,
    interpolation="lanczos",
    origin="lower",
    cmap="RdBu_r",
    extent=epochs_rnd_std.times[[0, -1, 0, -1]],
    vmin=0.0,
    vmax=0.4,
)
ax.set_xlabel("Testing Time (s)")
ax.set_ylabel("Training Time (s)")
ax.set_title("Temporal generalization")
ax.axvline(0, color="k")
ax.axhline(0, color="k")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("AUC")

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from mne.viz import plot_events
from mne.datasets import fetch_fsaverage
from mne.minimum_norm import make_inverse_operator, apply_inverse

from mne import (
                    read_epochs, 
                    concatenate_epochs,
                    open_report,
                    compute_covariance,
                    make_forward_solution
                )

from mne.decoding import (
                            GeneralizingEstimator,
                            LinearModel,
                            Scaler,
                            Vectorizer,
                            cross_val_multiscore,
                            get_coef,
                            )

In [None]:
def run_source_analysis(coef_patt, epochs):

    epochs.set_eeg_reference("average", projection=True)
    evokeds = []
    for i_cls in range(4):
        evokeds.append(
                        mne.EvokedArray(coef_patt[:, i_cls, :],
                                        epochs.info,
                                        tmin=epochs.times[0])
                        )

    noise_cov = compute_covariance(epochs, tmax=0.0)
    kwargs = {
                "subject": "fsaverage",
                "subjects_dir": None
            }

    fs_dir = fetch_fsaverage()
    trans = fs_dir / "bem" / "fsaverage-trans.fif"
    src = fs_dir / "bem" / "fsaverage-ico-5-src.fif"
    bem = fs_dir / "bem" / "fsaverage-5120-5120-5120-bem-sol.fif"

    fwd = make_forward_solution(
                                epochs.info,
                                trans=trans,
                                src=src,
                                bem=bem,
                                meg=False,
                                eeg=True
                                )
    inv = make_inverse_operator(
                                epochs.info,
                                fwd,
                                noise_cov
                                )
    stcs = []
    for evoked in evokeds:
        stcs.append(
                    apply_inverse(
                            evoked, 
                            inv,
                            lambda2=1.0 / 9.0,
                            method="dSPM",
                            pick_ori="normal"
                            )
                    )
    
    del fwd, inv
    return stcs

In [None]:
def decode(subject, saving_dir, epochs_rnd_std, epochs_rnd_tin, epochs_ord_std, epochs_ord_tin):
    
    ###### train clf on random trials
    n_splits = 2
    scores_dir = saving_dir / "scores"
    coeffs_dir = saving_dir / "coeffs"
    stcs_dir = saving_dir / "stcs"
    [sel_dir.mkdir(exist_ok=True) for sel_dir in [scores_dir, coeffs_dir, stcs_dir]]

    labels = ["standard", "tinnitus"]
    for epochs_rnd, epochs_ord, label in zip([epochs_rnd_std, epochs_rnd_tin], [epochs_ord_std, epochs_ord_tin], labels):

        post_mask = epochs_rnd.times >= 0
        pre_mask  = epochs_rnd.times < 0
        X = epochs_rnd.get_data()
        y = epochs_rnd.events[:, 2]

        X_post_rnd = X[:, :, post_mask]
        X_pre_rnd = X[:, :, pre_mask]

        ## define and fit generalization object
        clf = make_pipeline(
                            Scaler(epochs_rnd.info),
                            Vectorizer(),           
                            LinearModel(LinearDiscriminantAnalysis(solver="svd"))
                            )
        gen = GeneralizingEstimator(clf, scoring="accuracy", n_jobs=1, verbose=True)

        ## train post -> test post with fit (to extract weights)
        gen.fit(X_post_rnd, y)
        coef_filt = get_coef(gen, "filters_", inverse_transform=False)
        coef_patt = get_coef(gen, "patterns_", inverse_transform=True)[0] # (n_chs, n_class, n_time)

        np.save(coeffs_dir / f"{subject}_rnd_params_{label}.npy", coef_filt)
        np.save(coeffs_dir / f"{subject}_rnd_patterns_{label}.npy", coef_patt)

        ## train post -> test post with cv
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        scores_post_post = cross_val_multiscore(gen, X_post_rnd, y, cv=cv, n_jobs=1)

        ## train post -> test pre
        scores_post_pre = []
        for train_idx, test_idx in cv.split(X_post_rnd, y):
            gen.fit(X_post_rnd[train_idx], y[train_idx]) # train on post
            score = gen.score(X_pre_rnd[test_idx], y[test_idx]) # test on pre
            scores_post_pre.append(score)
        scores_post_pre = np.array(scores_post_pre)

        ## save scores and coeffs 
        np.save(scores_dir / f"{subject}_rnd_post2post_{label}.npy", scores_post_post)
        np.save(scores_dir / f"{subject}_rnd_post2pre_{label}.npy", scores_post_pre)

        ## source space decoding for random post
        stcs = run_source_analysis(coef_patt, epochs_rnd)
    return stcs

In [None]:
from pathlib import Path
saving_dir = Path("/Volumes/Extreme_SSD/payam_data/Tinreg")
stcs = decode("gwld", saving_dir, epochs_rnd_std, epochs_rnd_tin, epochs_ord_std, epochs_ord_tin)

In [None]:
import json

with open('/Users/payamsadeghishabestari/TinReg/sample/subjects_dict.json', 'r') as f:
    mapping = json.load(f)

ssps = []
for subject in subjects:
    if subject in mapping:
        ssps.append(mapping[subject])
    else:
        ssps.append(False)

In [None]:
ssps = []
for subject in subjects:
    if subject in mapping:
        ssps.append(mapping[subject])
    else:
        ssps.append(False)

{'dvob': True,
 'mpuf': True,
 'xrtt': True,
 'vfav': True,
 'pfyq': True,
 'jgdi': True,
 'xweo': True,
 'gwld': True,
 'typy': False,
 'lqzz': False,
 'rojk': False,
 'lwgd': False,
 'ydat': False,
 'swon': False,
 'zjee': False,
 'onne': False,
 'qtsq': False,
 'ztsi': False,
 'zdfy': False,
 'jnvs': False,
 'udul': False,
 'bkai': False,
 'nrjq': False,
 'euyi': False,
 'tzdz': False,
 'tioe': False,
 'dyqe': False,
 'asjt': False,
 'dsno': False,
 'spoh': False,
 'ynkf': False,
 'exoi': False,
 'vuio': False,
 'nxyw': False,
 'mkbb': False,
 'bctw': False,
 'lxot': False,
 'tmlj': False,
 'sgwy': False,
 'xcwq': False}