In [1]:
import pandas as pd
import numpy as np

In [2]:
train_csv = pd.read_csv('data/train.csv')
train_csv

Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,1628180742,0,0.0,353733,0,0.0,127492639,42516,Seizure,3,0,0,0,0,0
1,1628180742,1,6.0,353733,1,6.0,3887563113,42516,Seizure,3,0,0,0,0,0
2,1628180742,2,8.0,353733,2,8.0,1142670488,42516,Seizure,3,0,0,0,0,0
3,1628180742,3,18.0,353733,3,18.0,2718991173,42516,Seizure,3,0,0,0,0,0
4,1628180742,4,24.0,353733,4,24.0,3080632009,42516,Seizure,3,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
106795,351917269,6,12.0,2147388374,6,12.0,4195677307,10351,LRDA,0,0,0,3,0,0
106796,351917269,7,14.0,2147388374,7,14.0,290896675,10351,LRDA,0,0,0,3,0,0
106797,351917269,8,16.0,2147388374,8,16.0,461435451,10351,LRDA,0,0,0,3,0,0
106798,351917269,9,18.0,2147388374,9,18.0,3786213131,10351,LRDA,0,0,0,3,0,0


In [3]:
def read_parquet_cache(path):
    cache = {}

    def read_parquet(id_):
        if id_ not in cache:
            cache[id_] = pd.read_parquet(f'{path}{id_}.parquet')
        return cache[id_]

    return read_parquet

read_eeg = read_parquet_cache(path='data/train_eegs/')
read_eeg_test = read_parquet_cache(path='data/test_eegs/')
read_spg = read_parquet_cache(path='data/train_spectrograms/')
read_spg_test = read_parquet_cache(path='data/test_spectrograms/')

In [4]:
def eeg_window(row, train=True):
    eeg_data = read_eeg(row.eeg_id) if train else read_eeg_test(row.eeg_id)
    if train:
        eeg_offset = int(row.eeg_label_offset_seconds)
        eeg_data = eeg_data.iloc[(200 * eeg_offset):(200 * (eeg_offset + 50))]
    return eeg_data

eeg_window(train_csv.iloc[0])

Unnamed: 0,Fp1,F3,C3,P3,F7,T3,T5,O1,Fz,Cz,Pz,Fp2,F4,C4,P4,F8,T4,T6,O2,EKG
0,-80.519997,-70.540001,-80.110001,-108.750000,-120.330002,-88.620003,-101.750000,-104.489998,-99.129997,-90.389999,-97.040001,-77.989998,-88.830002,-112.120003,-108.110001,-95.949997,-98.360001,-121.730003,-106.449997,7.920000
1,-80.449997,-70.330002,-81.760002,-107.669998,-120.769997,-90.820000,-104.260002,-99.730003,-99.070000,-92.290001,-96.019997,-84.500000,-84.989998,-115.610001,-103.860001,-97.470001,-89.290001,-115.500000,-102.059998,29.219999
2,-80.209999,-75.870003,-82.050003,-106.010002,-117.500000,-87.489998,-99.589996,-96.820000,-119.680000,-99.360001,-91.110001,-99.440002,-104.589996,-127.529999,-113.349998,-95.870003,-96.019997,-123.879997,-105.790001,45.740002
3,-84.709999,-75.339996,-87.480003,-108.970001,-121.410004,-94.750000,-105.370003,-100.279999,-113.839996,-102.059998,-95.040001,-99.230003,-101.220001,-125.769997,-111.889999,-97.459999,-97.180000,-128.940002,-109.889999,83.870003
4,-90.570000,-80.790001,-93.000000,-113.870003,-129.960007,-102.860001,-118.599998,-101.099998,-107.660004,-102.339996,-98.510002,-95.300003,-88.930000,-115.639999,-99.800003,-97.500000,-88.730003,-114.849998,-100.250000,97.769997
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,-140.039993,-128.100006,-137.339996,-160.830002,-153.630005,-136.279999,-137.009995,-93.349998,-145.130005,-155.830002,-124.650002,-123.250000,-127.709999,-169.759995,-68.489998,-117.669998,-69.239998,-115.309998,-123.860001,65.010002
9996,-152.169998,-161.449997,-173.210007,-165.320007,-143.570007,-124.150002,-127.339996,-87.309998,-160.919998,-158.360001,-121.870003,-129.550003,-121.470001,-120.339996,-68.029999,-135.130005,-105.190002,-114.330002,-121.029999,47.090000
9997,-149.619995,-147.479996,-171.960007,-152.589996,-137.279999,-105.550003,-122.220001,-80.010002,-156.039993,-155.119995,-116.360001,-118.099998,-113.690002,-102.760002,-67.839996,-120.410004,-109.099998,-116.419998,-119.099998,95.589996
9998,-126.860001,-122.889999,-125.879997,-130.339996,-134.779999,-134.350006,-127.080002,-76.739998,-137.649994,-146.800003,-111.720001,-114.199997,-106.739998,-104.699997,-60.240002,-154.119995,-129.639999,-110.029999,-116.239998,72.980003


In [5]:
def spg_window(row, train=True):
    spg_data = read_spg(row.spectrogram_id) if train else read_spg_test(row.spectrogram_id)
    if train:
        spg_offset = int(row.spectrogram_label_offset_seconds)
        spg_data = spg_data.loc[(spg_data.time >= spg_offset) & (spg_data.time < spg_offset + 600)]
        spg_data = spg_data.drop(columns=['time'])
    return spg_data

spg_window(train_csv.iloc[0]).iloc[142:147]

Unnamed: 0,LL_0.59,LL_0.78,LL_0.98,LL_1.17,LL_1.37,LL_1.56,LL_1.76,LL_1.95,LL_2.15,LL_2.34,...,RP_18.16,RP_18.36,RP_18.55,RP_18.75,RP_18.95,RP_19.14,RP_19.34,RP_19.53,RP_19.73,RP_19.92
142,4.86,5.32,6.45,6.56,5.19,5.16,9.11,14.18,10.6,11.63,...,0.22,0.16,0.14,0.14,0.12,0.1,0.15,0.19,0.27,0.46
143,7.61,10.67,14.49,13.66,11.09,9.34,13.84,13.98,14.08,16.969999,...,0.13,0.19,0.2,0.18,0.21,0.18,0.15,0.15,0.29,0.44
144,5.66,9.55,16.73,16.299999,13.72,19.07,11.3,13.48,15.11,15.89,...,0.23,0.24,0.19,0.17,0.17,0.16,0.23,0.27,0.27,0.23
145,4.16,3.9,5.5,7.18,10.09,16.67,16.450001,19.700001,20.91,15.38,...,0.14,0.14,0.14,0.14,0.16,0.25,0.28,0.28,0.28,0.19
146,3.39,4.59,5.64,7.58,9.07,13.0,22.280001,30.65,29.84,28.66,...,0.21,0.17,0.15,0.08,0.11,0.11,0.11,0.12,0.08,0.08


In [6]:
def eeg_features(eeg_df, w=2):
    features = []
    # divide the 50s (sampled at 200Hz) window into 5 10s windows, index of the centre (3rd) window = 200*2 = 400 to 600 (excl.)
    for i in range(-w, w + 1):
        df = eeg_df.iloc[(400 + 200 * i):(600 + 200 * i)].mean(axis=0)
        df.index = [f'{label}_mean_{i}' for label in df.index]
        features.append(df) 

        df = eeg_df.iloc[(400 + 200 * i):(600 + 200 * i)].std(axis=0)
        df.index = [f'{label}_std_{i}' for label in df.index]
        features.append(df) 
    return pd.concat(features, axis=0)

# eeg_window(train_csv.iloc[0]).std(axis=0).mean()

# eeg_features(eeg_window(train_csv.iloc[0]))

train_csv.iloc[:10].apply(lambda row: eeg_features(eeg_window(row)), axis=1)

Unnamed: 0,Fp1_mean_-2,F3_mean_-2,C3_mean_-2,P3_mean_-2,F7_mean_-2,T3_mean_-2,T5_mean_-2,O1_mean_-2,Fz_mean_-2,Cz_mean_-2,...,Pz_std_2,Fp2_std_2,F4_std_2,C4_std_2,P4_std_2,F8_std_2,T4_std_2,T6_std_2,O2_std_2,EKG_std_2
0,-121.486092,-117.349953,-111.771599,-123.695099,-153.6073,-109.883148,-116.213058,-109.470352,-125.683395,-102.833351,...,14.946475,14.121632,14.244462,15.730001,19.757803,13.546809,13.426604,19.294004,12.898739,220.580551
1,-121.052208,-123.975342,-118.901901,-141.278351,-140.983307,-115.817947,-122.768456,-110.977547,-137.546295,-114.380608,...,20.191065,20.420134,18.895329,19.520639,19.432758,23.48761,19.623608,18.136122,15.016601,221.974289
2,-90.24881,-110.285797,-111.920097,-138.237564,-92.077248,-103.727791,-115.890549,-106.882904,-136.722946,-120.243301,...,18.915499,18.179863,18.19734,20.241385,23.951912,19.742643,18.851969,20.649118,12.514921,221.843094
3,-133.292603,-138.523941,-132.449707,-153.768051,-146.673645,-121.04055,-127.577751,-120.916458,-155.112396,-131.483704,...,12.233391,12.732665,14.876089,15.382173,18.535763,10.742869,13.385337,16.191242,9.787531,226.133621
4,-121.893593,-124.286751,-116.233849,-135.954895,-143.630157,-107.437103,-116.323906,-115.659554,-143.832886,-119.959503,...,14.71642,21.02224,21.440746,22.215027,25.95714,22.83021,22.980839,27.40246,13.296265,231.386322
5,-122.946594,-130.387054,-127.61734,-146.714401,-144.733658,-119.533798,-122.998802,-121.8433,-150.449905,-125.496056,...,16.63859,20.982431,18.965645,18.717407,21.742239,20.127821,19.233074,22.791016,10.40808,233.179504
6,-108.862503,-119.141891,-114.889153,-126.948242,-119.082497,-102.839851,-106.427536,-126.999001,-244.85231,-121.613655,...,21.522009,18.321779,20.729116,30.628258,31.503038,18.51379,24.307104,29.182018,16.50773,238.243484
7,-103.736847,-101.755753,-115.621056,-111.653801,-138.533463,-95.976448,-96.453903,-190.291092,-334.381042,-186.611099,...,23.005024,21.489351,26.47086,33.241924,31.765247,24.075081,32.068352,28.493212,16.528858,230.89566
8,-114.489456,-116.515495,-126.644646,-122.278557,-146.186951,-107.421799,-108.305954,-197.434494,-322.792969,-215.042892,...,21.042933,17.751595,21.859701,37.939106,26.767536,22.105419,36.574669,26.149015,15.613263,224.594116
9,18.540298,4.57415,18.653101,-0.54245,32.446899,17.05345,22.895401,12.206549,-10.976049,-16.659849,...,6.180764,8.319602,9.881698,8.705383,10.76403,5.042293,4.117795,4.201314,5.20881,19.947599


In [7]:
def spg_features(spg_df, w=2):
    features = []
    # divide the 600s window into 10s windows from the centre, and 2 5s windows on either side, the centre one (295s to 305s) has index (295-1)/2 = 147 to 152 (excl.)
    for i in range(-w, w + 1):
        df = spg_df.iloc[(147 + 5 * i):(152 + 5 * i)].mean(axis=0)
        df.index = [f'{label}_mean_{i}' for label in df.index]
        features.append(df) 

        df = spg_df.iloc[(147 + 5 * i):(152 + 5 * i)].std(axis=0)
        df.index = [f'{label}_std_{i}' for label in df.index]
        features.append(df) 
    return pd.concat(features, axis=0)

# spg_window(train_csv.iloc[0]).std(axis=0).mean()

# spg_features(spg_window(train_csv.iloc[0]))

train_csv.iloc[256:259].apply(lambda row: spg_features(spg_window(row)), axis=1)

Unnamed: 0,LL_0.59_mean_-2,LL_0.78_mean_-2,LL_0.98_mean_-2,LL_1.17_mean_-2,LL_1.37_mean_-2,LL_1.56_mean_-2,LL_1.76_mean_-2,LL_1.95_mean_-2,LL_2.15_mean_-2,LL_2.34_mean_-2,...,RP_18.16_std_2,RP_18.36_std_2,RP_18.55_std_2,RP_18.75_std_2,RP_18.95_std_2,RP_19.14_std_2,RP_19.34_std_2,RP_19.53_std_2,RP_19.73_std_2,RP_19.92_std_2
256,9.53,13.104001,18.334,22.152,25.763998,32.104,33.365997,35.935997,42.995998,40.5,...,,,,,,,,,,
257,20.136,25.09,30.256001,34.214001,33.116001,37.596001,42.056,56.767998,63.057995,65.961998,...,,,,,,,,,,
258,28.48,32.402,34.860004,41.416,35.681999,41.344002,52.950001,71.830002,77.528,78.995995,...,,,,,,,,,,


In [8]:
def actual_median(s):
    return s.iloc[(s - s.median()).abs().argsort().iloc[0]]

actual_median(pd.Series([1])), actual_median(pd.Series([1, 2])), actual_median(pd.Series([1, 2, 3])), actual_median(pd.Series([1, 2, 3, 4]))

(1, 1, 2, 2)

In [9]:
# remove duplicate eeg_ids (keeping the median one only)

df = train_csv.copy()

df = df.groupby('eeg_id')[['eeg_label_offset_seconds']].agg(actual_median)

df = pd.merge(df, train_csv, on=['eeg_id', 'eeg_label_offset_seconds'], how='left')

df

Unnamed: 0,eeg_id,eeg_label_offset_seconds,eeg_sub_id,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,568657,6.0,1,789577333,1,6.0,3640441665,20654,Other,0,0,3,0,2,7
1,582999,18.0,5,1552638400,5,18.0,1179854295,20230,LPD,0,12,0,1,0,1
2,642382,0.0,0,14960202,12,1008.0,3254468733,5955,Other,0,0,0,0,0,1
3,751790,0.0,0,618728447,4,908.0,2898467035,38549,GPD,0,0,1,0,0,0
4,778705,0.0,0,52296320,0,0.0,3255875127,40955,Other,0,0,0,0,0,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17084,4293354003,0.0,0,1188113564,0,0.0,447244163,16610,GRDA,0,0,0,0,1,1
17085,4293843368,0.0,0,1549502620,0,0.0,1618953053,15065,GRDA,0,0,0,0,1,1
17086,4294455489,0.0,0,2105480289,0,0.0,469526364,56,Other,0,0,0,0,0,1
17087,4294858825,6.0,2,657299228,2,6.0,3251917981,4312,Other,0,0,0,0,1,14


In [10]:
features_eeg = df.apply(lambda row: eeg_features(eeg_window(row)), axis=1)
features_eeg

Unnamed: 0,Fp1_mean_-2,F3_mean_-2,C3_mean_-2,P3_mean_-2,F7_mean_-2,T3_mean_-2,T5_mean_-2,O1_mean_-2,Fz_mean_-2,Cz_mean_-2,...,Pz_std_2,Fp2_std_2,F4_std_2,C4_std_2,P4_std_2,F8_std_2,T4_std_2,T6_std_2,O2_std_2,EKG_std_2
0,-79.983055,-81.323952,-18.681099,-47.863651,-187.592651,-29.112251,-78.728798,2.192550,39.954250,-62.504601,...,45.104725,99.185669,73.851074,59.138035,62.510010,95.654678,92.684517,68.819267,115.077171,10176.622070
1,1.592150,-10.108700,3.661750,35.589249,0.186300,-8.618750,-18.085199,-11.783500,-24.511650,-5.612250,...,13.066787,24.789822,21.046587,14.673660,13.633693,22.423006,28.336014,19.079601,12.253638,1.978712
2,-8.631650,24.218800,-21.270103,-13.353702,4.786050,-28.232948,-22.083349,-30.318253,2.992500,-2.851850,...,13.886168,28.948927,18.576426,154.855148,13.299860,21.070679,14.141970,11.991049,17.199745,39.931129
3,-27.117598,-7.614099,-9.483701,7.277600,-5.889450,-6.114050,8.603250,18.276499,-21.721901,-26.365049,...,377.822021,406.619690,392.210266,391.809448,375.332611,398.178040,387.610443,390.054474,386.756653,148.025818
4,0.638400,-11.500050,-42.373798,-43.817551,1.073700,-6.478049,-24.141251,-21.696600,-68.657951,-29.719650,...,79.100151,128.575195,13.842060,15.516143,448.307007,32.760445,22.795494,18.996525,18.658754,28.780083
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17084,-33.408051,-32.803398,-40.341198,-46.259800,-30.711050,-25.807602,-36.809303,-33.608646,-25.831551,-9.568650,...,18.370039,24.297533,18.158159,17.120998,9.450650,16.399866,23.705229,14.046872,9.358005,4.358585
17085,-27.330853,-37.369099,-19.988249,-25.696051,17.756750,-24.991299,-51.183102,-5.006700,3.874850,-21.615349,...,8.810526,25.546494,21.822578,15.407010,8.811084,36.052746,29.845596,19.349653,16.687122,14.719275
17086,9999.000977,9999.000977,9999.000977,9999.000977,9999.000977,9999.000977,9999.000977,9999.000977,9999.000977,9999.000977,...,0.450512,0.450512,0.450512,0.450512,0.450512,0.450512,0.450512,0.450512,0.450512,0.000000
17087,-1912.692505,-1467.127808,-1404.491577,-1338.480347,-2198.383545,-1344.772705,-1744.977295,-1378.312866,-1518.126099,-2016.969482,...,9.318738,20.667873,43.191296,30.294207,29.289835,21.108393,25.434523,30.786783,9.881910,5.757161


In [11]:
features_spg = df.apply(lambda row: spg_features(spg_window(row)), axis=1)
features_spg

Unnamed: 0,LL_0.59_mean_-2,LL_0.78_mean_-2,LL_0.98_mean_-2,LL_1.17_mean_-2,LL_1.37_mean_-2,LL_1.56_mean_-2,LL_1.76_mean_-2,LL_1.95_mean_-2,LL_2.15_mean_-2,LL_2.34_mean_-2,...,RP_18.16_std_2,RP_18.36_std_2,RP_18.55_std_2,RP_18.75_std_2,RP_18.95_std_2,RP_19.14_std_2,RP_19.34_std_2,RP_19.53_std_2,RP_19.73_std_2,RP_19.92_std_2
0,148.945999,184.276001,207.082001,226.774002,257.697998,238.506012,248.500000,199.630005,105.377991,83.867996,...,5.916032,3.873358,2.106198,1.200825,1.087658,1.692164,2.942052,4.868412,7.457153,9.085946
1,36.301998,45.697998,41.302002,27.203999,19.720001,11.401999,7.579999,6.974000,7.126000,4.717999,...,0.008367,0.005477,0.008367,0.016733,0.016733,0.019494,0.014832,0.011402,0.018708,0.018166
2,8.346001,8.978000,9.190000,6.072000,4.040000,3.398000,2.332000,2.278000,1.384000,1.368000,...,0.047434,0.046368,0.038471,0.031145,0.013038,0.038079,0.049800,0.048270,0.040249,0.013038
3,29.602001,37.526001,47.644001,69.781998,60.209999,76.968002,76.788002,88.792007,75.302002,73.494003,...,0.041473,0.026833,0.019494,0.017889,0.024900,0.028636,0.033466,0.039115,0.061074,0.021679
4,31.275997,31.728001,26.624002,25.548000,21.024000,21.142002,20.284000,12.744000,10.117999,8.334000,...,0.085732,0.130115,0.122066,0.087579,0.070922,0.112606,0.161493,0.131833,0.122147,0.105688
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17084,42.470001,47.935997,42.306000,28.694000,15.622000,11.334000,8.382000,7.934000,6.390000,4.768000,...,0.004472,0.008367,0.008367,0.007071,0.004472,0.005477,0.005477,0.005477,0.020000,0.020736
17085,17.914001,23.774000,27.480000,31.572001,34.119999,31.160000,29.743998,23.577999,10.226000,6.192000,...,0.108074,0.101833,0.082280,0.131719,0.129808,0.118870,0.064420,0.028636,0.171406,0.223652
17086,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,,,,,,,,,,
17087,1172.994019,392.623993,128.654007,42.380005,15.893999,7.095999,5.480000,7.466001,6.636000,5.038000,...,0.008944,0.010000,0.008367,0.013038,0.012247,0.008367,0.008367,0.010954,0.004472,0.005477


In [12]:
from sklearn.preprocessing import StandardScaler

data_processed = df.copy()

col_features = list(features_eeg.columns) + list(features_spg.columns)
col_targets = list(df.columns[-6:])

y = data_processed[col_targets]
y = y.div(y.sum(axis=1), axis=0)

data_processed[col_targets] = y

data_processed = pd.concat([data_processed, features_eeg, features_spg], axis=1)
data_processed = data_processed.dropna()
data_processed = data_processed.reset_index()

std_scaler = StandardScaler()

data_processed[col_features] = std_scaler.fit_transform(data_processed[col_features])

data_processed

Unnamed: 0,index,eeg_id,eeg_label_offset_seconds,eeg_sub_id,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,...,RP_18.16_std_2,RP_18.36_std_2,RP_18.55_std_2,RP_18.75_std_2,RP_18.95_std_2,RP_19.14_std_2,RP_19.34_std_2,RP_19.53_std_2,RP_19.73_std_2,RP_19.92_std_2
0,0,568657,6.0,1,789577333,1,6.0,3640441665,20654,Other,...,-0.046766,-0.048632,-0.048490,-0.048534,-0.049626,-0.052233,-0.052701,-0.039578,-0.029061,-0.028728
1,1,582999,18.0,5,1552638400,5,18.0,1179854295,20230,LPD,...,-0.047230,-0.048961,-0.048664,-0.048629,-0.049710,-0.052360,-0.052897,-0.039719,-0.029137,-0.028822
2,2,642382,0.0,0,14960202,12,1008.0,3254468733,5955,Other,...,-0.047227,-0.048958,-0.048662,-0.048627,-0.049710,-0.052359,-0.052895,-0.039718,-0.029136,-0.028822
3,3,751790,0.0,0,618728447,4,908.0,2898467035,38549,GPD,...,-0.047227,-0.048959,-0.048663,-0.048629,-0.049709,-0.052359,-0.052896,-0.039718,-0.029136,-0.028822
4,4,778705,0.0,0,52296320,0,0.0,3255875127,40955,Other,...,-0.047224,-0.048950,-0.048655,-0.048623,-0.049705,-0.052353,-0.052887,-0.039716,-0.029136,-0.028821
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
16917,17083,4293306306,0.0,0,819682076,3,168.0,1974235411,37409,GPD,...,-0.047227,-0.048959,-0.048661,-0.048626,-0.049708,-0.052358,-0.052896,-0.039719,-0.029137,-0.028822
16918,17084,4293354003,0.0,0,1188113564,0,0.0,447244163,16610,GRDA,...,-0.047230,-0.048961,-0.048664,-0.048629,-0.049711,-0.052361,-0.052897,-0.039719,-0.029137,-0.028822
16919,17085,4293843368,0.0,0,1549502620,0,0.0,1618953053,15065,GRDA,...,-0.047222,-0.048953,-0.048658,-0.048619,-0.049701,-0.052352,-0.052894,-0.039719,-0.029135,-0.028820
16920,17087,4294858825,6.0,2,657299228,2,6.0,3251917981,4312,Other,...,-0.047230,-0.048961,-0.048664,-0.048629,-0.049710,-0.052361,-0.052897,-0.039719,-0.029137,-0.028822


In [14]:
import catboost as cat
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import GroupKFold
from kaggle_kl_div import score

class_ids = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5}

group_k_fold = GroupKFold(n_splits=5)

all_oof = []
all_true = []

for i, (train_ids, valid_ids) in enumerate(group_k_fold.split(data_processed, None, data_processed.patient_id)):
    model = CatBoostClassifier(task_type='GPU', loss_function='MultiClass')
    
    train_pool = Pool(
        data=data_processed.loc[train_ids, col_features],
        label=data_processed.loc[train_ids, 'expert_consensus'].map(class_ids)
    )

    valid_pool = Pool(
        data=data_processed.loc[valid_ids, col_features],
        label=data_processed.loc[valid_ids, 'expert_consensus'].map(class_ids)
    )

    model.fit(train_pool, verbose=100, eval_set=valid_pool)
    model.save_model(f'model_f{i}.cat')

    oof = model.predict_proba(valid_pool)
    all_oof.append(oof)
    all_true.append(data_processed.loc[valid_ids, col_targets])


all_oof = np.concatenate(all_oof)
all_true = np.concatenate(all_true)

df_oof = pd.DataFrame(all_oof.copy())
df_oof['id'] = np.arange(len(df_oof))

df_true = pd.DataFrame(all_true.copy())
df_true['id'] = np.arange(len(df_true))

score(solution=df_true, submission=df_oof, row_id_column_name='id')

Learning rate set to 0.136809
0:	learn: 1.6488005	test: 1.6783835	best: 1.6783835 (0)	total: 47.5ms	remaining: 47.5s
100:	learn: 0.7863702	test: 1.2089519	best: 1.2089508 (96)	total: 3.99s	remaining: 35.5s
200:	learn: 0.6097796	test: 1.1890269	best: 1.1870142 (187)	total: 7.89s	remaining: 31.4s
300:	learn: 0.4952500	test: 1.1878609	best: 1.1829302 (228)	total: 11.7s	remaining: 27.3s
400:	learn: 0.4059426	test: 1.1934979	best: 1.1829302 (228)	total: 15.6s	remaining: 23.3s
500:	learn: 0.3391463	test: 1.1927486	best: 1.1829302 (228)	total: 19.2s	remaining: 19.2s
600:	learn: 0.2876501	test: 1.1956894	best: 1.1829302 (228)	total: 22.8s	remaining: 15.1s
700:	learn: 0.2450520	test: 1.2045513	best: 1.1829302 (228)	total: 26.3s	remaining: 11.2s
800:	learn: 0.2089440	test: 1.2141736	best: 1.1829302 (228)	total: 29.8s	remaining: 7.4s
900:	learn: 0.1806291	test: 1.2165176	best: 1.1829302 (228)	total: 33.3s	remaining: 3.66s
999:	learn: 0.1567555	test: 1.2268648	best: 1.1829302 (228)	total: 36.7s	re

0.8827525088411158

In [15]:
from kaggle_kl_div import score

df_equal = pd.DataFrame(np.ones(df_oof.shape) / 6)
df_equal['id'] = np.arange(len(df_equal))

df_true = pd.DataFrame(all_true.copy())
df_true['id'] = np.arange(len(df_true))

score(solution=df_true, submission=df_equal, row_id_column_name='id')

1.4750432192092449

In [16]:
test_csv = pd.read_csv('data/test.csv')
test_csv

Unnamed: 0,spectrogram_id,eeg_id,patient_id
0,853520,3911565283,6885


In [17]:
eeg_window(test_csv.iloc[0], train=False)

Unnamed: 0,Fp1,F3,C3,P3,F7,T3,T5,O1,Fz,Cz,Pz,Fp2,F4,C4,P4,F8,T4,T6,O2,EKG
0,9.210000,-47.459999,15.100000,8.220000,-16.900000,-22.99,-25.820000,-10.090000,28.370001,-3.010000,-27.299999,101.040001,35.110001,14.540000,18.330000,28.540001,44.090000,69.650002,30.74,171.679993
1,-3.590000,-30.290001,32.380001,10.800000,-68.980003,-21.60,-15.080000,-9.210000,26.360001,-8.980000,-32.279999,95.800003,26.389999,4.820000,10.540000,20.559999,32.060001,59.439999,23.32,178.279999
2,-26.040001,-60.070000,2.370000,-10.150000,-34.689999,-31.40,-31.920000,-26.980000,-1.940000,-28.770000,-49.770000,73.449997,-3.680000,-17.320000,-16.150000,-8.270000,5.330000,45.180000,9.49,306.739990
3,-3.040000,-36.250000,29.559999,14.530000,-14.010000,-11.90,-14.230000,-6.310000,26.040001,-2.770000,-25.030001,91.010002,22.610001,6.900000,9.930000,15.480000,33.580002,69.620003,31.01,223.259995
4,-4.630000,-20.160000,25.190001,1.190000,-44.580002,-23.51,-30.709999,-17.600000,25.420000,-8.860000,-33.959999,89.449997,19.440001,-2.080000,6.110000,8.380000,24.180000,55.869999,19.91,170.759995
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,-26.889999,-45.480000,-17.250000,-23.570000,19.059999,-9.40,-27.120001,-21.580000,-75.760002,-65.800003,-88.790001,-30.090000,-49.830002,-75.339996,-61.139999,-71.889999,-53.299999,-8.130000,-12.38,-34.799999
9996,-24.049999,-41.689999,-13.450000,-26.219999,14.210000,0.02,-30.030001,-22.219999,-75.440002,-68.639999,-91.099998,-33.180000,-45.610001,-78.809998,-61.259998,-71.889999,-55.009998,-12.320000,-15.15,-27.799999
9997,-34.500000,-55.340000,-25.959999,-30.670000,8.890000,-9.74,-38.520000,-30.330000,-87.080002,-70.690002,-92.320000,-37.349998,-57.290001,-80.209999,-67.320000,-72.919998,-57.110001,-12.330000,-15.20,21.980000
9998,-16.110001,-35.980000,-8.570000,-12.020000,28.580000,5.45,-20.510000,-10.300000,-65.459999,-50.730000,-71.650002,-15.970000,-36.380001,-59.660000,-46.310001,-51.520000,-39.740002,6.770000,3.74,-5.800000


In [18]:
spg_window(test_csv.iloc[0], train=False)

Unnamed: 0,time,LL_0.59,LL_0.78,LL_0.98,LL_1.17,LL_1.37,LL_1.56,LL_1.76,LL_1.95,LL_2.15,...,RP_18.16,RP_18.36,RP_18.55,RP_18.75,RP_18.95,RP_19.14,RP_19.34,RP_19.53,RP_19.73,RP_19.92
0,1,14.910000,17.110001,11.660000,11.73,6.08,4.54,4.31,3.38,2.05,...,0.07,0.06,0.05,0.06,0.05,0.05,0.06,0.05,0.04,0.05
1,3,11.130000,10.950000,10.770000,5.07,4.03,3.24,3.61,2.98,1.54,...,0.05,0.04,0.04,0.04,0.04,0.04,0.03,0.03,0.03,0.02
2,5,10.880000,10.570000,8.790000,5.33,2.44,1.48,1.83,0.99,0.89,...,0.04,0.04,0.04,0.03,0.03,0.04,0.04,0.05,0.06,0.06
3,7,19.450001,18.200001,17.719999,13.38,4.17,1.88,1.84,1.22,1.27,...,0.03,0.03,0.05,0.08,0.07,0.07,0.08,0.03,0.03,0.03
4,9,21.650000,22.530001,23.160000,17.00,7.19,3.89,3.65,2.72,2.35,...,0.04,0.04,0.05,0.05,0.06,0.05,0.05,0.05,0.04,0.03
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,591,15.580000,18.209999,14.020000,15.96,4.36,4.98,2.68,2.22,2.03,...,0.48,0.59,0.59,0.73,0.44,0.41,0.56,0.60,0.61,0.60
296,593,17.209999,20.219999,20.889999,17.16,9.15,4.14,2.49,2.71,1.60,...,0.26,0.37,0.41,0.36,0.48,0.36,0.39,0.46,0.34,0.32
297,595,9.610000,13.320000,9.190000,11.50,8.11,5.53,5.57,3.69,3.19,...,0.58,0.37,0.17,0.14,0.13,0.30,0.36,0.39,0.56,0.29
298,597,8.430000,11.840000,13.640000,10.56,8.63,5.80,2.98,1.48,0.96,...,0.54,0.22,0.17,0.16,0.11,0.38,0.45,0.45,0.45,0.34


In [25]:
features_eeg = test_csv.apply(lambda row: eeg_features(eeg_window(row, train=False)), axis=1)
features_spg = test_csv.apply(lambda row: spg_features(spg_window(row, train=False)), axis=1)

preds = []
for i in range(5):
    model = CatBoostClassifier(task_type='GPU')
    model.load_model(f'model_f{i}.cat')

    test_pool = Pool(
        data = pd.concat([features_eeg, features_spg], axis=1)
    )

    pred = model.predict_proba(test_pool)
    preds.append(pred)


pred = np.mean(preds, axis=0)
pred.round(3)

array([[0.127, 0.001, 0.001, 0.001, 0.001, 0.869]])

In [26]:
submission_csv = pd.DataFrame({'eeg_id': test_csv.eeg_id.values})
submission_csv[col_targets] = pred
submission_csv

Unnamed: 0,eeg_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,3911565283,0.12671,0.000609,0.000942,0.00132,0.001134,0.869285


In [27]:
submission_csv.to_csv('submission.csv', index=False)