In [1]:
import sys
from pathlib import Path
import pandas as pd
import tarfile
import urllib

def load_metadata():
  file_path = Path("data/train.csv")
  return pd.read_csv(file_path)
  
metadata = load_metadata()

def extract_eeg():
  eeg_dir = Path("../data/eeg")
  tarball_path = Path("data/eeg.tar.gz")
  if not tarball_path.is_file():
    url = 'https://dl.dropboxusercontent.com/scl/fi/5sina48c4naaxv6uze0fv/eeg.tar.gz?rlkey=r7ec191extynfcm8fy0tsiws5&dl=0'
    urllib.request.urlretrieve(url, tarball_path)
    with tarfile.open(tarball_path) as eeg_tarball:
      eeg_tarball.extractall()
    
extract_eeg()
metadata = metadata.drop_duplicates(subset='eeg_id')  # Dropping duplicate EEG IDs, 860 samples in total
metadata


Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,4144388963,140,604.0,1156825996,140,604.0,1451266906,59489,GRDA,0,0,0,0,3,0
1,2353475448,30,64.0,1002394133,30,64.0,4000072340,5339,LRDA,0,0,0,3,0,0
2,1618328341,9,52.0,900482955,9,52.0,4140697659,20198,GRDA,0,0,0,0,3,0
3,979865826,7,90.0,1626043434,7,90.0,919550440,1069,Other,1,1,4,1,4,5
4,521108392,0,0.0,827447277,0,0.0,1717414556,13134,Other,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
993,2114773317,4,20.0,1279113670,4,20.0,3366675421,44615,GRDA,0,0,0,0,3,0
994,3470749836,3,18.0,1959154324,3,18.0,1120809501,37552,Other,0,0,0,0,1,15
997,1322226281,2,18.0,1740512896,2,18.0,1697286566,49448,Other,0,0,0,0,0,3
998,628369060,15,98.0,13143748,17,292.0,1650460145,34998,GPD,0,3,7,0,2,4


In [2]:
# Load into Dask
import dask.dataframe as dd
import numpy as np
import glob

channel_order = ['Fp1', 'Fp2',
            'F7', 'F3', 'Fz', 'F4', 'F8', 
            'T3', 'C3', 'Cz', 'C4', 'T4', 
            'T5', 'P3', 'Pz', 'P4', 'T6', 
            'O1', 'O2',
            ]
sfreq = 200
eeg_ids = metadata['eeg_id'].to_list()

ddf_list = []
for eeg_id in eeg_ids:
  f_name = f'data/eeg/{eeg_id}.parquet'
  temp_ddf = dd.read_parquet(f_name).drop('EKG', axis=1)[channel_order]
  temp_ddf['eeg_id'] = str(eeg_id)
  temp_ddf = temp_ddf.set_index('eeg_id')
  ddf_list.append(temp_ddf)

ddf = dd.concat(ddf_list)
ddf



Unnamed: 0_level_0,Fp1,Fp2,F7,F3,Fz,F4,F8,T3,C3,Cz,C4,T4,T5,P3,Pz,P4,T6,O1,O2
npartitions=860,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [3]:
df = ddf.compute()
df

Unnamed: 0_level_0,Fp1,Fp2,F7,F3,Fz,F4,F8,T3,C3,Cz,C4,T4,T5,P3,Pz,P4,T6,O1,O2
eeg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
4144388963,4.340000,-30.620001,33.320000,14.510000,-27.790001,21.660000,-37.009998,11.040000,-7.910000,-19.620001,-7.380000,9.900000,5.670000,-9.180000,-0.63,9.350000,9.260000,-13.74,-7.970000
4144388963,1.900000,-28.379999,24.629999,2.340000,-28.870001,18.320000,-29.200001,21.520000,-12.530000,-18.240000,-3.400000,9.310000,7.950000,-7.210000,2.29,14.290000,12.010000,-11.14,-3.830000
4144388963,18.959999,-14.940000,51.020000,-7.860000,-28.110001,9.710000,-16.610001,37.540001,-17.780001,-19.290001,-14.670000,14.190000,9.460000,-8.130000,-0.56,8.720000,9.810000,-8.72,-6.950000
4144388963,23.230000,-13.990000,45.340000,11.850000,-25.930000,0.320000,1.800000,36.529999,-7.400000,-19.320000,-26.650000,2.440000,7.960000,-6.550000,-0.41,4.450000,5.170000,-6.82,-5.790000
4144388963,3.680000,-24.469999,23.020000,12.820000,-26.860001,21.740000,-0.360000,-8.000000,-4.640000,-18.670000,-16.360001,53.560001,6.290000,-7.360000,0.08,8.530000,10.930000,-10.94,-5.860000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
735834491,-39.990002,-11.380000,14.590000,-20.500000,-15.240000,-23.340000,-57.529999,-12.530000,7.420000,-20.930000,-16.010000,1.680000,-3.070000,-1.150000,2.43,-0.170000,-23.500000,-6.85,3.350000
735834491,-74.470001,-50.410000,-10.230000,-46.110001,-35.630001,-53.380001,-85.099998,-32.070000,-14.630000,-38.250000,-40.669998,-25.490000,-22.780001,-22.990000,-15.43,-26.559999,-49.869999,-26.16,-22.389999
735834491,-66.209999,-31.180000,-4.690000,-39.560001,-26.389999,-49.680000,-75.089996,-27.030001,-7.510000,-28.959999,-32.930000,-10.000000,-10.790000,-11.880000,-5.56,-20.100000,-42.110001,-11.15,-14.310000
735834491,-27.549999,6.770000,27.480000,-10.090000,-2.940000,-2.880000,-43.990002,1.900000,20.740000,-7.720000,-3.350000,20.650000,21.969999,22.360001,15.14,10.030000,-2.800000,14.06,11.530000


In [4]:
# Extracting top 3 channels based on variance for all samples
# 1000 samples computation duration = approx. 10 minutes
from src.feature_extraction import calculate_all_samples

top_channels_df = calculate_all_samples(df, eeg_ids, 860) # 10 samples for testing
top_channels_df # NEED TO EXTRACT FEATURES FROM CHANNELS IN THIS DATA STRUCTURE

Unnamed: 0,0,1,2
4144388963,Fp1,Fp2,F7
2353475448,Fp1,F3,F4
1618328341,Fp2,Fp1,F7
979865826,O2,C3,Fp2
521108392,O2,Fp1,Fp2
...,...,...,...
2114773317,T5,O2,Fp2
3470749836,F7,Fp1,T3
1322226281,Fp1,Fp2,T4
628369060,Fp2,Fz,F4


In [5]:
sig1 = ddf.loc['2161044411'].compute()
sig1

Unnamed: 0_level_0,Fp1,Fp2,F7,F3,Fz,F4,F8,T3,C3,Cz,C4,T4,T5,P3,Pz,P4,T6,O1,O2
eeg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
2161044411,-4.990000,1.080000,36.369999,38.619999,20.420000,12.780000,-7.110000,26.900000,16.740000,12.02,-10.150000,30.350000,19.709999,-45.099998,0.530000,-4.360000,-16.740000,-25.250000,-11.820000
2161044411,-8.370000,-8.270000,34.750000,26.770000,9.920000,-2.990000,-14.990000,-14.420000,1.290000,-0.29,-25.510000,22.200001,-0.790000,-56.259998,-10.000000,-13.280000,-23.160000,-32.880001,-18.370001
2161044411,-21.330000,-16.969999,11.280000,17.570000,-1.770000,-13.180000,-21.129999,13.460000,-11.670000,-10.41,-32.599998,15.000000,-11.220000,-64.160004,-17.889999,-20.670000,-28.570000,-39.419998,-25.190001
2161044411,-24.629999,-15.810000,14.430000,3.310000,-2.780000,-11.130000,-20.280001,46.669998,-34.200001,-10.58,-31.580000,16.480000,-16.330000,-67.500000,-18.520000,-19.030001,-26.760000,-38.450001,-24.690001
2161044411,-11.920000,-16.430000,53.830002,14.350000,1.120000,-12.720000,-25.570000,43.470001,-37.750000,-9.70,-34.330002,12.290000,-14.730000,-69.379997,-18.480000,-22.090000,-29.139999,-42.130001,-25.360001
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2161044411,4.140000,66.199997,49.080002,38.450001,19.770000,49.560001,14.240000,110.129997,21.420000,7.18,-19.959999,82.129997,118.800003,-7.350000,13.990000,-4.540000,-27.879999,-35.990002,-18.440001
2161044411,-2.330000,25.850000,31.440001,9.480000,8.530000,16.730000,8.510000,84.879997,11.390000,-2.26,-17.290001,79.260002,103.970001,-18.889999,8.190000,1.130000,-4.010000,-46.700001,-16.150000
2161044411,12.400000,49.529999,46.630001,16.400000,17.370001,42.180000,24.100000,86.029999,28.450001,9.74,-13.860000,93.599998,108.610001,-10.400000,17.160000,8.990000,3.950000,-41.369999,-5.970000
2161044411,-14.560000,24.459999,32.910000,5.550000,1.770000,-8.940000,5.880000,77.070000,16.389999,-3.58,-21.520000,84.860001,95.550003,-21.540001,6.980000,-1.140000,-7.150000,-52.160000,-14.420000


In [6]:
## MNE setup
import mne

mne_info = mne.create_info(ch_names=sig1.columns.tolist(), sfreq=200, ch_types='eeg')
mne_info.set_montage('standard_1020')
    
sig1_data = np.array(sig1.transpose())
sig1_data = np.nan_to_num(sig1_data)
    
raw = mne.io.RawArray(sig1_data, mne_info)
raw.apply_function(lambda x: x / 20e6, picks='eeg')

Creating RawArray with float64 data, n_channels=19, n_times=14800
    Range : 0 ... 14799 =      0.000 ...    73.995 secs
Ready.


0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,22 points
Good channels,19 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,200.00 Hz
Highpass,0.00 Hz
Lowpass,100.00 Hz
Duration,00:01:14 (HH:MM:SS)


In [7]:
# Apply filters
from src.preprocessing import notch_filter, bp_filter, standardize

l_freq = 1.0
h_freq = 70.0

df = notch_filter(df, 60)
df = bp_filter(df, 1.0, 70)
df = standardize(df)
df

Creating RawArray with float64 data, n_channels=19, n_times=29720400
    Range : 0 ... 29720399 =      0.000 ... 148601.995 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 59.10 Hz)
- Upper passband edge: 60.65 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 60.90 Hz)
- Filter length: 1321 samples (6.605 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    9.2s


Creating RawArray with float64 data, n_channels=19, n_times=29720400
    Range : 0 ... 29720399 =      0.000 ... 148601.995 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 70 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 70.00 Hz
- Upper transition bandwidth: 17.50 Hz (-6 dB cutoff frequency: 78.75 Hz)
- Filter length: 661 samples (3.305 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:   16.8s


Unnamed: 0_level_0,Fp1,Fp2,F7,F3,Fz,F4,F8,T3,C3,Cz,C4,T4,T5,P3,Pz,P4,T6,O1,O2
eeg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
4144388963,0.011283,0.896449,0.884797,-0.471276,-1.721524,-0.902853,0.007172,0.263442,1.648809,0.755831,-1.701600,-0.560082,0.427429,0.276038,1.843657,0.776621,-1.418749,-0.672857,-0.342590
4144388963,0.284405,1.963306,1.120008,-1.206780,-1.485544,-0.266160,0.365642,0.573411,1.098524,0.258417,-1.549086,-0.998262,0.130180,-0.064835,1.054635,0.806012,-1.547966,-0.825199,0.289290
4144388963,-0.063982,1.312294,0.481871,-1.693770,-0.905352,-0.047898,0.621270,1.882831,0.724230,-1.192859,-1.231428,-0.623811,0.102722,0.923889,1.294931,0.028620,-1.612263,-0.461945,0.460650
4144388963,0.291819,1.601841,0.914265,-1.652425,-1.019082,-0.195470,-0.038537,1.380901,0.651691,-1.274143,-0.972838,0.108583,0.922287,1.299612,0.494318,-0.898237,-1.545120,-0.671329,0.601865
4144388963,0.877254,0.687077,-0.156826,-1.409273,-0.634556,0.393029,0.347775,1.700889,1.026435,-1.763409,-0.690831,0.799441,0.475176,1.362960,0.077414,-2.019527,-0.835284,-0.158702,-0.079042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
735834491,1.452691,-0.210558,-0.492525,1.014448,0.617263,0.309325,0.126559,-0.480378,1.228044,1.413593,-0.133942,-0.160932,-0.331164,0.101733,0.637515,0.202091,-0.825493,-2.419308,-2.048960
735834491,0.516668,-1.086366,-0.980298,0.100393,0.140696,0.641374,1.289093,0.017183,-1.416835,-0.685799,0.518389,1.573385,1.760938,-0.203149,-1.688528,-0.201256,0.111158,-1.456036,1.048988
735834491,2.826115,0.790033,0.444021,-0.741464,-0.915052,1.056363,-0.561535,-0.736579,1.206857,-0.420858,-1.176683,0.040314,0.361949,0.344895,0.167977,0.207614,-0.712273,-1.639740,-0.541954
735834491,-0.087751,-1.047592,-0.870010,0.011761,1.863578,2.515086,0.000094,-1.271951,-0.370206,-0.070828,0.043683,0.199790,1.064105,0.854764,-0.283924,0.413710,-0.754051,-1.558746,-0.651513


In [8]:
from src.feature_extraction import extract_features_all_samples
features_df = extract_features_all_samples(df, top_channels_df)
features_df

Unnamed: 0,std,mean,max,min,var,med,skew,kurt,ent,mom,pow
4144388963,1.063415,0.004105,3.894739,-3.876958,1.132815,0.004889,0.000267,2.531627,6.376908,3.287149,49767790277.470947
2353475448,1.034221,0.004227,3.6184,-3.584313,1.075266,0.011796,-0.020749,2.596045,5.501748,3.065472,1128791543.023513
1618328341,1.051975,0.009905,3.677406,-3.533425,1.108152,-0.005811,0.044025,2.503976,5.313307,3.105324,576119837.570129
979865826,1.039166,0.002955,3.769261,-3.642915,1.084759,0.002044,0.022057,2.698057,5.469183,3.216387,850472736.558914
521108392,1.102458,-0.001273,3.71895,-3.669903,1.216623,-0.001189,-0.002571,2.599759,4.94512,3.880334,121664648.885842
...,...,...,...,...,...,...,...,...,...,...,...
2114773317,1.039171,-0.003363,3.857879,-3.563809,1.082236,-0.004157,0.024376,2.710328,5.27491,3.192541,366436797.5669
3470749836,1.031516,0.002624,3.487895,-3.511269,1.068229,0.016101,-0.035007,2.643787,5.120299,3.057462,209385365.810265
1322226281,1.046505,0.007396,3.723876,-3.673957,1.100163,0.011183,0.037173,2.66164,5.271773,3.254092,388890929.978747
628369060,1.002498,-0.003406,3.470537,-3.45019,1.006382,-0.006814,0.0119,2.597972,5.511878,2.631447,1163395361.111273


In [9]:
from sklearn.model_selection import train_test_split
y = metadata[['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']]
y

X_train, X_test, y_train, y_test = train_test_split(features_df, y, test_size=0.2, random_state=42)


In [10]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.multioutput import MultiOutputClassifier

# Initialize the base classifier
base_clf = DecisionTreeClassifier()

# Create the Binary Relevance classifier
clf = MultiOutputClassifier(base_clf)

# Train the classifier
clf.fit(X_train, y_train)

# Make predictions
y_pred = clf.predict(X_test)


In [11]:
from sklearn.metrics import classification_report

# Convert y_test DataFrame to a NumPy array
y_test_np = y_test.to_numpy()

# Print the classification report for each label
for i in range(y_test_np.shape[1]):
    print(f"Classification Report for Label {i}:\n")
    print(classification_report(y_test_np[:, i], y_pred[:, i]))

Classification Report for Label 0:

              precision    recall  f1-score   support

           0       0.70      0.68      0.69       118
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00         1
           3       0.32      0.25      0.28        44
           4       0.00      0.00      0.00         2
           5       0.00      0.00      0.00         1
          12       0.00      0.00      0.00         0

    accuracy                           0.53       172
   macro avg       0.15      0.13      0.14       172
weighted avg       0.56      0.53      0.54       172

Classification Report for Label 1:

              precision    recall  f1-score   support

           0       0.69      0.68      0.69       120
           1       0.08      0.05      0.06        21
           2       0.22      0.22      0.22         9
           3       0.12      0.20      0.15         5
           4       0.00      0.00      0.00         4
      

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [12]:
# Relevant code for testing purposes

# print(np.var(df['Fp1'], axis=0)) #variance for one col/channel
# fpl = df['Pz'].fillna(0).to_numpy() #converting to numby array for easier computation
#print(np.var(df['Fp1'], axis=0)) #variance for one col/channel

# one sample and their channels
# sig1 = df.loc[['521108392']]
# sig1
# sig1['Fp1'] # one sample and single channel
# print(np.var(fpl, axis=0))

# variance for one channel(Fp1) in one signal(4144388963)
# np.var(sig1['F7'].to_numpy())

# File created to test the correctness of extracted values using MATLAB
# Save Fp1 channel data into a MATLAB file
# import scipy.io
# scipy.io.savemat('Fp1_data.mat', {'Fp1_data': sig1['Fp1']})