In [None]:
%matplotlib qt
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne
from sklearn.pipeline import make_pipeline

from mne.decoding import (
    CSP,
    GeneralizingEstimator,
    LinearModel,
    Scaler,
    SlidingEstimator,
    Vectorizer,
    cross_val_multiscore,
    get_coef,
)
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

In [None]:
## read data
subject_id = "dvob"
fname = Path.cwd().parent.parent.parent / "subjects" / subject_id / "EEG" / "regularity" / "raw_prep.fif"
raw = mne.io.read_raw_fif(fname, preload=True)


## get events
events_orig, events_dict = mne.events_from_annotations(raw)
events = events_orig.copy()

## remove new segment ans s140 (???) from events and update the events_dict
for d_ev in ["New Segment/", "Stimulus/S140"]:
    if d_ev in events_dict:
        ns_id = events_dict[d_ev] 
        events = events[events[:, -1] != ns_id] 
        events_dict.pop(d_ev)

trigger_dict = {
                "f1_std_or": 1, "f2_std_or": 2, "f3_std_or": 3, "f4_std_or": 4,
                "f1_std_rndm": 5, "f2_std_rndm": 6, "f3_std_rndm": 7, "f4_std_rndm": 8,
                "f1_tin_or": 11, "f2_tin_or": 12, "f3_tin_or": 13, "f4_tin_or": 14,
                "f1_tin_rndm": 15, "f2_tin_rndm": 16, "f3_tin_rndm": 17, "f4_tin_rndm": 18
                } # copied from trigger definition script

events_dict_new = {}
for key, val in events_dict.items():
    for trig_id, trig_val in trigger_dict.items():
        if key.endswith(f" {trig_val}"):
            events_dict_new[trig_id] = val

diff_thr = 750
split_indices = np.where(np.diff(events[:, 0]) > diff_thr)[0] + 1
blocks = np.split(events, split_indices)

## some checks
assert len(blocks) == 12, f"Something fishy with blocking, got {len(blocks)} blocks instead of 12"
for block_idx, block in enumerate(blocks):
    assert len(block) == 500, f"Number of triggers in block {block_idx + 1} is {len(block)}, must be 500." 

## concatenating similar blocks
std_ord_block_idxs = [0, 2, 4]
std_rnd_block_idxs = [1, 3, 5]
tin_ord_block_idxs = [6, 8, 10]
tin_rnd_block_idxs = [7, 9, 11]

first_event = [0, 0, 99999] 
blocks_dict = {}
for block_idxs, title in zip([std_ord_block_idxs, std_rnd_block_idxs, \
                                tin_ord_block_idxs, tin_rnd_block_idxs], \
                                ["std_ord", "std_rnd", "tin_ord", "tin_rnd"]):  
    
    blocks_dict[title] = np.concatenate(
                                        [blocks[i] for i in block_idxs]
                                        )
    # blocks_dict[title] = np.insert(blocks_dict[title], 0, first_event, axis=0)
    

## binary classification for 2 groups



## entropy level decoding (nothing removed)
blocks_dict["std_ord"][:, -1] = 101
blocks_dict["std_rnd"][:, -1] = 102
blocks_dict["tin_ord"][:, -1] = 103
blocks_dict["tin_rnd"][:, -1] = 104
events_decode_1 = np.concatenate(list(blocks_dict.values()))


epochs = mne.Epochs(
                    raw,
                    events_decode_1,
                    tmin=-0.3,
                    tmax=0.7,
                    preload=True,
                    baseline=(None, 0)
                    )

epochs.pick(picks="eeg")
X = epochs.get_data()
y = epochs.events[:, 2]

clf = make_pipeline(
                    StandardScaler(),
                    LinearDiscriminantAnalysis(solver="svd"), # multi class
                    )
time_gen = GeneralizingEstimator(clf, scoring="accuracy") # good score for multiple class
time_gen.fit(X, y)
scores_cv = cross_val_multiscore(time_gen, X, y, cv=2, n_jobs=1)
scores = np.mean(scores_cv, axis=0) # fix scores shape and save them

## extract spatial patterns

## 4 targets std_or, std_rnd, tin_or, tin_rnd
## so here all epochs X, y (not important which frequency they are, finally y should be shape (,4))


## sound-to-sound decoding
## 4 targets, no matter which block... (y should be again (,4))


## get coef for source

## 4 targets, f1, f2, f3, f4 of standard -> clf_std trained on rnd ones (blocks 1, 3, 5)
## 4 targets, f1, f2, f3, f4 of tinnitus -> clf_tin trained on rnd ones (blocks 7, 9, 11)


In [None]:
ev_id = 8

indices_of_carrier = np.where(blocks[1][:, -1] == ev_id)[0]
selected_events = blocks[1][indices_of_carrier - 1]
unique, counts = np.unique(selected_events[:, 2], return_counts=True)
print(unique)
print(counts)
counts_dict = dict(zip(unique, counts))
min_count = min(counts_dict.values()) # the minimum count to balance

# Randomly downsample each trial type to the minimum count
balanced_indices = []
for trial_id in unique:
    trial_indices = np.where(selected_events == ev_id)[0]
    selected = np.random.choice(trial_indices, min_count, replace=False)
    balanced_indices.extend(indices_of_carrier[selected - 1])

len(blocks[1][balanced_indices])