Variables

In [1]:
SCALINGS = 4e-4
SFREQ = 500
WINDOW_SIZE = 1
SHOW = False
FMIN = 0.01
FMAX = 45
FREQ_BANDS = {
    'theta': (4, 8),
    'alpha': (8, 13),
    'beta': (13, 30),
}
ALL_CHANNELS=[]

Importing all packages

In [2]:
import numpy as np
import pandas as pd
from scipy import fftpack
from scipy.fft import fft
import time
import mne
import matplotlib
import matplotlib.pyplot as plt
import os
from scipy.stats import kurtosis, zscore, ttest_rel
from mne.preprocessing import create_ecg_epochs, create_eog_epochs, read_ica
from mne.time_frequency import tfr_morlet, tfr_array_morlet, morlet, AverageTFR
from itertools import product
import pywt
from scipy.signal import hilbert
from mne_connectivity import spectral_connectivity_epochs
from mne_connectivity.viz import plot_sensors_connectivity
import networkx as nx
import numpy as np 
import re
from scipy.cluster.hierarchy import dendrogram, linkage
import glob

matplotlib.use('TkAgg')

## Reading the EEG data

In [3]:
def data_transformation_easy(file_path):
	df = pd.read_csv(file_path, sep='\t')
	channel_str='Channel 1:P8\
		Channel 2:T8\
		Channel 3:CP6\
		Channel 4:FC6\
		Channel 5:F8\
		Channel 6:F4\
		Channel 7:C4\
		Channel 8:P4\
		Channel 9:AF4\
		Channel 10:Fp2\
		Channel 11:Fp1\
		Channel 12:AF3\
		Channel 13:Fz\
		Channel 14:FC2\
		Channel 15:Cz\
		Channel 16:CP2\
		Channel 17:PO3\
		Channel 18:O1\
		Channel 19:Oz\
		Channel 20:O2\
		Channel 21:PO4\
		Channel 22:Pz\
		Channel 23:CP1\
		Channel 24:FC1\
		Channel 25:P3\
		Channel 26:C3\
		Channel 27:F3\
		Channel 28:F7\
		Channel 29:FC5\
		Channel 30:CP5\
		Channel 31:T7\
		Channel 32:P7'

	channel_names = re.findall(r'Channel \d+:(\w+)', channel_str)
	channel_names.append('ax')
	channel_names.append('ay')
	channel_names.append('az')
	channel_names.append('trigger')
	channel_names.append('timestamp(ms)')
	global ALL_CHANNELS
	ALL_CHANNELS = np.array(channel_names[:-5])
	df.columns=channel_names
	transposed_data=df.T
	ch_names = df.columns.tolist()[:-5]
	ch_types = ['eeg' for i in range(32)]
	info = mne.create_info(ch_names=ch_names,ch_types=ch_types, sfreq=500)
	raw = mne.io.RawArray(transposed_data.values[:-5,:]/1e9, info) # Example: 33129984 nV = 0.033129984 V = 33129.984000000004 uV
	return raw

Set montage

In [5]:
def set_montage(raw):
    mont1020 = mne.channels.make_standard_montage('standard_1020')
    ind = [i for (i, channel) in enumerate(mont1020.ch_names) if channel in ALL_CHANNELS]
    mont1020_new = mont1020.copy()
    mont1020_new.ch_names = [mont1020.ch_names[x] for x in ind]
    kept_channel_info = [mont1020.dig[x+3] for x in ind]
    mont1020_new.dig = mont1020.dig[0:3]+kept_channel_info
    raw.set_montage(mont1020_new)
    return raw

Calculate absolute power

In [6]:
# def calculate_absolute_power(raw_data):
#     FMIN=0.01
#     FMAX=45
#     psds, freqs = mne.time_frequency.psd_array_welch(raw_data.get_data(), fmin=FMIN, fmax=FMAX, sfreq=SFREQ)
#     absolute_powers = {}
#     for band, (FMIN, FMAX) in FREQ_BANDS.items():
#         idx_band = np.logical_and(freqs >= FMIN, freqs <= FMAX)
#         absolute_power = np.trapz(psds[:, idx_band], dx=(freqs[1] - freqs[0]), axis=-1)
#         absolute_powers[band] = absolute_power
#     total_absolute_power = sum(absolute_powers.values())
#     return total_absolute_power


Calculate average components of a graph

In [7]:
def calculate_avergae_components(G):
    connected_components = list(nx.connected_components(G))
    component_avg_lengths = []
    component_nbc_values = []
    component_eglo_values = []
    component_cc_values = []
    component_eloc_values = []
    # component_degree_centrality = []

    for component in connected_components:
        subgraph = G.subgraph(component)
        component_avg_lengths.append(nx.average_shortest_path_length(subgraph))
        component_nbc_values.append(nx.betweenness_centrality(subgraph))
        component_eglo_values.append(nx.global_efficiency(subgraph))
        component_cc_values.append(nx.average_clustering(subgraph))
        component_eloc_values.append(nx.local_efficiency(subgraph))
        # component_degree_centrality.append(nx.degree_centrality(subgraph))

    merged_nbc_values = {}
    for dnbc in component_nbc_values:
        merged_nbc_values.update(dnbc)
    merged_nbc_values = dict(sorted(merged_nbc_values.items(), key=lambda item: item[0]))
 
    overall_avg_length = sum(component_avg_lengths) / len(component_avg_lengths)
    overall_eglo_values = sum(component_eglo_values) / len(component_eglo_values)
    overall_cc_values = sum(component_cc_values) / len(component_cc_values)
    overall_eloc_values = sum(component_eloc_values) / len(component_eloc_values)
    # overall_centrality_values = sum(component_degree_centrality) / len(component_degree_centrality)

    return [overall_avg_length, merged_nbc_values, overall_eglo_values, overall_cc_values, overall_eloc_values]

#### Time amplitude plot

In [8]:
def time_amplitude(raw, title):
    fig = raw.plot(
        n_channels=32, 
        scalings=SCALINGS,
        show=SHOW
        )
    fig.savefig(f'MNE-graphs/time-amplitude/{title}-EEG.png')

    print(raw.info)
    # TODO: Extract statistical features from time domain such as mean, median, variance, skewness, kurtosis

#### Power spectral density plot

In [9]:
def psd(raw, title):
    fig = raw.plot_psd(
        picks=raw.info['ch_names'], 
        show=SHOW)
    fig.savefig(f'MNE-graphs/psd-frequency/{title}.png')

#### Wavelet plot

In [10]:
# def wavelet(raw1, raw2=None, title):
#     raw1_avg = np.mean(raw1.get_data(), axis=0)
#     t = np.arange(0, 150, 1/SFREQ)
#     ts = t[:-1]
#     wavelet = 'db13'
#     level = 5 # level of decomposition based on your signal characteristics

#     # coeffs_multi_channel = []
#     # for i in range(32): 
#     coeffs1 = pywt.wavedec(raw1_avg, wavelet, level=level)
#         # coeffs_multi_channel.append(coeffs)
    
#     plt.figure(figsize=(10, 6))

#     # raw1
#     plt.subplot(4, 1, 1)
#     plt.plot(ts, raw1_avg, label='xyz')
#     plt.title('Channel 1')
#     plt.xlabel('Time')
#     plt.ylabel('Amplitude')
#     plt.legend()

#     # raw1 coeffs
#     plt.subplot(4, 1, 2)
#     for i in range(level+1):
#         plt.plot(t, pywt.upcoef('a', coeffs1[i], wavelet, level=level)[:len(t)], label=f'Level {i}')
#     plt.title('DWT')
#     plt.xlabel('Time')
#     plt.ylabel('Amplitude')
#     plt.legend()

#     if raw2:
#         raw2_avg = np.mean(raw2.get_data(), axis=0)
#         coeffs2 = pywt.wavedec(raw2_avg, wavelet, level=level)

#         # raw2
#         plt.subplot(4, 1, 3)
#         plt.plot(t, raw2_avg, label='ddd')
#         plt.title('Channel 1')
#         plt.xlabel('Time')
#         plt.ylabel('Amplitude')
#         plt.legend()

#         # raw2 coeffs
#         plt.subplot(4, 1, 4)
#         for i in range(level+1):
#             plt.plot(t, pywt.upcoef('a', coeffs2[i], wavelet, level=level)[:len(t)], label=f'Level {i}')
#         plt.title('DWT')
#         plt.xlabel('Time')
#         plt.ylabel('Amplitude')
#         plt.legend()

#     plt.tight_layout()
#     plt.show()

Band pass filtering

In [11]:
def band_pass_filter(raw, l_freq=FMIN, h_freq=FMAX):
    raw.filter(method= 'fir',
        phase= 'minimum',
        fir_window= 'hann',
        l_freq= l_freq,
        h_freq= h_freq)
    return raw

Rereferencing

Calculates the mean voltage from all electrodes at each time point and subtracts this mean from the voltage at each individual electrode.

In [12]:
def rereferencing(raw):
    raw.set_eeg_reference('average', projection=True).apply_proj() 
    return raw

Artifact Rejection (EOG/ECG) using Wavelet decomposition

In [13]:
# def wavelet_decompose(raw):
#     info = raw.info
#     coeffs = pywt.wavedec(raw.get_data(), 'db13', level=4)
#     threshold = 0.00001 # applying a shrinkage function that smoothly brings coefficients below the threshold to zero
#     coeffs_thresholded = [pywt.threshold(c, threshold, mode='soft') for c in coeffs] 
#     denoised_signal = pywt.waverec(coeffs_thresholded, 'db13')
#     raw = mne.io.RawArray(denoised_signal, info)
#     return raw

Crop signal

In [14]:
def crop(raw, l, h):
    return raw.crop(l,h)

Resampling

In [15]:
def resampling(raw, SFREQ=SFREQ):
    return raw.resample(sfreq=SFREQ)

Drop channels

Dropping extra channels in both groups
TODO: Instead of dropping we can convert 10-10 64 channels montage to 10-20 32 channels

In [16]:
def drop_channels(raw, raw_MDD):
    drop = []
    for chan in raw.info['ch_names']:
        if chan not in raw_MDD.info['ch_names']:
            drop.append(chan)
    raw.drop_channels(drop)
    return raw

Group 1 preprocessing

In [17]:
def g1_preprocess(raw):
    raw = set_montage(raw)
    raw = raw.crop(10, 110)
    raw  = band_pass_filter(raw, l_freq = FMIN, h_freq = FMAX)
    raw = rereferencing(raw)
    raw.info['bads'] = []
    return raw

Group 2 preprcessing

In [18]:
def g2_preprocessing(raw, raw_1, crop=True):
    # For epochs crop=False
    raw = drop_channels(raw, raw_1)
    if crop:
        raw = set_montage(raw)
        raw = raw.crop(50, 200)
    raw = resampling(raw)
    raw = band_pass_filter(raw, l_freq = 0.01, h_freq = 45)
    raw.info['bads'] = []
    return raw