In [1]:
%matplotlib qt
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mne

Load data

In [100]:
main_dir = Path("/Volumes/Extreme_SSD/payam_data/Tinreg")
epochs_dir = main_dir / "epochs"
scores_dir = main_dir / "scores"
coeffs_dir = main_dir / "coeffs"
stcs_dir = main_dir / "stcs"

df_subjects = pd.read_csv("../sample/tinreg_master.csv")
df_subjects["ID"] = df_subjects["ID"].str.lower()
df_subjects.dropna(inplace=True)
df_subjects.set_index("ID", inplace=True)
df_subjects["group"] = df_subjects["group"].map({"T": "Tinnitus", "C": "Control"})

cmap1 = sns.cubehelix_palette(rot=-.2).as_hex()
cmap2 = sns.cubehelix_palette(gamma=.5).as_hex()

In [10]:
## How does the events look like?

fname = epochs_dir / "asjt-epo.fif"
epochs_sample = mne.read_epochs(fname, verbose=False)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), layout="tight")
mne.viz.plot_events(
                    epochs_sample.events,
                    sfreq=epochs_sample.info["sfreq"],
                    event_id=epochs_sample.event_id,
                    axes=ax
                    )
ax.spines[["right", "top"]].set_visible(False)
ax.get_legend().remove()

  mne.viz.plot_events(


In [None]:
## How does the ERPs look like?

mod1 = "tin"
mod2 = "rndm"
keys = [key for key in list(epochs_sample.event_id.keys()) if key.endswith(f"{mod1}_{mod2}")]

evs_co, evs_ti = [], []
for fname in epochs_dir.iterdir():
    if fname.stem.startswith("."):
        continue
    subject = fname.stem[:4]
    group = df_subjects.loc[subject, "group"]
    if group == "Control":
        evs_co.append(mne.read_epochs(fname, verbose=False)[keys].average())
    if group == "Tinnitus":
        evs_ti.append(mne.read_epochs(fname, verbose=False)[keys].average())

for mod in [evs_co, evs_ti]:
    for ev in mod:
        ev.drop_channels(["O1", "O2", "PO7", "PO8"], on_missing="warn")

evs_dict = {"Control": evs_co,
            "Tinnitus": evs_ti}

fig = mne.viz.plot_compare_evokeds(evs_dict, truncate_xaxis=False, time_unit='ms',
                                    colors={"Control": cmap1[3], "Tinnitus": cmap2[5]},
                                    picks="Cz", ylim=dict(eeg=[-1.5, 1.5]),
                                    title=f"{mod1} frequency & {mod2}")

for coll in fig[0].axes[0].collections:
    coll.set_alpha(0.1)
leg = fig[0].axes[0].get_legend()
leg.set_frame_on(False)
fig[0].set_size_inches(8, 4)

for line in fig[0].axes[0].get_lines():
    line.set_linewidth(2.5)


Check Decodings

In [69]:
subjects = np.unique([fname.stem[:4] for fname in scores_dir.iterdir()])
mode_1 = "rnd"
mode_2 = "standard"
scores_dict = {}
groups_dict = {}

for subject in subjects:
    fname_po2po = scores_dir / f"{subject}_{mode_1}_post2post_{mode_2}.npy"
    score_po2po = np.load(fname_po2po, allow_pickle=True) # n_cv, n_time_post, n_time_post (5, 50, 50)
    
    fname_po2pr = scores_dir / f"{subject}_{mode_1}_post2pre_{mode_2}.npy"
    score_po2pr = np.load(fname_po2pr, allow_pickle=True) # n_cv, n_time_post, n_time_post (5, 50, 40)
    
    groups_dict[subject] = df_subjects.loc[subject, "group"]
    scores_dict[subject] = np.concatenate(
                                        [np.mean(score_po2pr, axis=0),
                                        np.mean(score_po2po, axis=0)],
                                        axis=-1
                                        )

scores_co = []
scores_ti = []
for subject, group in groups_dict.items():
    if group == "Control":
        scores_co.append(scores_dict[subject])
    if group == "Tinnitus":
        scores_ti.append(scores_dict[subject])

scores_co = np.array(scores_co)
scores_ti = np.array(scores_ti)    

In [112]:
## 1D plot
diag_po2po_co = np.array([np.diag(s) for s in scores_co[:, :, 40:]])
diag_po2pr_co = np.array([np.diag(s) for s in scores_co[:, :, :40]])
diag_po2po_ti = np.array([np.diag(s) for s in scores_ti[:, :, 40:]])
diag_po2pr_ti = np.array([np.diag(s) for s in scores_ti[:, :, :40]])


diag_co = np.concatenate([diag_po2pr_co, diag_po2po_co], axis=1)
diag_ti = np.concatenate([diag_po2pr_ti, diag_po2po_ti], axis=1)
mean1, std1 = diag_co.mean(axis=0), diag_co.std(axis=0)
mean2, std2 = diag_ti.mean(axis=0), diag_ti.std(axis=0)

fig, ax = plt.subplots(1, 1, layout="tight", figsize=(9, 4))
ax.plot(epochs.times, mean1, color=cmap1[3], label="Control", lw=2)
ax.fill_between(epochs.times, mean1-std1, mean1+std1, alpha=0.1, color=cmap1[3])
ax.plot(epochs.times, mean2, color=cmap2[5], label="Tinnitus", lw=2)
ax.fill_between(epochs.times, mean2-std2, mean2+std2, alpha=0.1, color=cmap2[5])

ax.vlines(0, 0.22, 0.28, linestyles="--", color="k")
ax.hlines(0.25, -0.4, 0.5, linestyles="--", color="grey")
ax.set_ylim([0.22, 0.28])
ax.set_xlim([-0.4, 0.5])
ax.spines[["right", "top"]].set_visible(False)
ax.legend(frameon=False)
ax.set_ylabel("Decoding accuracy")
ax.set_title("Decoding on Random Trials")
ax.set_xlabel("Time")

Text(0.5, 0, 'Time')

In [117]:
## 2D plot
fig, ax = plt.subplots(1, 1, layout="tight", figsize=(9, 4))
im = ax.imshow(
    scores_co.mean(axis=0),
    interpolation="lanczos",
    origin="lower",
    cmap="RdBu_r",
    extent=epochs.times[[0, -1, 40, -1]],
    vmin=0.235,
    vmax=0.265,
)
ax.set_xlabel("Testing Time (s)")
ax.set_ylabel("Training Time (s)")
ax.set_title("Trained on random trials")
ax.axvline(0, color="k")
ax.axhline(0, color="k")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Decoding accuracy")