## Flow

In [61]:
# RQA -- in between - raw quality assessment (amplitude * power frequency ratio * alpha band psd ratio)
# Extract file ---> collective data ---> Bad channel removal ----- Not needed [sampling rate correction (500Hz)] -----
# filtering (butterworth bandpass - tune coefficient)  ----- rereferencing ----- epoching ----- ERP ----- baseline removal ---- 

Variables

In [62]:
SCALINGS = 4e-4
SFREQ = 500
WINDOW_SIZE = 1

Importing all packages

In [63]:
import numpy as np
import pandas as pd
from scipy import fftpack
from scipy.fft import fft
import time
import mne
import matplotlib
import matplotlib.pyplot as plt
import os
from scipy.stats import kurtosis, zscore
from mne.preprocessing import create_ecg_epochs, create_eog_epochs, read_ica
from mne.time_frequency import tfr_morlet, tfr_array_morlet, morlet, AverageTFR
from itertools import product
import pywt

matplotlib.use('TkAgg')

## Reading the EEG data

In [64]:
folder_path = os.getcwd()+'\\Depression-Sample-dataset-AIIMS\\'
items = os.listdir(folder_path)
active_or_sham_list = [item for item in items if os.path.isdir(os.path.join(folder_path, item))]
for active_or_sham in active_or_sham_list:
    patient_folder_path = os.path.join(folder_path, active_or_sham)
    items = os.listdir(patient_folder_path)
    patients_list = [item for item in items if os.path.isdir(os.path.join(patient_folder_path, item))]
    # patients_list = ['Hemlata', 'PreetiSingh', 'VinodKumarSharma']
    for patient in patients_list:
        pre_post_int_folder_path = os.path.join(patient_folder_path, patient)
        items = os.listdir(pre_post_int_folder_path)
        pre_post_int_list = [item for item in items if os.path.isdir(os.path.join(pre_post_int_folder_path, item))]
        for var in pre_post_int_list:
            if var=='Pre':
                pre_path = os.path.join(pre_post_int_folder_path, var)
                ao_files_list = ['20230831110419_Hemlata_05.10.23_01_AO', '20230718202004_Preeti singh_22.08.23-01_AO', 
                                 '20230829195917_VinodKumarSharma_25.9.23_01_AO', '20230825020604_JitenderKumar_29.08.23_01_AO',
                                  '20230827074250_SeemaKumari_11.09.23_01_AO']
                file_path = pre_path + '\\' + ao_files_list[0] + '.easy'
                break # Remove for all pre, post and intervention for a patient
        break # Remove for all patients in active or sham
    break # Remove for both active and sham

# Manual file_path
# sham_or_active = 'Sham\\' 
# patient = 'SeemaKumari'
# pre_post_intervention = '\\pre\\'
# directory = folder_path + sham_or_active + patient + pre_post_intervention
# file_path = directory + '20230827074800_SeemaKumari_11.09.23_01_GNG' + '.easy'

# TODO: Remove ?
# OBSERVED: Active: pre all spike (AO/GNG/Eye close) ---- post except 17 and 25 all spike (AO/GNG/Eye close)
# OBSERVED: Sham: pre all spike in Eye close only ------ post reduced spikes in Eye close only

In [65]:
df = pd.read_csv(file_path, sep='\t')
channel_str='Channel 1:P8\
	Channel 2:T8\
	Channel 3:CP6\
	Channel 4:FC6\
	Channel 5:F8\
	Channel 6:F4\
	Channel 7:C4\
	Channel 8:P4\
	Channel 9:AF4\
	Channel 10:Fp2\
	Channel 11:Fp1\
	Channel 12:AF3\
	Channel 13:Fz\
	Channel 14:FC2\
	Channel 15:Cz\
	Channel 16:CP2\
	Channel 17:PO3\
	Channel 18:O1\
	Channel 19:Oz\
	Channel 20:O2\
	Channel 21:PO4\
	Channel 22:Pz\
	Channel 23:CP1\
	Channel 24:FC1\
	Channel 25:P3\
	Channel 26:C3\
	Channel 27:F3\
	Channel 28:F7\
	Channel 29:FC5\
	Channel 30:CP5\
	Channel 31:T7\
	Channel 32:P7'

string_channels = channel_str.replace('\t', ':').split(':')
channel_names = [string_channels[i] for i in range(len(string_channels)) if i % 2 != 0]
channel_names.append('ax')
channel_names.append('ay')
channel_names.append('az')
channel_names.append('trigger')
channel_names.append('timestamp(ms)')
all_channels = np.array(channel_names[:-5])
df.columns=channel_names
df.head()

Unnamed: 0,P8,T8,CP6,FC6,F8,F4,C4,P4,AF4,Fp2,...,F7,FC5,CP5,T7,P7,ax,ay,az,trigger,timestamp(ms)
0,25379810,29103155,28599472,15288889,40199265,28470199,29530134,28225105,12977421,10060115,...,27105488,28082879,26403810,22191074,25382844,0,0,0,0,1693460058992
1,25378626,29099195,28596399,15280034,40182550,28472527,29529559,28223555,12977802,10054182,...,27097417,28086114,26406471,22194416,25384645,0,0,0,0,1693460058994
2,25376580,29093544,28595577,15272144,40165525,28474566,29528571,28223016,12979022,10046370,...,27087846,28086630,26409398,22195833,25387344,0,0,0,0,1693460058996
3,25373540,29087155,28593108,15264796,40151809,28470077,29524346,28219665,12975505,10034833,...,27078459,28081278,26408405,22191979,25386476,0,0,0,0,1693460058998
4,25367807,29077171,28587184,15255541,40138599,28466434,29519236,28214265,12972645,10021083,...,27069769,28078585,26403375,22186997,25382070,0,0,0,0,1693460059000


Data analysis

In [66]:
transposed_data=df.T

# Create a MNE-Python info object and specifying sampling rate of data
ch_names = df.columns.tolist()[:-5]
ch_types = ['eeg' for i in range(32)]
info = mne.create_info(ch_names=ch_names,ch_types=ch_types, sfreq=500)

# Convert all EEG units to nV
raw = mne.io.RawArray(transposed_data.values[:-5,:]/1e9, info)

print(f"num of channels: {raw.info.get('nchan')}")
print(f'Shape of the data: {raw.get_data().shape}')

Creating RawArray with float64 data, n_channels=32, n_times=149999
    Range : 0 ... 149998 =      0.000 ...   299.996 secs
Ready.


num of channels: 32
Shape of the data: (32, 149999)


Setting custom Montage

In [67]:
mne.channels
mont1020 = mne.channels.make_standard_montage('standard_1020')
mont1005 = mne.channels.make_standard_montage('standard_1005')

ind = [i for (i, channel) in enumerate(mont1020.ch_names) if channel in all_channels]
mont1020_new = mont1020.copy()

mont1020_new.ch_names = [mont1020.ch_names[x] for x in ind]
kept_channel_info = [mont1020.dig[x+3] for x in ind]

# Keeping the first three rows as they are the fiducial points information
mont1020_new.dig = mont1020.dig[0:3]+kept_channel_info

raw.set_montage(mont1020_new)
mont1020_new.plot()

<Figure size 640x640 with 1 Axes>

#### Time amplitude plot

In [68]:
def time_amplitude(raw, title):
    fig = raw.plot(
        n_channels=32, 
        scalings=SCALINGS
        )
    fig.savefig(f'MNE-graphs/time-amplitude/{title}-EEG.png')

    print(raw.info)
    # TODO: Extract statistical features from time domain such as mean, median, variance, skewness, kurtosis

#### Power spectral density plot

In [69]:
def psd(raw, title):
    fig = raw.plot_psd(picks=raw.info['ch_names'])
    fig.savefig(f'MNE-graphs/psd-frequency/{title}.png')

#### Wavelet plot

In [70]:
# f = scale2frequency(wavelet, scale)/sampling_period

In [71]:
# def wavelet(raw, title):
#     signal = raw.get_data()[0]
#     t = np.linspace(0, 299, len(signal))
#     coefficients, frequencies = pywt.cwt(signal, scales=np.arange(1, 128), wavelet='cmor')  

#     plt.figure(figsize=(10, 6))
#     plt.imshow(np.abs(coefficients), aspect='auto', cmap='jet', extent=[0, 299, 1, 128])
#     plt.colorbar(label="Magnitude")
#     plt.ylabel("Scale")
#     plt.xlabel("Time")
#     plt.title("CWT")
#     plt.show()

### EEG Quality Assessment Tool

In [43]:
def calculate_eqi_metrics(eeg_data_window):
    average_amplitude_spectrum = []
    rms_amplitude = []
    max_gradient = []
    zcr = []
    kurt = []
    num_channels = len(eeg_data_window.info['ch_names'])

    for i in range(num_channels):
        amplitude=eeg_data_window.get_data()[i]

        # Metric 1: Average Single-Sided Amplitude Spectrum (0.01-45Hz range)
        average_amplitude_spectrum.append(np.mean(np.abs(np.fft.fft(amplitude))))

        # Metric 2: RMS Amplitude
        rms_amplitude.append(np.sqrt(np.mean(amplitude ** 2)))

        # Metric 3: Maximum Gradient
        max_gradient.append(np.max(np.abs(np.diff(amplitude))))

        # Metric 4: Zero-Crossing Rate (ZCR)
        zcr.append(np.mean(np.abs(np.diff(np.sign(amplitude))))) # TODO: Amplitude all in positive why ?

        # Metric 5: Kurtosis
        kurt.append(kurtosis(amplitude, axis=0))
    
    return average_amplitude_spectrum, rms_amplitude, max_gradient, zcr, kurt

def calculate_eqi(eeg_data):
    num_samples = len(eeg_data)
    num_windows = int(num_samples / (SFREQ * WINDOW_SIZE))
    num_channels = len(eeg_data.info['ch_names'])

    eqi_scores = [[0 for col in range(num_windows)] for row in range(num_channels)]
    average_eqi_score_per_channel = [0 for i in range(num_channels)]
    clean_data_percentage_per_channel = [0 for i in range(num_channels)]

    for i in range(num_windows):
        # In seconds
        start_idx = int(i * WINDOW_SIZE)
        end_idx = int((i + 1) * WINDOW_SIZE)
        eeg_data_window = eeg_data.copy().crop(tmin=start_idx, tmax=end_idx)

        eqi_metrics = calculate_eqi_metrics(eeg_data_window)
        # Z-score normalization
        for chan in range(num_channels):
            # 5 X 32
            eqi_metrics_channel = eqi_metrics[0][chan], eqi_metrics[1][chan], eqi_metrics[2][chan], eqi_metrics[3][chan], eqi_metrics[4][chan] 
            feature_vector_normalized = zscore(eqi_metrics_channel)
            eqi_score = np.sum(np.abs(feature_vector_normalized) > 1, axis=0)
            eqi_scores[chan][i]=eqi_score
    average_eqi_score_per_channel = np.mean(eqi_scores, axis=1)
    clean_data_percentage_per_channel = np.mean(np.array(eqi_scores) < 2, axis=1) * 100

    return (np.mean(average_eqi_score_per_channel), np.mean(clean_data_percentage_per_channel))

In [44]:
def find_score(eeg_data):
    average_eqi_score, clean_data_percentage = calculate_eqi(eeg_data)
    print(average_eqi_score, clean_data_percentage)
    return (average_eqi_score, clean_data_percentage)

#### Raw graph plot

In [45]:
title = "0--Raw graph with line filter"
time_amplitude(raw, title)
psd(raw, title)
# wavelet(raw, title)
print(raw.info)
find_score(raw)

<Info | 8 non-empty values
 bads: []
 ch_names: P8, T8, CP6, FC6, F8, F4, C4, P4, AF4, Fp2, Fp1, AF3, Fz, FC2, ...
 chs: 32 EEG
 custom_ref_applied: False
 dig: 35 items (3 Cardinal, 32 EEG)
 highpass: 0.0 Hz
 lowpass: 250.0 Hz
 meas_date: unspecified
 nchan: 32
 projs: []
 sfreq: 500.0 Hz
>
NOTE: plot_psd() is a legacy function. New code should use .compute_psd().plot().
Effective window size : 4.096 (s)
<Info | 8 non-empty values
 bads: []
 ch_names: P8, T8, CP6, FC6, F8, F4, C4, P4, AF4, Fp2, Fp1, AF3, Fz, FC2, ...
 chs: 32 EEG
 custom_ref_applied: False
 dig: 35 items (3 Cardinal, 32 EEG)
 highpass: 0.0 Hz
 lowpass: 250.0 Hz
 meas_date: unspecified
 nchan: 32
 projs: []
 sfreq: 500.0 Hz
>
1.0840301003344481 95.48494983277592


(1.0840301003344481, 95.48494983277592)

Channels marked as bad:
none


# Preprocessing

1. Bad channels removal

Removing Amplitude, flatlined, standard, kurtosis threshold channels

In [46]:
# # Thresholds - https://ieeexplore.ieee.org/abstract/document/6346834?casa_token=zFQWJXGAa80AAAAA:A84Ep-dTINstXMDRX12vpvBuDb2TLLvRR5jMK9wSuAUdTx4nZWvxH_bpzSqCKCigwsHwEmpqaw
# # Amplitude thresholds
# # TODO: Set all and observe all for different patients, can apply hyperparameter tuning
# amplitude_threshold = 1.5e-3
# flatline_threshold = 1e-4
# std_threshold = 1e-3
# kurtosis_threshold = 5.0

# bad_channels_amplitude = [raw.ch_names[i] for i in range(raw.info['nchan']) if (max(raw._data[i, :]) > amplitude_threshold and min(raw._data[i, :]) < -amplitude_threshold )]
# print(len(bad_channels_amplitude), bad_channels_amplitude, "111111")
# bad_channels_flatline = [raw.ch_names[i] for i in range(raw.info['nchan']) if (np.all(np.abs(raw._data[i, :]) < flatline_threshold))]
# print(len(bad_channels_flatline), bad_channels_flatline, "22222")
# bad_channels_std = [raw.ch_names[i] for i in range(raw.info['nchan']) if (np.std(raw._data[i, :]) > std_threshold )]
# print(len(bad_channels_std), bad_channels_std, "3333")
# bad_channels_kurtosis = [raw.ch_names[i] for i in range(raw.info['nchan']) if (kurtosis(raw._data[i, :]) > kurtosis_threshold)]
# print(len(bad_channels_kurtosis), bad_channels_kurtosis, "4444")

# all_bad_channels = list(set(bad_channels_amplitude + bad_channels_flatline + bad_channels_std + bad_channels_kurtosis))
#                     # + bad_channels_power_spectrum + bad_channels_frequency_range)
# raw_cleaned = raw.copy()
# raw_cleaned.info['bads'] = list(set(raw_cleaned.info['bads']+(all_bad_channels)))

1 ['Oz'] 111111
0 [] 22222
4 ['Cz', 'Oz', 'CP5', 'P7'] 3333
1 ['Oz'] 4444


Interpolation required to estimate missing or bad channel values.

In [60]:
title = '1--Interpolated bad channels graph'
# Dropping bad channels

# raw_deleted = raw_cleaned.copy() 
# raw_deleted.drop_channels(all_bad_channels)
# time_amplitude(raw_deleted, title)
# psd(raw_interpolated, title)
# find_score(raw_deleted)

#Interpolating bad channels

# raw_cleaned.interpolate_bads(reset_bads = True)
# time_amplitude(raw_cleaned, title)
# psd(raw_cleaned, title)
# find_score(raw_cleaned)

raw_cleaned = raw.copy()
# time_amplitude(raw_cleaned, title)
# psd(raw_cleaned, title)

2. Band pass filtering

Experimenting configurations to minimize EQI

In [48]:
# # applied band pass filter of 0.01-45 Hz for depression detection 
# l_freq = 0.01
# h_freq = 45
# fir_phase = ['zero', 'minimum']
# fir_window = ['hamming', 'hann']
# iir_phase = ['zero', 'zero-double', 'forward']
# fir_combinations = list(product(fir_phase, fir_window))

# method = 'fir'
# eqi = 1000
# clean = 0
# for index, combination in enumerate(fir_combinations, start=1):
#     raw_filter = raw_cleaned.copy()
#     raw_filter.filter(method= method,
#     phase= combination[0],
#     fir_window= combination[1],
#     l_freq= l_freq,
#     h_freq= h_freq)
#     eqi_1, clean_1 = find_score(raw_filter)
#     if eqi_1 < eqi or clean_1 > clean:
#         eqi, clean = eqi_1, clean_1
#         combination = combination

# method = 'iir'
# for phase in iir_phase:
#     raw_filter = raw_cleaned.copy()
#     raw_filter.filter(method= method,
#         phase= phase,
#         l_freq= l_freq,
#         h_freq= h_freq)
#     eqi_1, clean_1 = find_score(raw_filter)
#     if eqi_1 < eqi or clean_1 > clean:
#         eqi, clean = eqi_1, clean_1
#         combination = phase

# print("Band pass filter configurations: ", eqi, clean, combination)

Filtering data with optimal band pass configurations

In [49]:
# title = "2--Band pass filtered graph"
# if len(combination) == 2:
#     raw_cleaned.filter(method= 'fir',
#     phase= combination[0],
#     fir_window= combination[1],
#     l_freq= l_freq,
#     h_freq= h_freq)
# elif len(combination) == 1:
#     raw_cleaned.filter(method= 'iir',
#         phase= combination,
#         l_freq= l_freq,
#         h_freq= h_freq)

l_freq = 0.01
h_freq = 45
raw_cleaned.filter(method= 'fir',
    phase= 'minimum',
    fir_window= 'hann',
    l_freq= l_freq,
    h_freq= h_freq)

time_amplitude(raw_cleaned, title)
psd(raw_cleaned, title)
# print("Band pass filter configurations: ", eqi, clean, combination)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.01 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, non-linear phase, causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hann window with 0.0546 passband ripple and 44 dB stopband attenuation
- Lower transition bandwidth: 0.01 Hz
- Upper transition bandwidth: 11.25 Hz
- Filter length: 155001 samples (310.002 s)



  raw_cleaned.filter(method= 'fir',
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.6s


<Info | 9 non-empty values
 bads: 4 items (CP5, P7, Cz, Oz)
 ch_names: P8, T8, CP6, FC6, F8, F4, C4, P4, AF4, Fp2, Fp1, AF3, Fz, FC2, ...
 chs: 32 EEG
 custom_ref_applied: False
 dig: 35 items (3 Cardinal, 32 EEG)
 highpass: 0.0 Hz
 lowpass: 45.0 Hz
 meas_date: unspecified
 nchan: 32
 projs: []
 sfreq: 500.0 Hz
>
NOTE: plot_psd() is a legacy function. New code should use .compute_psd().plot().
Effective window size : 4.096 (s)


Channels marked as bad:
['CP5', 'P7', 'Cz', 'Oz']


3. Rereferencing

Calculates the mean voltage from all electrodes at each time point and subtracts this mean from the voltage at each individual electrode.

In [50]:
# eqi = 1000
# clean = 0
# raw_referenced = raw_cleaned.copy()
# raw_referenced.set_eeg_reference('average', projection=True).apply_proj() 
# eqi_1, clean_1 = find_score(raw_referenced)
# if eqi_1 < eqi or clean_1 > clean:
#     eqi, clean = eqi_1, clean_1
#     reference = 'average'

# ref_channels = ['Cz']
# raw_referenced = raw_cleaned.copy()
# raw_referenced.set_eeg_reference(ref_channels=ref_channels)
# eqi_1, clean_1 = find_score(raw_referenced)
# if eqi_1 < eqi or clean_1 > clean:
#     eqi, clean = eqi_1, clean_1
#     reference = ref_channels
# print("Rereference configurations: ", eqi, clean, reference)

In [51]:
title = "3---Rereferenced"
# if reference == 'average':
raw_cleaned.set_eeg_reference('average', projection=True).apply_proj() 
# else:
    # raw_cleaned.set_eeg_reference(ref_channels=reference)

time_amplitude(raw_cleaned, title)
psd(raw_cleaned, title)
# print("Rereference configurations: ", eqi, clean)

EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Created an SSP operator (subspace dimension = 1)
1 projection items activated
SSP projectors applied...
<Info | 10 non-empty values
 bads: 4 items (CP5, P7, Cz, Oz)
 ch_names: P8, T8, CP6, FC6, F8, F4, C4, P4, AF4, Fp2, Fp1, AF3, Fz, FC2, ...
 chs: 32 EEG
 custom_ref_applied: False
 dig: 35 items (3 Cardinal, 32 EEG)
 highpass: 0.0 Hz
 lowpass: 45.0 Hz
 meas_date: unspecified
 nchan: 32
 projs: Average EEG reference: on
 sfreq: 500.0 Hz
>
NOTE: plot_psd() is a legacy function. New code should use .compute_psd().plot().
Effective window size : 4.096 (s)


Channels marked as bad:
['CP5', 'P7', 'Cz', 'Oz', 'C4']


<!-- 4. ECG/EOG Correction using ICA on Raw data

ICA decomposes data into different components each representing a spatial pattern. Excluding a component means discarding that specific spatial pattern and its contribution to original data -->

In [52]:
# ica = mne.preprocessing.ICA(
#     n_components=20, random_state=0)
# ica.fit(raw_cleaned)
# ica.plot_components()

In [53]:
# # Automatically : Find bad ECG and EOG for all channels and plotting epochs data without Removal of ICA components
# bad_indices_eog = []
# channels_eog = []
# bad_indices_ecg = []
# channels_ecg = []

# for channel in raw_cleaned.info['ch_names']:
#     eog, scores_eog = ica.find_bads_eog(raw_cleaned, channel, threshold='auto')
#     if len(eog):
#         bad_indices_eog.append(eog)
#         channels_eog.append(channel)
    
#     ecg = ica.find_bads_ecg(raw_cleaned, channel, threshold='auto')
#     if len(ecg[0]):
#         bad_indices_ecg.append(ecg[0])
    
# print("EOG Bad indices", bad_indices_eog, channels_eog)
# print("ECG Bad indices", bad_indices_ecg, channels_ecg)

# excluded_components = []
# for i in bad_indices_eog:
#     for j in i:
#         excluded_components.append(j)
# for i in bad_indices_ecg:
#     for j in i:
#         excluded_components.append(j)
# excluded_components = list(set(excluded_components))
# print("Excluded components", excluded_components)

In [54]:
# # Manually : Observe bad ECG and EOG from ICA components
# # excluded_components = []
# print("Excluded components", excluded_components)
# print(bad_indices_ecg, bad_indices_eog)

In [55]:
# # Apply removed ICA components for both ECG and EOG for each channel and plot epoched data with removal of ICA componetns
# raw_corrected = ica.apply(raw_cleaned, exclude=excluded_components) # ICA algo identifies spatially independent components in EEG 
# raw_corrected.plot(n_channels=len(raw_cleaned.info['ch_names']), scalings=SCALINGS)

In [56]:
# # Calculate EQI score
# find_score(raw_corrected)

In [58]:
# Wavelet denoising: 
signal = raw_cleaned.get_data()
t = np.linspace(0, 299, len(signal))
# Perform a multi-level wavelet decomposition
coeffs = pywt.wavedec(signal, 'db1', level=4)

# Set a threshold to nullify smaller coefficients (assumed to be noise)
threshold = 0.5
coeffs_thresholded = [pywt.threshold(c, threshold, mode='soft') for c in coeffs]

# Reconstruct the signal from the thresholded coefficients
denoised_signal = pywt.waverec(coeffs_thresholded, 'db1')

# Plotting the noisy and denoised signals
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(t, raw_cleaned.get_data())
plt.title("Noisy Signal")
plt.subplot(1, 2, 2)
plt.plot(t, denoised_signal)
plt.title("Denoised Signal")
plt.tight_layout()
plt.show()

5. Epoching

5.1 Function which converts Epoch to raw data for Signal Quality Assessment purpose only

In [None]:
def convert_epoch_to_raw(epochs):
    def map_indices(i, j, k):
        # Assuming i ranges from 0 to 149, j ranges from 0 to 31, and k ranges from 0 to 400
        # Mapping to the reshaped_data indices
        reshaped_i = j
        reshaped_j = i * len(epochs.get_data()[0][0])  + k  # Assuming 401 is the size of the second dimension in the original data

        return reshaped_i, reshaped_j

    reshaped_data = np.random.rand(32, 60150)
    for i in range(len(epochs.get_data())):
        for j in range(len(epochs.get_data()[0])):
            for k in range(len(epochs.get_data()[0][0])):
                reshaped_i, reshaped_j = map_indices(i, j, k)

                reshaped_data[reshaped_i, reshaped_j] = epochs.get_data()[i][j][k]
    return reshaped_data

5.2 Adding trigger channel

In [None]:
trigger_col = df.values[:, -2]
trigger_times = []
event_id = {}
event_list = []

for i in range(len(trigger_col)):
    if trigger_col[i]!=0:
        trigger_times.append(i+2)
        event_id[str(trigger_col[i])] = trigger_col[i]
        event_list.append(trigger_col[i])

# Create an events array (trigger value, previous, sample number)
events = np.column_stack((trigger_times, np.zeros_like(trigger_times), event_list))
# fig = raw_cleaned.plot(n_channels=len(raw_cleaned.info['ch_names']), events=events, event_id=event_id, event_color={1:'r', 2:'g', 3:'b'}, scalings=SCALINGS)

5.3 Remove corrupted events

In [None]:
# Remove corrupted event
# From event_id
if '3' in event_id:
    del event_id['3']

# From events
indexes=[]
for i in range(len(events)):
    if events[i][2] == 3:
        indexes.append(i)
events = np.delete(events, indexes, 0)

fig = raw_cleaned.plot(n_channels=len(raw_cleaned.info['ch_names']), events=events, event_id=event_id, event_color={1:'r', 2:'g'}, scalings=SCALINGS)
fig.savefig(f'MNE-graphs/time-amplitude/4---Trigger channels added.png')

5.4 Defining and segmenting Epochs

In [None]:
# Applying baseline correction (mode: mean) - baseline=(tmin, 0)
# gonogo - 200 msec around stimulus
tmin = -0.1
tmax = 0.7
epochs = mne.Epochs(raw_cleaned, events=events, event_id=None, tmin=tmin, tmax=tmax, baseline=(tmin, 0), detrend=1, preload=True, picks=['eeg'])
# epochs.drop_bad() # TODO
fig = epochs.plot(n_channels=len(raw_cleaned.info['ch_names']), event_color={1:'r', 2:'g'}, events=events, scalings=SCALINGS)

# Calculate EQI score
# reshaped_data = convert_epoch_to_raw(epochs)
# epoch_to_raw_data = mne.io.RawArray(reshaped_data, info)
# find_score(epoch_to_raw_data)

Experiment baseline periods 

In [None]:
# Define the baseline period relative to the event onset
# baseline = (-0.05, 0.2) 
# epochs.apply_baseline(baseline=baseline)

# Calculate EQI score
# reshaped_data = convert_epoch_to_raw(epochs)
# epoch_to_raw_data = mne.io.RawArray(reshaped_data, info)
# find_score(epoch_to_raw_data)

6. Artifact Rejection (EOG/ECG) using ICA on Epoched data

Doing ICA detection after epoching because ICA decomposition needs strong signals which we will get in our epochs

In [None]:
ica = mne.preprocessing.ICA(
    n_components=20, random_state=0)
ica.fit(epochs)
ica.plot_components()

Excluding bads

In [None]:
# Find bad ECG and EOG for all channels and plotting epochs data without Removal of ICA cmponents
bad_indices_eog = []
channels_eog = []
bad_indices_ecg = []
channels_ecg = []

for channel in raw_cleaned.info['ch_names']:
    eog, scores_eog = ica.find_bads_eog(epochs, channel, threshold='auto')
    if len(eog):
        bad_indices_eog.append(eog)
        channels_eog.append(channel)
    
    ecg, scores_ecg = ica.find_bads_ecg(epochs, channel, threshold='auto')
    if len(ecg):
        bad_indices_ecg.append(ecg)
        channels_ecg.append(channel)
    
print("EOG Bad indices", bad_indices_eog, channels_eog)
print("ECG Bad indices", bad_indices_ecg, channels_ecg)

excluded_components = []
for i in bad_indices_eog:
    for j in i:
        excluded_components.append(j)
for i in bad_indices_ecg:
    for j in i:
        excluded_components.append(j)
excluded_components = list(set(excluded_components))
print("Excluded components", excluded_components)

In [None]:
# Apply removed ICA components for both ECG and EOG for each channel and plot epoched data with removal of ICA componetns
cleaned_epochs = ica.apply(epochs, exclude=excluded_components) # ICA algo identifies spatially independent components in EEG 
cleaned_epochs.plot(n_channels=len(raw_cleaned.info['ch_names']), event_color={1:'r', 2:'g'}, events=events, scalings=SCALINGS)

In [None]:
# Calculate EQI score
reshaped_data = convert_epoch_to_raw(cleaned_epochs)
epoch_to_raw_data = mne.io.RawArray(reshaped_data, info)
find_score(epoch_to_raw_data)

In [None]:
# Wavelet denoising: 
signal = raw_cleaned.get_data()
t = np.linspace(0, 299, len(signal))
# Perform a multi-level wavelet decomposition
coeffs = pywt.wavedec(signal, 'db1', level=4)

# Set a threshold to nullify smaller coefficients (assumed to be noise)
threshold = 0.5
coeffs_thresholded = [pywt.threshold(c, threshold, mode='soft') for c in coeffs]

# Reconstruct the signal from the thresholded coefficients
denoised_signal = pywt.waverec(coeffs_thresholded, 'db1')

# Plotting the noisy and denoised signals
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(t, raw_cleaned)
plt.title("Noisy Signal")
plt.subplot(1, 2, 2)
plt.plot(t, denoised_signal)
plt.title("Denoised Signal")
plt.tight_layout()
plt.show()

8. Averaging and estimating evoked responses

This reduces noise and enhances signal to noise ratio

In [None]:
# epochs = cleaned_epochs # Uncomment if applying EOG/ECG correction on epoched data
individual_nogo_epoch_plots = []
individual_go_epoch_plots = []
nogo = epochs['1'].average()
go = epochs['2'].average()
# discarded = epochs['3'].average()
evokeds = dict(go=go, nogo=nogo)

To plot overlay of individual epochs for each channel: Uncomment below

In [None]:
# %matplotlib inline
# nogo_data = nogo.get_data()
# go_data = go.get_data()

# # Convert to msec
# times = epochs.times *1000 # For each epoch -- Sampling frequency (500 samples/sec) * duration of each epoch (800 msec) 

# for channel_index, channel_name in enumerate(channel_names[:32]):

#     # 'F8', 'F4', 'C4', 'P4' - Frontal channels
#     # 'Cz', 'C3', 'CP1' - central channels
#     # 'Pz', 'F3' - Parietal channels
#     gonogo_channels = ['F8', 'F4', 'C4', 'P4', 'Cz', 'C3', 'CP1', 'Pz', 'F3']

#     if channel_name in gonogo_channels:

#         # Plot Nogo overlay of Individual epochs for each channel
#         plt.figure(figsize=(21, 10))
#         for ind_epoch, epoch in enumerate(epochs):
#             if next(iter(list(epochs[ind_epoch].event_id.values()))) == 1:
#                 plt.plot(times, epoch[channel_index], label=f'Epoch {ind_epoch}')
#         plt.xlabel('Time (ms)')
#         plt.ylabel('Amplitude (V)')
#         plt.title(f'NOGO - Overlay of Individual Epochs for channel {channel_name}')
#         plt.plot(times, nogo_data[channel_index], linewidth='5')
#         plt.show()
        
#         # Plot Go overlay of Individual epochs for each channel
#         plt.figure(figsize=(21, 10))
#         for ind_epoch, epoch in enumerate(epochs):
#             # print(type(epochs[ind_epoch]), type(epoch)) # <class 'mne.epochs.Epochs'> <class 'numpy.ndarray'> 32X401
#             if next(iter(list(epochs[ind_epoch].event_id.values()))) == 2:
#                 plt.plot(times, epoch[channel_index], label=f'Epoch {ind_epoch+1}')
#         plt.xlabel('Time (ms)')
#         plt.ylabel('Amplitude (V)')
#         plt.title(f'GO - Overlay of Individual Epochs for channel {channel_name}')
#         plt.plot(times, go_data[channel_index], linewidth='5')
       


In [None]:
# # Plot dropped epochs
# epochs.plot_drop_log()

To look for individual go and nogo plots for all 9 channels ( normal and 10X scaled)

In [None]:
# for channel_index, channel_name in enumerate(channel_names[:32]):

#     # 'F8', 'F4', 'C4', 'P4' - Frontal channels
#     # 'Cz', 'C3', 'CP1' - central channels
#     # 'Pz', 'F3' - Parietal channels
#     gonogo_channels = ['F8', 'F4', 'C4', 'P4', 'Cz', 'C3', 'CP1', 'Pz', 'F3']
#     lis = []
#     lis.append(channel_name)
#     if channel_name in gonogo_channels:
#         nogo.plot(picks=lis, titles='Nogo', ylim = dict(eeg=[-5e1, 4e1]), time_unit = 'ms')
#         nogo.plot(picks=lis, titles='Nogo 10X scaled', ylim = dict(eeg=[-4e1, 3e1]),  time_unit = 'ms')

#         go.plot(picks=lis, titles='Go', ylim = dict(eeg=[-4e1, 3e1]),  time_unit = 'ms')
#         go.plot(picks=lis, titles='Go 10X scaled', ylim = dict(eeg=[-4e1, 3e1]), time_unit = 'ms')

7. ERP Components analysis

In [None]:
print("Bads in Go", go.info['bads'])
print("Bads in nogo", nogo.info['bads'])
evokeds = dict(go=go, nogo=nogo)
fig = go.plot(time_unit='ms', titles='Go-average epoch plot for all channels')
fig = nogo.plot(time_unit='ms', titles='NoGo-average epoch plot for all channels')

Excluding bads

In [None]:
# go.info['bads'] = []
# nogo.info['bads'] = []

In [None]:
# evokeds = dict(go=go, nogo=nogo)
# fig = go.plot(time_unit='ms', titles='Go-average epoch plot for all channels')
# fig.savefig(f'MNE-graphs/go-nogo/go-plot-32channels.png')
# fig = nogo.plot(time_unit='ms', titles='NoGo-average epoch plot for all channels')
# fig.savefig(f'MNE-graphs/go-nogo/nogo-plot-32channels.png')

In [None]:
gonogo_channels = ['F8', 'F4', 'C4', 'P4', 'Cz', 'C3', 'CP1', 'Pz', 'F3']
erp_components = {'P300': [0, 250, 400], 'N2': [0, 180, 390] }

In [None]:
for key, value in erp_components.items():
    for channel_name in gonogo_channels:
        fig = mne.viz.plot_compare_evokeds(evokeds, picks=channel_name, vlines=value, time_unit='ms', title=f'{key} for channel {channel_name}')
        # fig[0].savefig(f'MNE-graphs/go-nogo/{channel_name}.png')
    break