In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import log_loss
from sklearn.decomposition import PCA

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.backend as K
import tensorflow.keras.layers as L
import tensorflow.keras.models as M
import time
import sys
import os

In [5]:
train_df = pd.read_csv('/content/drive/MyDrive/lish-moa/train_features.csv')
test_df = pd.read_csv('/content/drive/MyDrive/lish-moa/test_features.csv')
targetns = pd.read_csv('/content/drive/MyDrive/lish-moa/train_targets_nonscored.csv')
train_target_df = pd.read_csv('/content/drive/MyDrive/lish-moa/train_targets_scored.csv')
# sub = pd.read_csv('/content/drive/MyDrive/lish_moa/sample_submission.csv')

In [6]:
target_cols = train_target_df.columns[1:]
N_TARGETS = len(target_cols)
print(train_df.shape)

(23814, 876)


In [7]:
cells = [col for col in train_df.columns if col.startswith('c-')]
genes = [col for col in train_df.columns if col.startswith('g-')]

In [8]:
# For g- features
n_comp = 50
data = pd.concat([pd.DataFrame(train_df[genes]), pd.DataFrame(test_df[genes])])
data2 = (PCA(n_components = 50, random_state = 100).fit_transform(data[genes]))
train2 = data2[:train_df.shape[0]]
test2 = data2[-test_df.shape[0]:]

train2 = pd.DataFrame(train2, columns = [f'pca_G-{i}' for i in range(50)])
test2 = pd.DataFrame(test2, columns = [f'pca_G-{i}' for i in range(50)])

train_df = pd.concat((train_df, train2), axis = 1)
test_df = pd.concat((test_df, test2), axis = 1)

In [9]:
# For c- features
data = pd.concat([pd.DataFrame(train_df[cells]), pd.DataFrame(test_df[cells])])
data2 = (PCA(n_components = 15, random_state = 100).fit_transform(data[cells]))
train2 = data2[:train_df.shape[0]]
test2 = data2[-test_df.shape[0]:]

train2 = pd.DataFrame(train2, columns = [f'pca_C-{i}' for i in range(15)])
test2 = pd.DataFrame(test2, columns = [f'pca_C-{i}' for i in range(15)])
train_df = pd.concat((train_df, train2), axis = 1)
test_df = pd.concat((test_df, test2), axis = 1)
train_df

Unnamed: 0,sig_id,cp_type,cp_time,cp_dose,g-0,g-1,g-2,g-3,g-4,g-5,g-6,g-7,g-8,g-9,g-10,g-11,g-12,g-13,g-14,g-15,g-16,g-17,g-18,g-19,g-20,g-21,g-22,g-23,g-24,g-25,g-26,g-27,g-28,g-29,g-30,g-31,g-32,g-33,g-34,g-35,...,pca_G-25,pca_G-26,pca_G-27,pca_G-28,pca_G-29,pca_G-30,pca_G-31,pca_G-32,pca_G-33,pca_G-34,pca_G-35,pca_G-36,pca_G-37,pca_G-38,pca_G-39,pca_G-40,pca_G-41,pca_G-42,pca_G-43,pca_G-44,pca_G-45,pca_G-46,pca_G-47,pca_G-48,pca_G-49,pca_C-0,pca_C-1,pca_C-2,pca_C-3,pca_C-4,pca_C-5,pca_C-6,pca_C-7,pca_C-8,pca_C-9,pca_C-10,pca_C-11,pca_C-12,pca_C-13,pca_C-14
0,id_000644bb2,trt_cp,24,D1,1.0620,0.5577,-0.2479,-0.6208,-0.1944,-1.0120,-1.0220,-0.0326,0.5548,-0.0921,1.1830,0.1530,0.5574,-0.4015,0.1789,-0.6528,-0.7969,0.6342,0.1778,-0.3694,-0.5688,-1.1360,-1.1880,0.6940,0.4393,0.2664,0.1907,0.1628,-0.2853,0.5819,0.2934,-0.5584,-0.0916,-0.3010,-0.1537,0.2198,...,0.720634,-1.002048,0.864533,0.496370,-0.851521,0.498302,-1.052455,1.709184,0.409845,0.298407,-0.593563,-0.452313,1.407378,1.046697,-0.066408,0.554682,0.276440,0.265879,1.071292,0.580373,0.590998,-0.994302,1.727720,0.551437,0.560823,-7.285008,0.608206,-0.007576,0.191647,-0.755983,-0.440573,-0.124919,-1.222842,0.197772,-0.934195,0.488818,-0.403431,-0.010301,-0.776819,0.338812
1,id_000779bfc,trt_cp,72,D1,0.0743,0.4087,0.2991,0.0604,1.0190,0.5207,0.2341,0.3372,-0.4047,0.8507,-1.1520,-0.4201,-0.0958,0.4590,0.0803,0.2250,0.5293,0.2839,-0.3494,0.2883,0.9449,-0.1646,-0.2657,-0.3372,0.3135,-0.4316,0.4773,0.2075,-0.4216,-0.1161,-0.0499,-0.2627,0.9959,-0.2483,0.2655,-0.2102,...,-0.781220,1.686975,2.666582,-0.439571,-0.167359,1.155236,1.383034,1.226590,0.479089,1.105055,-1.004126,0.254221,-2.836420,0.758174,0.308480,1.702743,0.955828,-0.311353,-0.773748,1.493518,0.510786,0.631809,0.496264,-0.013238,0.292292,-7.417466,-0.756860,0.138957,-0.748048,-0.667145,0.086473,0.639086,0.508997,-0.305727,0.258950,1.047416,0.181604,0.127128,0.098714,0.899737
2,id_000a6266a,trt_cp,48,D1,0.6280,0.5817,1.5540,-0.0764,-0.0323,1.2390,0.1715,0.2155,0.0065,1.2300,-0.4797,-0.5631,-0.0366,-1.8300,0.6057,-0.3278,0.6042,-0.3075,-0.1147,-0.0570,-0.0799,-0.8181,-1.5320,0.2307,0.4901,0.4780,-1.3970,4.6240,-0.0437,1.2870,-1.8530,0.6069,0.4290,0.1783,0.0018,-1.1800,...,-1.858141,-1.402234,-2.853905,0.467562,-0.647710,2.040741,-1.630025,-0.047758,-2.582135,1.945418,1.510833,2.480615,-0.392530,0.630351,2.219793,0.608508,0.946207,3.674519,3.727714,-1.482477,-0.672522,1.160756,-0.751993,0.332657,-0.581010,-2.247580,0.227317,0.090409,0.970996,-0.187612,-0.117653,0.754535,-0.328586,-0.894865,-0.147314,-0.147462,0.013202,0.322070,0.548096,-0.341222
3,id_0015fd391,trt_cp,48,D1,-0.5138,-0.2491,-0.2656,0.5288,4.0620,-0.8095,-1.9590,0.1792,-0.1321,-1.0600,-0.8269,-0.3584,-0.8511,-0.5844,-2.5690,0.8183,-0.0532,-0.8554,0.1160,-2.3520,2.1200,-1.1580,-0.7191,-0.8004,-1.4670,-0.0107,-0.8995,0.2406,-0.2479,-1.0890,-0.7575,0.0881,-2.7370,0.8745,0.5787,-1.6740,...,-1.418338,-0.613210,-1.850996,-2.519921,2.337761,0.028901,5.639400,0.268690,2.644683,0.964086,1.134717,2.977621,1.599544,0.355653,-0.468149,0.548169,-1.761214,1.110526,0.950816,1.488872,-0.930547,1.113490,-0.713351,-0.111871,-1.532594,13.943314,7.308942,0.512171,-3.352433,-1.340207,0.598344,0.762685,-0.020627,-0.864042,0.186749,0.406903,0.433483,-1.166913,0.312000,1.736286
4,id_001626bd3,trt_cp,72,D2,-0.3254,-0.4009,0.9700,0.6919,1.4180,-0.8244,-0.2800,-0.1498,-0.8789,0.8630,-0.2219,-0.5121,-0.9577,1.1750,0.2042,0.1970,0.1244,-1.7090,-0.3543,-0.5160,-0.3330,-0.2685,0.7649,0.2057,1.3720,0.6835,0.8056,-0.3754,-1.2090,0.2965,-0.0712,0.6389,0.6674,-0.0783,1.1740,-0.7110,...,-0.509489,1.318480,1.846217,-1.181367,1.739973,1.682794,0.493895,0.786108,1.717408,-0.396927,-0.229921,0.385068,-1.238370,0.039057,-2.221722,0.217201,-1.538193,0.908307,-1.286020,1.937282,1.417942,-1.785443,0.809808,-0.950254,0.983837,-6.307752,0.262810,0.267303,0.636032,-0.202712,-0.019117,-0.310744,-0.143201,-0.102514,-0.369357,0.097451,-0.510072,-0.025396,0.241370,-0.391453
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23809,id_fffb1ceed,trt_cp,24,D2,0.1394,-0.0636,-0.1112,-0.5080,-0.4713,0.7201,0.5773,0.3055,-0.4726,0.1269,0.2531,0.1730,-0.4532,-1.0790,0.2474,-0.4550,0.3588,0.1600,-0.7362,-0.1103,0.8550,-0.4139,0.5541,0.2310,-0.5573,-0.4397,-0.9260,-0.2424,-0.6686,0.2326,0.6456,0.0136,-0.5141,-0.6320,0.7166,-0.1736,...,0.312503,-0.771408,0.454353,0.556108,-2.079935,0.503498,-2.729901,-0.211907,-0.108778,2.406187,-0.368170,0.942625,1.084293,-0.083321,-0.010082,-0.575770,1.446017,2.017576,0.076590,-2.079627,-1.887206,-2.127474,-0.228028,-1.059640,-2.143961,-6.398104,0.103819,0.055113,-0.727397,-0.094678,-0.444480,0.773718,0.412298,-0.323740,0.274428,-0.198010,0.781093,-0.764532,-0.154121,1.404189
23810,id_fffb70c0c,trt_cp,24,D2,-1.3260,0.3478,-0.3743,0.9905,-0.7178,0.6621,-0.2252,-0.5565,0.5112,0.6727,-0.1851,2.8650,-0.2140,-0.6153,0.8362,0.5584,-0.2589,0.1292,0.0148,0.0949,-0.2182,-0.9235,0.0749,-1.5910,-0.8359,-0.9217,0.3013,0.1716,0.0880,0.1842,0.1835,0.5436,-0.0533,-0.0491,0.9543,0.4626,...,0.355809,-0.099948,-0.331101,0.667607,0.846298,2.420827,-1.080704,1.476196,-1.615844,-1.351519,1.927152,-0.878880,0.738303,-1.958191,1.508200,-0.444276,-0.499055,-0.079298,2.614913,-0.545869,-1.845803,0.900534,0.007127,-2.352780,3.091127,-4.029873,-1.009079,1.117199,0.148042,0.008588,-1.373380,-0.408353,-1.406653,-0.926603,1.188596,0.144073,0.038613,-0.162614,-0.696604,-0.038072
23811,id_fffc1c3f4,ctl_vehicle,48,D2,0.3942,0.3756,0.3109,-0.7389,0.5505,-0.0159,-0.2541,0.1745,-0.0340,0.4865,-0.1854,0.0716,0.1729,-0.0434,0.1542,-0.2192,-0.0302,-0.4218,0.4057,-0.5372,0.1521,-0.2651,0.2310,-0.8101,0.4943,0.6905,-0.3720,-1.4110,0.4516,1.2300,-0.1949,-1.3280,-0.4276,-0.0040,-0.3086,-0.2355,...,0.514248,-0.547889,-1.805144,-1.009904,-1.326962,0.023632,-0.450699,-0.058584,1.938601,-2.600804,-0.319968,-2.758825,3.076324,0.423707,4.199681,1.850683,1.617905,-2.477354,0.736502,0.670233,0.394235,1.108688,-1.726877,0.462453,1.123788,-8.302736,0.451157,0.642989,-0.618732,0.661144,-0.468037,-0.013892,-0.278587,-0.202004,-0.453634,0.034899,0.558306,-0.170959,-0.027128,0.040471
23812,id_fffcb9e7c,trt_cp,24,D1,0.6660,0.2324,0.4392,0.2044,0.8531,-0.0343,0.0323,0.0463,0.4299,-0.7985,0.5742,0.1421,2.2700,0.2046,0.5363,-1.7330,0.1450,0.6097,0.2024,0.9865,-0.7805,0.9608,0.3440,2.7650,0.4925,0.6698,0.2374,-0.3372,0.8771,-2.6560,-0.2000,-0.2043,0.6797,-0.0248,-0.0927,1.8480,...,1.316908,-5.138563,-0.359762,4.836056,-2.732917,0.024836,-1.313128,2.002749,0.001628,-0.421295,0.523341,-0.095842,-0.845837,-0.606549,-1.264786,0.230480,-3.733861,1.440790,-0.653776,1.701608,-1.184729,1.552738,-0.493535,-0.767306,-1.969371,-7.838803,0.176177,0.255380,0.311314,0.445997,-1.112928,-0.012667,-0.076785,-0.359969,0.351484,0.945044,-1.402858,-0.366168,0.455150,0.779150


In [10]:
from sklearn.feature_selection import VarianceThreshold

train_copy = train_df
var_thresh = VarianceThreshold(0.8)
data = train_df.append(test_df)
data_transformed = var_thresh.fit_transform(data.iloc[:, 4:])
data_transformed.shape

(27796, 868)

In [11]:
train_df_trans = data_transformed[ : train_df.shape[0]]
test_df_trans = data_transformed[-test_df.shape[0] : ]

train_df = pd.DataFrame(train_df[['sig_id', 'cp_type', 'cp_time', 'cp_dose']].values.reshape(-1, 4), columns = ['sig_id', 'cp_type', 'cp_time', 'cp_dose'])
# train_df.head
train_df = pd.concat([train_df, pd.DataFrame(train_df_trans)], axis = 1)


test_df = pd.DataFrame(test_df[['sig_id', 'cp_type', 'cp_time', 'cp_dose']].values.reshape(-1, 4), columns = ['sig_id', 'cp_type', 'cp_time', 'cp_dose'])
# train_df.head
test_df = pd.concat([test_df, pd.DataFrame(test_df_trans)], axis = 1)
train_df.head

<bound method NDFrame.head of              sig_id      cp_type cp_time  ...       865       866       867
0      id_000644bb2       trt_cp      24  ... -0.934195  0.488818 -0.403431
1      id_000779bfc       trt_cp      72  ...  0.258950  1.047416  0.181604
2      id_000a6266a       trt_cp      48  ... -0.147314 -0.147462  0.013202
3      id_0015fd391       trt_cp      48  ...  0.186749  0.406903  0.433483
4      id_001626bd3       trt_cp      72  ... -0.369357  0.097451 -0.510072
...             ...          ...     ...  ...       ...       ...       ...
23809  id_fffb1ceed       trt_cp      24  ...  0.274428 -0.198010  0.781093
23810  id_fffb70c0c       trt_cp      24  ...  1.188596  0.144073  0.038613
23811  id_fffc1c3f4  ctl_vehicle      48  ... -0.453634  0.034899  0.558306
23812  id_fffcb9e7c       trt_cp      24  ...  0.351484  0.945044 -1.402858
23813  id_ffffdd77b       trt_cp      72  ...  4.458739  0.583752  2.390563

[23814 rows x 872 columns]>

In [12]:
search_row = dict(train_copy.iloc[0, 4:])
col_rela = {}
for i in np.arange(0, 868):
  for k, v in search_row.items():
    if train_df[i][0] == v.all():
      col_rela[i] = k
train_df = train_df.rename(columns = col_rela)
test_df = test_df.rename(columns = col_rela)


In [13]:
SEED = 1925
EPOCHS = 25
BATCH_SIZE = 128
FOLDS = 5
REPEATS = 5
LR = 0.0005
N_TARGETS = len(target_cols)

In [14]:
def seed_everything(seed):
  np.random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  tf.random.set_seed(seed)

In [15]:
def multi_log_loss(y_true, y_pred):
  losses = []
  for col in y_true.columns:
    losses.append(log_loss(y_true.loc[:, col], y_pred.loc[:, col]))
  return np.mean(losses)

In [16]:
def preprocess_df(data):
  data['cp_type'] = (data['cp_type'] == 'trt_cp').astype(int)
  data['cp_dose'] = (data['cp_dose'] == 'D2').astype(int)
  return data

In [17]:
x_train = preprocess_df(train_df.drop(columns = "sig_id"))
x_test = preprocess_df(test_df.drop(columns = "sig_id"))
y_train = train_target_df.drop(columns = "sig_id")
N_FEATURES = x_train.shape[1]

In [27]:
x_train = x_train.astype({'cp_time':int})
x_test = x_test.astype({'cp_time':int})
x_train

Unnamed: 0,cp_type,cp_time,cp_dose,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,...,828,829,830,831,832,833,834,835,836,837,838,839,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867
0,1,24,0,1.0620,-0.2479,-0.6208,-0.1944,-1.0120,-0.0326,0.5548,-0.0921,1.1830,0.1530,0.5574,-0.4015,0.1789,-0.7969,0.6342,-0.5688,-1.1880,0.4393,0.1907,0.1628,-0.2853,0.5819,0.2934,-0.5584,-0.0916,-0.3010,-0.1537,0.2198,0.2965,-0.5055,-0.5119,-0.2162,-0.0347,-0.2566,-1.1980,0.3116,-1.0330,...,-1.023938,0.764982,0.999989,0.720634,-1.002048,0.864533,0.496370,-0.851521,0.498302,-1.052455,1.709184,0.409845,0.298407,-0.593563,-0.452313,1.407378,1.046697,-0.066408,0.554682,0.276440,0.265879,1.071292,0.580373,0.590998,-0.994302,1.727720,0.551437,0.560823,-7.285008,0.608206,-0.007576,0.191647,-0.755983,-0.440573,-0.124919,-1.222842,0.197772,-0.934195,0.488818,-0.403431
1,1,72,0,0.0743,0.2991,0.0604,1.0190,0.5207,0.3372,-0.4047,0.8507,-1.1520,-0.4201,-0.0958,0.4590,0.0803,0.5293,0.2839,0.9449,-0.2657,0.3135,0.4773,0.2075,-0.4216,-0.1161,-0.0499,-0.2627,0.9959,-0.2483,0.2655,-0.2102,0.1656,0.5300,-0.2568,-0.0455,0.1194,-0.3958,-1.1730,0.4509,1.9250,...,-0.003276,0.573495,-0.950573,-0.781220,1.686975,2.666582,-0.439571,-0.167359,1.155236,1.383034,1.226590,0.479089,1.105055,-1.004126,0.254221,-2.836420,0.758174,0.308480,1.702743,0.955828,-0.311353,-0.773748,1.493518,0.510786,0.631809,0.496264,-0.013238,0.292292,-7.417466,-0.756860,0.138957,-0.748048,-0.667145,0.086473,0.639086,0.508997,-0.305727,0.258950,1.047416,0.181604
2,1,48,0,0.6280,1.5540,-0.0764,-0.0323,1.2390,0.2155,0.0065,1.2300,-0.4797,-0.5631,-0.0366,-1.8300,0.6057,0.6042,-0.3075,-0.0799,-1.5320,0.4901,-1.3970,4.6240,-0.0437,1.2870,-1.8530,0.6069,0.4290,0.1783,0.0018,-1.1800,0.1256,-0.1219,5.4470,1.0310,0.3477,-0.5561,0.0357,-0.3636,-0.4653,...,-0.618204,1.603663,-1.608948,-1.858141,-1.402234,-2.853905,0.467562,-0.647710,2.040741,-1.630025,-0.047758,-2.582135,1.945418,1.510833,2.480615,-0.392530,0.630351,2.219793,0.608508,0.946207,3.674519,3.727714,-1.482477,-0.672522,1.160756,-0.751993,0.332657,-0.581010,-2.247580,0.227317,0.090409,0.970996,-0.187612,-0.117653,0.754535,-0.328586,-0.894865,-0.147314,-0.147462,0.013202
3,1,48,0,-0.5138,-0.2656,0.5288,4.0620,-0.8095,0.1792,-0.1321,-1.0600,-0.8269,-0.3584,-0.8511,-0.5844,-2.5690,-0.0532,-0.8554,2.1200,-0.7191,-1.4670,-0.8995,0.2406,-0.2479,-1.0890,-0.7575,0.0881,-2.7370,0.8745,0.5787,-1.6740,-1.6720,-1.2690,3.0900,-0.3814,-0.7229,-0.0010,0.1353,-1.6400,-0.7483,...,0.433757,3.776830,-1.493978,-1.418338,-0.613210,-1.850996,-2.519921,2.337761,0.028901,5.639400,0.268690,2.644683,0.964086,1.134717,2.977621,1.599544,0.355653,-0.468149,0.548169,-1.761214,1.110526,0.950816,1.488872,-0.930547,1.113490,-0.713351,-0.111871,-1.532594,13.943314,7.308942,0.512171,-3.352433,-1.340207,0.598344,0.762685,-0.020627,-0.864042,0.186749,0.406903,0.433483
4,1,72,1,-0.3254,0.9700,0.6919,1.4180,-0.8244,-0.1498,-0.8789,0.8630,-0.2219,-0.5121,-0.9577,1.1750,0.2042,0.1244,-1.7090,-0.3330,0.7649,1.3720,0.8056,-0.3754,-1.2090,0.2965,-0.0712,0.6389,0.6674,-0.0783,1.1740,-0.7110,-1.4470,1.0620,0.7888,-0.0848,0.1302,0.1224,0.9584,0.2126,0.6162,...,1.364715,-1.323377,0.287988,-0.509489,1.318480,1.846217,-1.181367,1.739973,1.682794,0.493895,0.786108,1.717408,-0.396927,-0.229921,0.385068,-1.238370,0.039057,-2.221722,0.217201,-1.538193,0.908307,-1.286020,1.937282,1.417942,-1.785443,0.809808,-0.950254,0.983837,-6.307752,0.262810,0.267303,0.636032,-0.202712,-0.019117,-0.310744,-0.143201,-0.102514,-0.369357,0.097451,-0.510072
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23809,1,24,1,0.1394,-0.1112,-0.5080,-0.4713,0.7201,0.3055,-0.4726,0.1269,0.2531,0.1730,-0.4532,-1.0790,0.2474,0.3588,0.1600,0.8550,0.5541,-0.5573,-0.9260,-0.2424,-0.6686,0.2326,0.6456,0.0136,-0.5141,-0.6320,0.7166,-0.1736,0.3686,-0.1565,-0.7362,0.1318,0.1119,1.3410,0.0813,-0.2178,-1.1840,...,-0.120989,-0.290174,-3.014317,0.312503,-0.771408,0.454353,0.556108,-2.079935,0.503498,-2.729901,-0.211907,-0.108778,2.406187,-0.368170,0.942625,1.084293,-0.083321,-0.010082,-0.575770,1.446017,2.017576,0.076590,-2.079627,-1.887206,-2.127474,-0.228028,-1.059640,-2.143961,-6.398104,0.103819,0.055113,-0.727397,-0.094678,-0.444480,0.773718,0.412298,-0.323740,0.274428,-0.198010,0.781093
23810,1,24,1,-1.3260,-0.3743,0.9905,-0.7178,0.6621,-0.5565,0.5112,0.6727,-0.1851,2.8650,-0.2140,-0.6153,0.8362,-0.2589,0.1292,-0.2182,0.0749,-0.8359,0.3013,0.1716,0.0880,0.1842,0.1835,0.5436,-0.0533,-0.0491,0.9543,0.4626,0.0819,0.1586,1.2050,0.0384,-0.0843,-0.8834,-0.6190,0.2070,-0.0265,...,1.291513,-2.317664,0.645659,0.355809,-0.099948,-0.331101,0.667607,0.846298,2.420827,-1.080704,1.476196,-1.615844,-1.351519,1.927152,-0.878880,0.738303,-1.958191,1.508200,-0.444276,-0.499055,-0.079298,2.614913,-0.545869,-1.845803,0.900534,0.007127,-2.352780,3.091127,-4.029873,-1.009079,1.117199,0.148042,0.008588,-1.373380,-0.408353,-1.406653,-0.926603,1.188596,0.144073,0.038613
23811,0,48,1,0.3942,0.3109,-0.7389,0.5505,-0.0159,0.1745,-0.0340,0.4865,-0.1854,0.0716,0.1729,-0.0434,0.1542,-0.0302,-0.4218,0.1521,0.2310,0.4943,-0.3720,-1.4110,0.4516,1.2300,-0.1949,-1.3280,-0.4276,-0.0040,-0.3086,-0.2355,-0.5222,-1.1470,0.2844,0.0813,-0.4128,-0.2369,0.3507,0.2589,-0.2534,...,2.005605,0.351382,0.052283,0.514248,-0.547889,-1.805144,-1.009904,-1.326962,0.023632,-0.450699,-0.058584,1.938601,-2.600804,-0.319968,-2.758825,3.076324,0.423707,4.199681,1.850683,1.617905,-2.477354,0.736502,0.670233,0.394235,1.108688,-1.726877,0.462453,1.123788,-8.302736,0.451157,0.642989,-0.618732,0.661144,-0.468037,-0.013892,-0.278587,-0.202004,-0.453634,0.034899,0.558306
23812,1,24,0,0.6660,0.4392,0.2044,0.8531,-0.0343,0.0463,0.4299,-0.7985,0.5742,0.1421,2.2700,0.2046,0.5363,0.1450,0.6097,-0.7805,0.3440,0.4925,0.2374,-0.3372,0.8771,-2.6560,-0.2000,-0.2043,0.6797,-0.0248,-0.0927,1.8480,-2.7180,-5.3160,-1.0750,-0.1437,-0.1714,0.0329,0.8309,-0.3018,0.1599,...,-1.624329,-1.838336,-1.203729,1.316908,-5.138563,-0.359762,4.836056,-2.732917,0.024836,-1.313128,2.002749,0.001628,-0.421295,0.523341,-0.095842,-0.845837,-0.606549,-1.264786,0.230480,-3.733861,1.440790,-0.653776,1.701608,-1.184729,1.552738,-0.493535,-0.767306,-1.969371,-7.838803,0.176177,0.255380,0.311314,0.445997,-1.112928,-0.012667,-0.076785,-0.359969,0.351484,0.945044,-1.402858


In [25]:
def create_model():
  # model = tf.keras.Sequential([tf.keras.layers.Input(N_FEATURES), tf.keras.layers.BatchNormalization(),
  #                              tf.keras.layers.Dropout(0.2), 
  #                              tfa.layers.WeightNormalization(tf.keras.layers.Dense(2048, activation = "relu")),
  #                              tf.keras.layers.BatchNormalization(), tf.keras.layers.Dropout(0.5), 
  #                              tfa.layers.WeightNormalization(tf.keras.layers.Dense(2048, activation = "relu")),
  #                              tf.keras.layers.BatchNormalization(), tf.keras.layers.Dropout(0.5),
  #                              tfa.layers.WeightNormalization(tf.keras.layers.Dense(N_TARGETS, activation = "sigmoid"))])
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Input(N_FEATURES))
  model.add(tf.keras.layers.BatchNormalization())
  model.add(tf.keras.layers.Dropout(0.2))
  model.add(tfa.layers.WeightNormalization(tf.keras.layers.Dense(2048, activation = "relu")))
  model.add(tf.keras.layers.BatchNormalization())
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tfa.layers.WeightNormalization(tf.keras.layers.Dense(2048, activation = "relu")))
  model.add(tf.keras.layers.BatchNormalization())
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tfa.layers.WeightNormalization(tf.keras.layers.Dense(N_TARGETS, activation = "sigmoid")))
  model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = LR), loss = 'binary_crossentropy', metrics = ["accuracy"])
  return model

In [26]:
model = create_model()
model.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
batch_normalization_15 (Batc (None, 871)               3484      
_________________________________________________________________
dropout_15 (Dropout)         (None, 871)               0         
_________________________________________________________________
weight_normalization_15 (Wei (None, 2048)              3573761   
_________________________________________________________________
batch_normalization_16 (Batc (None, 2048)              8192      
_________________________________________________________________
dropout_16 (Dropout)         (None, 2048)              0         
_________________________________________________________________
weight_normalization_16 (Wei (None, 2048)              8394753   
_________________________________________________________________
batch_normalization_17 (Batc (None, 2048)             

In [23]:
def build_train(resume_models = None, repeat_number = 0, folds = 5, skip_folds = 0):
  models = []
  oof_preds = y_train.copy()
  kfold = KFold(folds, shuffle = True)
  for fold, (train_ind, val_ind) in enumerate(kfold.split(x_train)):
    print(f'Training fold {fold + 1}')
    cb_lr_schedule = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = 0.4, patience = 2, verbose = 1, min_delta = 0.0001, mode = 'auto')
    checkpoint_path = f'repeat:{repeat_number}_Fold:{fold}.hdf5'
    cb_checkpt = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor = 'val_loss', verbose = 0,
                                                    save_best_only = True, save_weights_only = True, mode = 'min')
    model = create_model()
    model.fit(x_train.values[train_ind], y_train.values[train_ind], validation_data = (x_train.values[val_ind], y_train.values[val_ind]),
              callbacks = [cb_lr_schedule, cb_checkpt], epochs = EPOCHS, batch_size = BATCH_SIZE, verbose = 2)
    model.load_weights(checkpoint_path)
    oof_preds.loc[val_ind, :] = model.predict(x_train.values[val_ind])
    models.append(model)
    print('train:')
    print(list(zip(model.metrics_names, model.evaluate(x_train.values[train_ind], y_train.values[train_ind], verbose = 0, batch_size = 32))))
    print('val:')
    print(list(zip(model.metrics_names, model.evaluate(x_train.values[train_ind], y_train.values[train_ind], verbose = 0, batch_size = 32))))
    return models, oof_preds
  

In [24]:
models = []
oof_preds = []
seed_everything(SEED)
for i in range(REPEATS):
  m, oof = build_train(repeat_number = i, folds = FOLDS)
  models = models + m
  oof_preds.append(oof)



--------------------------------------------------
Training fold 1
Epoch 1/25
149/149 - 34s - loss: 0.5643 - accuracy: 0.0315 - val_loss: 0.2334 - val_accuracy: 0.0762
Epoch 2/25
149/149 - 34s - loss: 0.1019 - accuracy: 0.0489 - val_loss: 0.0435 - val_accuracy: 0.0909
Epoch 3/25
149/149 - 34s - loss: 0.0334 - accuracy: 0.0728 - val_loss: 0.0250 - val_accuracy: 0.1001
Epoch 4/25
149/149 - 34s - loss: 0.0240 - accuracy: 0.0811 - val_loss: 0.0212 - val_accuracy: 0.1064
Epoch 5/25
149/149 - 34s - loss: 0.0212 - accuracy: 0.0902 - val_loss: 0.0201 - val_accuracy: 0.1111
Epoch 6/25
149/149 - 34s - loss: 0.0196 - accuracy: 0.0938 - val_loss: 0.0184 - val_accuracy: 0.0884
Epoch 7/25
149/149 - 34s - loss: 0.0187 - accuracy: 0.1006 - val_loss: 0.0185 - val_accuracy: 0.0903
Epoch 8/25
149/149 - 34s - loss: 0.0181 - accuracy: 0.1047 - val_loss: 0.0173 - val_accuracy: 0.0968
Epoch 9/25
149/149 - 35s - loss: 0.0174 - accuracy: 0.1086 - val_loss: 0.0168 - val_accuracy: 0.0974
Epoch 10/25
149/149 - 

KeyboardInterrupt: ignored