In [1]:
%matplotlib qt
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mne

Load data

In [6]:
main_dir = Path("/Volumes/Extreme_SSD/payam_data/Tinreg")
epochs_dir = main_dir / "epochs"
scores_dir = main_dir / "scores"
coeffs_dir = main_dir / "coeffs"
stcs_dir = main_dir / "stcs"

df_subjects = pd.read_csv("../sample/tinreg_master.csv")
df_subjects["ID"] = df_subjects["ID"].str.lower()
df_subjects.dropna(inplace=True)
df_subjects.set_index("ID", inplace=True)
df_subjects["group"] = df_subjects["group"].map({"T": "Tinnitus", "C": "Control"})

cmap1 = sns.cubehelix_palette(rot=-.2).as_hex()
cmap2 = sns.cubehelix_palette(gamma=.5).as_hex()
subjects = np.unique([fname.stem[:4] for fname in scores_dir.iterdir()])

In [7]:
## How does the events look like?

fname = epochs_dir / "asjt-epo.fif"
epochs_sample = mne.read_epochs(fname, verbose=False)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), layout="tight")
mne.viz.plot_events(
                    epochs_sample.events,
                    sfreq=epochs_sample.info["sfreq"],
                    event_id=epochs_sample.event_id,
                    axes=ax
                    )
ax.spines[["right", "top"]].set_visible(False)
ax.get_legend().remove()

  mne.viz.plot_events(


In [10]:
## How does the ERPs look like?

mod1 = "std"
mod2 = "rndm"
keys = [key for key in list(epochs_sample.event_id.keys()) if key.endswith(f"{mod1}_{mod2}")]

evs_co, evs_ti = [], []
for fname in epochs_dir.iterdir():
    if fname.stem.startswith("."):
        continue
    subject = fname.stem[:4]
    group = df_subjects.loc[subject, "group"]
    if group == "Control":
        evs_co.append(mne.read_epochs(fname, verbose=False)[keys].average())
    if group == "Tinnitus":
        evs_ti.append(mne.read_epochs(fname, verbose=False)[keys].average())

for mod in [evs_co, evs_ti]:
    for ev in mod:
        ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")

evs_dict = {"Control": evs_co,
            "Tinnitus": evs_ti}

fig = mne.viz.plot_compare_evokeds(evs_dict, truncate_xaxis=False, time_unit='ms',
                                    colors={"Control": cmap1[3], "Tinnitus": cmap2[5]},
                                    picks="Cz", ylim=dict(eeg=[-1.5, 1.5]),
                                    title=f"{mod1} frequency & {mod2}")

for coll in fig[0].axes[0].collections:
    coll.set_alpha(0.1)
leg = fig[0].axes[0].get_legend()
leg.set_frame_on(False)
fig[0].set_size_inches(8, 4)

for line in fig[0].axes[0].get_lines():
    line.set_linewidth(2.5)


  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")
  ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")


check

In [26]:
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import RidgeClassifierCV
from sklearn.svm import LinearSVC

import mne
from mne.datasets import sample
from mne.decoding import (
    CSP,
    GeneralizingEstimator,
    LinearModel,
    Scaler,
    SlidingEstimator,
    Vectorizer,
    cross_val_multiscore,
    get_coef,
)

In [27]:
all_scores = []
for subject in subjects:
    group = df_subjects.loc[subject, "group"]
    if group == "Control":
        fname = epochs_dir / f"{subject}-epo.fif"
        epochs = mne.read_epochs(fname)
        X = epochs[keys].get_data(picks="eeg")
        y = epochs[keys].events[:, 2]

        clf = make_pipeline(
                            StandardScaler(),
                            LinearModel(LinearSVC(C=1.0))
                            )
        time_gen = GeneralizingEstimator(clf, n_jobs=None, scoring="accuracy", verbose=True)
        scores = cross_val_multiscore(time_gen, X, y, cv=2, n_jobs=None)
        all_scores.append(np.mean(scores, axis=0))


Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/bctw-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/bkai-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/dyqe-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/gwld-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/jgdi-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/lxot-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/pfyq-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6001 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/rojk-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
5965 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/sgwy-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/tmlj-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/tzdz-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/udul-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/90 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/8100 [00:00<?,       ?it/s]

In [None]:
scores = np.mean(all_scores, axis=0)
scores.shape

In [31]:
fig, ax = plt.subplots()
ax.plot(epochs.times, np.diag(scores), label="score")
ax.axhline(0.25, color="k", linestyle="--", label="chance")
ax.set_xlabel("Times")
ax.set_ylabel("AUC")
ax.legend()
ax.axvline(0.0, color="k", linestyle="-")
ax.set_title("Decoding MEG sensors over time")

Text(0.5, 1.0, 'Decoding MEG sensors over time')

In [32]:
fig, ax = plt.subplots(1, 1)
im = ax.imshow(
    scores,
    interpolation="lanczos",
    origin="lower",
    cmap="RdBu_r",
    extent=epochs.times[[0, -1, 0, -1]],
    vmin=0.14,
    vmax=0.36,
)
ax.set_xlabel("Testing Time (s)")
ax.set_ylabel("Training Time (s)")
ax.set_title("Temporal generalization")
ax.axvline(0, color="k")
ax.axhline(0, color="k")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("AUC")

Trained and tested on random trials

In [33]:
mode_1 = "rnd"
mode_2 = "standard"
scores_dict = {}
groups_dict = {}

for subject in subjects:
    fname_po2po = scores_dir / f"{subject}_{mode_1}_post2post_{mode_2}.npy"
    score_po2po = np.load(fname_po2po, allow_pickle=True) # n_cv, n_time_post, n_time_post (5, 50, 50)
    
    fname_po2pr = scores_dir / f"{subject}_{mode_1}_post2pre_{mode_2}.npy"
    score_po2pr = np.load(fname_po2pr, allow_pickle=True) # n_cv, n_time_post, n_time_post (5, 50, 40)
    
    groups_dict[subject] = df_subjects.loc[subject, "group"]
    scores_dict[subject] = np.concatenate(
                                        [np.mean(score_po2pr, axis=0),
                                        np.mean(score_po2po, axis=0)],
                                        axis=-1
                                        )

scores_co = []
scores_ti = []
for subject, group in groups_dict.items():
    if group == "Control":
        scores_co.append(scores_dict[subject])
    if group == "Tinnitus":
        scores_ti.append(scores_dict[subject])

scores_co = np.array(scores_co)
scores_ti = np.array(scores_ti)    

In [34]:
## 1D plot
diag_po2po_co = np.array([np.diag(s) for s in scores_co[:, :, 40:]])
diag_po2pr_co = np.array([np.diag(s) for s in scores_co[:, :, :40]])
diag_po2po_ti = np.array([np.diag(s) for s in scores_ti[:, :, 40:]])
diag_po2pr_ti = np.array([np.diag(s) for s in scores_ti[:, :, :40]])


diag_co = np.concatenate([diag_po2pr_co, diag_po2po_co], axis=1)
diag_ti = np.concatenate([diag_po2pr_ti, diag_po2po_ti], axis=1)
mean1, std1 = diag_co.mean(axis=0), diag_co.std(axis=0)
mean2, std2 = diag_ti.mean(axis=0), diag_ti.std(axis=0)

fig, ax = plt.subplots(1, 1, layout="tight", figsize=(9, 4))
ax.plot(epochs.times, mean1, color=cmap1[3], label="Control", lw=2)
ax.fill_between(epochs.times, mean1-std1, mean1+std1, alpha=0.1, color=cmap1[3])
ax.plot(epochs.times, mean2, color=cmap2[5], label="Tinnitus", lw=2)
ax.fill_between(epochs.times, mean2-std2, mean2+std2, alpha=0.1, color=cmap2[5])

ax.vlines(0, 0.22, 0.28, linestyles="--", color="k")
ax.hlines(0.25, -0.4, 0.5, linestyles="--", color="grey")
ax.set_ylim([0.22, 0.28])
ax.set_xlim([-0.4, 0.5])
ax.spines[["right", "top"]].set_visible(False)
ax.legend(frameon=False)
ax.set_ylabel("Decoding accuracy")
ax.set_title("Decoding on Random Trials")
ax.set_xlabel("Time")

Text(0.5, 0, 'Time')

In [35]:
## 2D plot
fig, ax = plt.subplots(1, 1, layout="tight", figsize=(9, 4))
im = ax.imshow(
    scores_co.mean(axis=0),
    interpolation="lanczos",
    origin="lower",
    cmap="RdBu_r",
    extent=epochs.times[[0, -1, 40, -1]],
    vmin=0.235,
    vmax=0.265,
)
ax.set_xlabel("Testing Time (s)")
ax.set_ylabel("Training Time (s)")
ax.set_title("Trained on random trials")
ax.axvline(0, color="k")
ax.axhline(0, color="k")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Decoding accuracy")

Trained on Random trials and tested on Ordered trials

In [36]:
subjects = np.unique([fname.stem[:4] for fname in scores_dir.iterdir()])
mode_1 = "ord"
mode_2 = "standard"
scores_dict = {}
groups_dict = {}

for subject in subjects:
    fname_po2po = scores_dir / f"{subject}_{mode_1}_post2post_{mode_2}.npy"
    score_po2po = np.load(fname_po2po, allow_pickle=True) # n_cv, n_time_post, n_time_post (5, 50, 50)
    
    fname_po2pr = scores_dir / f"{subject}_{mode_1}_post2pre_{mode_2}.npy"
    score_po2pr = np.load(fname_po2pr, allow_pickle=True) # n_cv, n_time_post, n_time_post (5, 50, 40)
    
    groups_dict[subject] = df_subjects.loc[subject, "group"]
    scores_dict[subject] = np.concatenate([score_po2pr, score_po2po], axis=-1)

scores_co = []
scores_ti = []
for subject, group in groups_dict.items():
    if group == "Control":
        scores_co.append(scores_dict[subject])
    if group == "Tinnitus":
        scores_ti.append(scores_dict[subject])

scores_co = np.array(scores_co)
scores_ti = np.array(scores_ti)  

In [37]:
scores_co

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [None]:
## 1D plot
diag_po2po_co = np.array([np.diag(s) for s in scores_co[:, :, 40:]])
diag_po2pr_co = np.array([np.diag(s) for s in scores_co[:, :, :40]])
diag_po2po_ti = np.array([np.diag(s) for s in scores_ti[:, :, 40:]])
diag_po2pr_ti = np.array([np.diag(s) for s in scores_ti[:, :, :40]])


diag_co = np.concatenate([diag_po2pr_co, diag_po2po_co], axis=1)
diag_ti = np.concatenate([diag_po2pr_ti, diag_po2po_ti], axis=1)
mean1, std1 = diag_co.mean(axis=0), diag_co.std(axis=0)
mean2, std2 = diag_ti.mean(axis=0), diag_ti.std(axis=0)

fig, ax = plt.subplots(1, 1, layout="tight", figsize=(9, 4))
ax.plot(epochs.times, mean1, color=cmap1[3], label="Control", lw=2)
ax.fill_between(epochs.times, mean1-std1, mean1+std1, alpha=0.1, color=cmap1[3])
ax.plot(epochs.times, mean2, color=cmap2[5], label="Tinnitus", lw=2)
ax.fill_between(epochs.times, mean2-std2, mean2+std2, alpha=0.1, color=cmap2[5])

ax.vlines(0, 0.22, 0.28, linestyles="--", color="k")
ax.hlines(0.25, -0.4, 0.5, linestyles="--", color="grey")
# ax.set_ylim([0.22, 0.28])
ax.set_xlim([-0.4, 0.5])
ax.spines[["right", "top"]].set_visible(False)
ax.legend(frameon=False)
ax.set_ylabel("Decoding accuracy")
ax.set_title("Decoding on Random Trials")
ax.set_xlabel("Time")

Check why they are all 0

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from mne.viz import plot_events
from mne.datasets import fetch_fsaverage
from mne.minimum_norm import make_inverse_operator, apply_inverse

from mne import (
                    read_epochs, 
                    concatenate_epochs,
                    open_report,
                    compute_covariance,
                    make_forward_solution
                )

from mne.decoding import (
                            GeneralizingEstimator,
                            LinearModel,
                            Scaler,
                            Vectorizer,
                            cross_val_multiscore,
                            get_coef,
                            )


def split_epochs(subject, saving_dir):

    ## check paths
    ep_fname = saving_dir / "epochs" / f"{subject}-epo.fif"
    re_fname1 = saving_dir / "reports" / f"{subject}-report.h5" 
    re_fname2 = saving_dir / "reports" / f"{subject}-report.html"
    overwrite = True
    
    # if re_fname2.exists():
    #     return None

    ## read and modify epochs/report
    epochs = read_epochs(ep_fname, preload=True)
    report = open_report(re_fname1)
    sfreq = epochs.info["sfreq"]
    epochs.pick(picks="eeg")
    epochs.drop_bad(reject=dict(eeg=100e-6)) # maybe totally remove this
    even_ids = epochs.event_id
    even_ids.pop("New Segment/", None)

    fig_events, ax = plt.subplots(1, 1, figsize=(10, 4), layout="tight")
    plot_events(epochs.events, sfreq=sfreq, event_id=even_ids, axes=ax, show=False)
    ax.get_legend().remove()
    ax.spines[["right", "top"]].set_visible(False)
    fig_drop = epochs.plot_drop_log(show=False)

    report.add_figure(fig=fig_events, title="Events", image_format="PNG")
    report.add_figure(fig=fig_drop, title="Drop log", image_format="PNG")

    ## balance trials across the 4 classes in training
    rnd_ids = [key for key in even_ids if key.endswith("rndm")]
    eps_list = [epochs[rnd_id] for rnd_id in rnd_ids]
    mne.epochs.equalize_epoch_counts(eps_list, method="mintime")
    epochs_rnd = concatenate_epochs(eps_list)

    ord_ids = [key for key in even_ids if key.endswith("or")]
    epochs_ord = epochs[ord_ids]

    report.add_epochs(epochs_rnd, title="Random trials info", psd=False, projs=False)
    report.add_epochs(epochs_ord, title="Ordered trials info", psd=False, projs=False)

    ## compute covariance matrix from rnd epochs
    cov = compute_covariance(epochs_rnd, tmax=0.0)

    ## define epochs
    ids = range(1, 5) 
    epochs_rnd_std = epochs_rnd[[f"f{i}_std_rndm" for i in ids]]
    epochs_rnd_tin = epochs_rnd[[f"f{i}_tin_rndm" for i in ids]]

    epochs_ord_std = epochs_ord[[f"f{i}_std_or" for i in ids]]
    epochs_ord_tin = epochs_ord[[f"f{i}_tin_or" for i in ids]]

    # report.save(saving_dir / "reports" / f"{subject}-report.html",
    #             overwrite=overwrite, open_browser=False)

    del epochs_ord, epochs_rnd

    return epochs_rnd_std, epochs_rnd_tin, epochs_ord_std, epochs_ord_tin




def decode(subject, saving_dir, epochs_rnd_std, epochs_rnd_tin, epochs_ord_std, epochs_ord_tin):
    
    ###### train clf on random trials
    n_splits = 5
    scores_dir = saving_dir / "scores"
    coeffs_dir = saving_dir / "coeffs"
    stcs_dir = saving_dir / "stcs"
    [sel_dir.mkdir(exist_ok=True) for sel_dir in [scores_dir, coeffs_dir, stcs_dir]]

    labels = ["standard", "tinnitus"]
    for epochs_rnd, epochs_ord, label in zip([epochs_rnd_std, epochs_rnd_tin], [epochs_ord_std, epochs_ord_tin], labels):

        post_mask = epochs_rnd.times >= 0
        pre_mask  = epochs_rnd.times < 0
        X = epochs_rnd.get_data()
        y = epochs_rnd.events[:, 2]

        X_post_rnd = X[:, :, post_mask]
        X_pre_rnd = X[:, :, pre_mask]

        ## define and fit generalization object
        clf = make_pipeline(
                            Scaler(epochs_rnd.info),
                            Vectorizer(),           
                            LinearModel(LinearDiscriminantAnalysis(solver="svd"))
                            )
        gen = GeneralizingEstimator(clf, scoring="accuracy", n_jobs=1, verbose=True)

        ## train post -> test post with fit (to extract weights)
        gen.fit(X_post_rnd, y)
        coef_filt = get_coef(gen, "filters_", inverse_transform=False)
        coef_patt = get_coef(gen, "patterns_", inverse_transform=True)[0] # (n_chs, n_class, n_time)

        # np.save(coeffs_dir / f"{subject}_rnd_params_{label}.npy", coef_filt)
        # np.save(coeffs_dir / f"{subject}_rnd_patterns_{label}.npy", coef_patt)

        ## train post -> test post with cv
        cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
        scores_post_post = cross_val_multiscore(gen, X_post_rnd, y, cv=cv, n_jobs=1)

        ## train post -> test pre
        scores_post_pre = []
        for train_idx, test_idx in cv.split(X_post_rnd, y):
            gen.fit(X_post_rnd[train_idx], y[train_idx]) # train on post
            score = gen.score(X_pre_rnd[test_idx], y[test_idx]) # test on pre
            scores_post_pre.append(score)
        scores_post_pre = np.array(scores_post_pre)

        ## save scores and coeffs 
        # np.save(scores_dir / f"{subject}_rnd_post2post_{label}.npy", scores_post_post)
        # np.save(scores_dir / f"{subject}_rnd_post2pre_{label}.npy", scores_post_pre)

        ## source space decoding for random post
        # stcs = run_source_analysis(coef_patt, epochs_rnd)
        # [stc.save(stcs_dir / f"{subject}_rnd_class_{label}_{stc_idx + 1}") for stc_idx, stc in enumerate(stcs)]

        ## test on ordered tones
        X_ord = epochs_ord.get_data()
        y_ord = epochs_ord.events[:, 2]

        times_ord = epochs_ord.times
        post_mask_ord = epochs_ord.times >= 0
        pre_mask_ord  = epochs_ord.times < 0

        X_ord_post = X_ord[:, :, post_mask_ord]
        X_ord_pre  = X_ord[:, :, pre_mask_ord]

        

        gen.fit(X_post_rnd, y) # train again on random

        ## scores and coeffs
        if label == "standard":
            mapping_ord_2_rnd = dict(zip(range(1, 5), range(5, 9)))
        if label == "tinnitus":
            mapping_ord_2_rnd = dict(zip(range(11, 15), range(15, 19)))
        
        y_ord_mapped = np.array([mapping_ord_2_rnd[val] for val in y_ord])

        score_ord_post = gen.score(X_ord_post, y_ord_mapped)
        score_ord_pre = gen.score(X_ord_pre, y_ord_mapped)

        coef_filt_ord = get_coef(gen, "filters_", inverse_transform=False) # (n_chs, n_class, n_time)
        coef_patt_ord = get_coef(gen, "patterns_", inverse_transform=True)[0] # (n_chs, n_class, n_time) # check this later

        return score_ord_pre, score_ord_post, coef_filt_ord, coef_patt_ord

        ## save in numpy array  
        # np.save(scores_dir / f"{subject}_ord_post2post_{label}.npy", score_ord_post)
        # np.save(scores_dir / f"{subject}_ord_post2pre_{label}.npy", score_ord_pre)

        # np.save(coeffs_dir / f"{subject}_ord_params_{label}.npy", coef_filt_ord)
        # np.save(coeffs_dir / f"{subject}_ord_patterns_{label}.npy", coef_patt_ord)


def run_source_analysis(coef_patt, epochs):

    epochs.set_eeg_reference("average", projection=True)
    evokeds = []
    for i_cls in range(4):
        evokeds.append(
                        mne.EvokedArray(coef_patt[:, i_cls, :],
                                        epochs.info,
                                        tmin=epochs.times[0])
                        )

    noise_cov = compute_covariance(epochs, tmax=0.0)
    kwargs = {
                "subject": "fsaverage",
                "subjects_dir": None
            }

    fs_dir = fetch_fsaverage()
    trans = fs_dir / "bem" / "fsaverage-trans.fif"
    src = fs_dir / "bem" / "fsaverage-ico-5-src.fif"
    bem = fs_dir / "bem" / "fsaverage-5120-5120-5120-bem-sol.fif"

    fwd = make_forward_solution(
                                epochs.info,
                                trans=trans,
                                src=src,
                                bem=bem,
                                meg=False,
                                eeg=True
                                )
    inv = make_inverse_operator(
                                epochs.info,
                                fwd,
                                noise_cov
                                )
    stcs = []
    for evoked in evokeds:
        stcs.append(
                    apply_inverse(
                            evoked, 
                            inv,
                            lambda2=1.0 / 9.0,
                            method="dSPM",
                            pick_ori="normal"
                            )
                    )
    
    del fwd, inv
    return stcs

In [46]:
subject = "tzdz"
saving_dir = Path("/Volumes/Extreme_SSD/payam_data/Tinreg")
epochs_rnd_std, epochs_rnd_tin, epochs_ord_std, epochs_ord_tin = split_epochs(subject, saving_dir)

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/tzdz-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated
Embedding : jquery-3.6.0.min.js
Embedding : bootstrap.bundle.min.js
Embedding : bootstrap.min.css
Embedding : bootstrap-table/bootstrap-table.min.js
Embedding : bootstrap-table/bootstrap-table.min.css
Embedding : bootstrap-table/bootstrap-table-copy-rows.min.js
Embedding : bootstrap-table/bootstrap-table-export.min.js
Embedding : bootstrap-table/tableExport.min.js
Embedding : bootstrap-icons/bootstrap-icons.mne.min.css
Embedding : highlightjs/highlight.min.js
Embedding : highlightjs/atom-one-dark-reasonable.min.css
    Rejecting  epoch based on EEG : ['Fp1', 'TP10']
    Rejecting  epoch based on EEG : ['Fp1']
    Rejecting  epoch based on EEG : ['Fp1', 'AF7']
    Rejecting  epoch based on EEG 

  plot_events(epochs.events, sfreq=sfreq, event_id=even_ids, axes=ax, show=False)


Dropped 58 epochs: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57
Dropped 12 epochs: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
Dropped 13 epochs: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
Dropped 29 epochs: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28
Dropped 27 epochs: 34, 69, 70, 72, 73, 75, 76, 108, 109, 110, 112, 114, 139, 146, 152, 153, 154, 155, 173, 241, 242, 244, 288, 289, 291, 292, 301
Dropped 41 epochs: 101, 102, 103, 104, 107, 150, 176, 177, 192, 193, 194, 195, 197, 199, 200, 201, 202, 207, 209, 211, 214, 215, 216, 218, 220, 228, 229, 299, 300, 304, 305, 307, 308, 310, 313, 314, 318, 322, 323, 325, 328
Dropped 0 epochs: 
Dropped 17 epochs: 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 107, 115, 118, 119
Not setting metadata
2696 matching events 

  epochs_rnd = concatenate_epochs(eps_list)
  cov = compute_covariance(epochs_rnd, tmax=0.0)


Reducing data rank from 63 -> 63
Estimating covariance using EMPIRICAL
Done.
Number of samples used : 110536
[done]


In [87]:
epochs = mne.read_epochs("/Volumes/Extreme_SSD/payam_data/Tinreg/epochs/tzdz-epo.fif")
epochs.event_id

Reading /Volumes/Extreme_SSD/payam_data/Tinreg/epochs/tzdz-epo.fif ...
    Found the data of interest:
        t =    -400.00 ...     490.00 ms
        0 CTF compensation matrices available
Not setting metadata
6000 matching events found
No baseline correction applied
0 projection items activated


{'f1_std_or': 1,
 'f2_std_or': 2,
 'f3_std_or': 3,
 'f4_std_or': 4,
 'f1_std_rndm': 5,
 'f2_std_rndm': 6,
 'f3_std_rndm': 7,
 'f4_std_rndm': 8,
 'f1_tin_or': 11,
 'f2_tin_or': 12,
 'f3_tin_or': 13,
 'f4_tin_or': 14,
 'f1_tin_rndm': 15,
 'f2_tin_rndm': 16,
 'f3_tin_rndm': 17,
 'f4_tin_rndm': 18}

In [80]:
epochs_ord_std.event_id

{'f1_std_or': 1, 'f2_std_or': 2, 'f3_std_or': 3, 'f4_std_or': 4}

In [84]:
score_ord_pre, score_ord_post, coef_filt_ord, coef_patt_ord = decode(subject, saving_dir, epochs_rnd_std, epochs_rnd_tin, epochs_ord_std, epochs_ord_tin)

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/2500 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/2500 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/2000 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/2000 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/2500 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/2000 [00:00<?,       ?it/s]

In [None]:
epochs

array([[[ 0.00366324,  0.00429946,  0.00040037, ..., -0.00091414,
          0.00046053,  0.00220295],
        [-0.00082582, -0.00062171, -0.00334085, ..., -0.00645047,
         -0.00556678, -0.00874295],
        [-0.00036778, -0.00195855,  0.00038022, ...,  0.00240001,
          0.00378135,  0.0010115 ],
        [-0.00246964, -0.0017192 ,  0.00256026, ...,  0.00496461,
          0.0013249 ,  0.0055285 ]],

       [[ 0.00535992,  0.00740735,  0.0062571 , ...,  0.00839391,
          0.00860925,  0.00755414],
        [ 0.00338569,  0.00163175,  0.00474385, ...,  0.00751926,
          0.0039333 ,  0.00362832],
        [-0.01501283, -0.0116235 , -0.01141998, ..., -0.01151534,
         -0.01161128, -0.01258841],
        [ 0.00626723,  0.0025844 ,  0.00041903, ..., -0.00439784,
         -0.00093128,  0.00140595]],

       [[-0.00479005, -0.00587304, -0.0033084 , ..., -0.00223297,
          0.00032441, -0.00037668],
        [ 0.00081993,  0.00019521, -0.00048077, ...,  0.00462582,
          0.

In [None]:
clf = make_pipeline(
                            Scaler(epochs_rnd_std.info),
                            Vectorizer(),           
                            LinearModel(LinearDiscriminantAnalysis(solver="svd"))
                            )
gen = GeneralizingEstimator(clf, scoring="accuracy", n_jobs=1, verbose=True)
gen.fit(X_post_rnd, y)

y_ord = epochs_ord_std.events[:, 2]
gen.score(X_ord_post, y_ord)

  0%|          | Fitting GeneralizingEstimator : 0/50 [00:00<?,       ?it/s]

0,1,2
,base_estimator,Pipeline(step...Analysis()))])
,scoring,'accuracy'
,n_jobs,1
,position,0
,allow_2d,False
,verbose,True

0,1,2
,info,<Info | 11 no...eq: 100.0 Hz >
,scalings,
,with_mean,True
,with_std,True

0,1,2
,model,LinearDiscriminantAnalysis()

0,1,2
,solver,'svd'
,shrinkage,
,priors,
,n_components,
,store_covariance,False
,tol,0.0001
,covariance_estimator,


In [72]:
X_ord_post.shape

(1471, 63, 50)

In [76]:
np.unique(y_ord)

array([1, 2, 3, 4], dtype=int32)

In [74]:
print(X_post_rnd.shape, y.shape)
print(X_ord_post.shape, y_ord.shape)


(1348, 63, 50) (1348,)
(1471, 63, 50) (1471,)
