In [1]:
import pandas as pd
import numpy as np
import pathlib
import mne
from mne.time_frequency import tfr_morlet



In [2]:
def get_filtered_eeg(raw, low_freq=8, high_freq=30,n_components=15):
    
    raw_filtered = raw.notch_filter(freqs=50)
    # Apply the bandpass filter
    raw_filtered = raw_filtered.filter(l_freq=low_freq, h_freq=high_freq, fir_design='firwin', n_jobs=20)
    ica = mne.preprocessing.ICA(n_components=n_components, max_iter="auto", random_state=97)
    ica.fit(raw_filtered)
    ica.exclude = []
    # find which ICs match the EOG pattern; no EOG attached, Taking Fp1 as proxy
    eog_indices, eog_scores = ica.find_bads_eog(raw_filtered,ch_name="Fp1")
    ica.exclude = eog_indices
    ica.apply(raw_filtered)
    return raw_filtered

def get_trials(rawData, epoch_length, n_groups=4, fs=256,duration=599):
    """
    Input -----
    
    rawData : List of mne raw
    epoch_length : float
        length of each epoch
    
    Output ------
    
    combinedEpochs : List 
        length of List = n_groups.
        Each group has trials.
    """
    # Initialization
    n_sub = len(rawData)
    n_sub_per_group = int(n_sub / n_groups)
    print(n_sub_per_group)
    epochsList = []
    trialsList = [[] for _ in range(n_groups)]
    # Output Initialization
    # combinedEpochs = [[] for _ in range(n_groups)]
    # creating Fake Events
    events = np.array([[int(i * epoch_length * fs), 0, 1] for i in range(int(duration // epoch_length))])
    event_id = 1
    tmin, tmax = 0, epoch_length  # Epoch start and end times
    # creating epochs 
    for i in range(n_sub):
        epochs = mne.Epochs(rawData[i], events, event_id, tmin, tmax, baseline=None, preload=True)
        epochsList.append(epochs)
    n_epoch_per_sub = len(epochsList[0])
    # creating trials
    for k in range(n_groups):
        for i in range(n_epoch_per_sub):
            trials = [[] for _ in range(n_groups)]
            for j in range(n_sub_per_group):
                epoch = epochsList[j + k*n_sub_per_group][i]
                trials[k].append(epoch)
            trialsList[k].append(trials[k])
    # combining epochs in each trial
    # for k in range(n_groups):
    #     for i in range(n_epoch_per_sub):
    #         # Converting into Epoch object
    #         combinedEpochs[k].append(mne.concatenate_epochs(trialsList[k][i]))
    # return combinedEpochs
    return trialsList

def get_tfr(combinedEpochs,low_freq=8,high_freq=30):
    # Define Morlet wavelet parameters
    frequencies = np.arange(low_freq, high_freq, 1)  # Frequencies of interest from 1 to 50 Hz
    n_cycles = frequencies / 2.  # Number of cycles in Morlet wavelet
    n_groups = len(combinedEpochs)
    n_epoch_per_sub = len(combinedEpochs[0])
    powerD = [[] for _ in range(n_groups)]
    itcD = [[] for _ in range(n_groups)]
    for k in range(n_groups):
        for i in range(n_epoch_per_sub):
            power, itc = mne.time_frequency.tfr_morlet(combinedEpochs[k][i], freqs=frequencies, n_cycles=n_cycles, use_fft=True,output='power',return_itc=True, n_jobs=-1) 
            powerD[k].append(power)
            itcD[k].append(itc)
    return powerD, itcD

In [3]:
types = 'med2'
sub_per_type = 2
data_dir = "./data/files"
data_dir = pathlib.Path(data_dir)
supersetfiles = list(data_dir.glob(f'*{types}*.txt'))
FILES = [[] for _ in range(4)]
for i, ssfile in enumerate(supersetfiles):
    with open(ssfile, 'r') as file:
        c = 0
        for line in file:
            if(c<sub_per_type):
                FILES[i].append(line.strip())
                c = c+1
            else:
                break

In [4]:
FILES

[['../sub-081/eeg/sub-081_task-med2_eeg.bdf',
  '../sub-095/eeg/sub-095_task-med2_eeg.bdf'],
 ['../sub-032/eeg/sub-032_task-med2_eeg.bdf',
  '../sub-034/eeg/sub-034_task-med2_eeg.bdf'],
 ['../sub-078/eeg/sub-078_task-med2_eeg.bdf',
  '../sub-067/eeg/sub-067_task-med2_eeg.bdf'],
 ['../sub-013/eeg/sub-013_task-med2_eeg.bdf',
  '../sub-020/eeg/sub-020_task-med2_eeg.bdf']]

In [5]:
%%capture
rawData = [[] for _ in range(4)]
for k in range(4):
    for i in range(sub_per_type):
        try:
            raw = mne.io.read_raw_bdf(FILES[k][i],preload=True)
            total_time_sec = raw.times[-1] - raw.times[0]
            if total_time_sec >= 599:
                n = raw.info['ch_names']
                if(len(n)==73):
                    raw = raw.drop_channels(n[-9:])
                elif(len(n)==80):
                    raw = raw.drop_channels(n[-16:])
                raw = raw.crop(tmin=0, tmax=599)
                raw_ds = raw.resample(256, n_jobs='cuda')
                rawf=get_filtered_eeg(raw, low_freq=8, high_freq=30,n_components=15)
                rawData[k].append(rawf)
        except:
            print("hi")

In [6]:
rawData = rawData[0] + rawData[1] + rawData[2] + rawData[3]
rawData

[<RawEDF | sub-081_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-095_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-032_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-034_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-078_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-067_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-013_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>,
 <RawEDF | sub-020_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>]

In [7]:
sampleRaw = [rawData[0]]
sampleRaw

[<RawEDF | sub-081_task-med2_eeg.bdf, 64 x 153344 (599.0 s), ~74.9 MB, data loaded>]

In [22]:
n_sub = len(sampleRaw)
epochsList = []

events = np.array([[int(i * 5 * 256), 0, 1] for i in range(int(599 // 5))])
event_id = 1
tmin, tmax = 0, 5  # Epoch start and end times
    # creating epochs 
for i in range(n_sub):
    epochs = mne.Epochs(sampleRaw[i], events, event_id, tmin, tmax, baseline=None, preload=True)
    epochsList.append(epochs)
n_epoch_per_sub = len(epochsList[0])

Not setting metadata
119 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 119 events and 1281 original time points ...
0 bad epochs dropped


In [23]:
epochsList[0]

0,1
Number of events,119
Events,1: 119
Time range,0.000 – 5.000 s
Baseline,off


In [24]:
frequencies = np.arange(8,30, 1)  # Frequencies of interest from 1 to 50 Hz
n_cycles = frequencies / 2.  # Number of cycles in Morlet wavelet
fs = 256

In [53]:
# compute complex signal
complex_signal = mne.time_frequency.tfr_array_morlet(epochsList[0][0].get_data(), sfreq=fs, freqs=frequencies, n_cycles=n_cycles,use_fft=True, output='complex', n_jobs=20)

# Add a new axis to array to enable broadcasting
array_expanded = complex_signal[0][:, np.newaxis, :, :]

# Compute the pairwise differences using broadcasting
differences = array_expanded - complex_signal[0][np.newaxis, :, :, :]

# compute angle
differences = np.angle(differences)

In [29]:
differences.shape

(64, 64, 22, 1281)