In [2]:
import numpy as np
import pandas as pd
import os
import re
from sklearn.base import clone
from sklearn.metrics import cohen_kappa_score
from sklearn.model_selection import StratifiedKFold
from scipy.optimize import minimize
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Dense
from keras.optimizers import Adam

from colorama import Fore, Style
from IPython.display import clear_output
import warnings
from lightgbm import LGBMRegressor
from xgboost import XGBRegressor
from catboost import CatBoostRegressor
from sklearn.ensemble import VotingRegressor, RandomForestRegressor, GradientBoostingRegressor
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
warnings.filterwarnings('ignore')
pd.options.display.max_columns = None
import torch
from torch import nn

SEED = 42
n_splits = 5

In [3]:
def process_file(filename, dirname):
    df = pd.read_parquet(os.path.join(dirname, filename, 'part-0.parquet'))
    df.drop('step', axis=1, inplace=True)
    return df.describe().values.reshape(-1), filename.split('=')[1]

def load_time_series(dirname) -> pd.DataFrame:
    ids = os.listdir(dirname)
    
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(lambda fname: process_file(fname, dirname), ids), total=len(ids)))
    
    stats, indexes = zip(*results)
    
    df = pd.DataFrame(stats, columns=[f"stat_{i}" for i in range(len(stats[0]))])
    df['id'] = indexes
    return df


In [4]:
train = pd.read_csv('/kaggle/input/child-mind-institute-problematic-internet-use/train.csv')
test = pd.read_csv('/kaggle/input/child-mind-institute-problematic-internet-use/test.csv')
sample = pd.read_csv('/kaggle/input/child-mind-institute-problematic-internet-use/sample_submission.csv')

train_ts = load_time_series("/kaggle/input/child-mind-institute-problematic-internet-use/series_train.parquet")
test_ts = load_time_series("/kaggle/input/child-mind-institute-problematic-internet-use/series_test.parquet")

df_train = train_ts.drop('id', axis=1)
df_test = test_ts.drop('id', axis=1)

100%|██████████| 996/996 [01:17<00:00, 12.93it/s]
100%|██████████| 2/2 [00:00<00:00,  8.73it/s]


In [5]:
train_ts.head()

Unnamed: 0,stat_0,stat_1,stat_2,stat_3,stat_4,stat_5,stat_6,stat_7,stat_8,stat_9,stat_10,stat_11,stat_12,stat_13,stat_14,stat_15,stat_16,stat_17,stat_18,stat_19,stat_20,stat_21,stat_22,stat_23,stat_24,stat_25,stat_26,stat_27,stat_28,stat_29,stat_30,stat_31,stat_32,stat_33,stat_34,stat_35,stat_36,stat_37,stat_38,stat_39,stat_40,stat_41,stat_42,stat_43,stat_44,stat_45,stat_46,stat_47,stat_48,stat_49,stat_50,stat_51,stat_52,stat_53,stat_54,stat_55,stat_56,stat_57,stat_58,stat_59,stat_60,stat_61,stat_62,stat_63,stat_64,stat_65,stat_66,stat_67,stat_68,stat_69,stat_70,stat_71,stat_72,stat_73,stat_74,stat_75,stat_76,stat_77,stat_78,stat_79,stat_80,stat_81,stat_82,stat_83,stat_84,stat_85,stat_86,stat_87,stat_88,stat_89,stat_90,stat_91,stat_92,stat_93,stat_94,stat_95,id
0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,-0.054638,-0.163923,-0.114302,0.045252,-7.805897,0.0,46.009533,4027.514893,54154750000000.0,4.43886,2.0,30.202068,0.633126,0.513286,0.500372,0.132576,34.917873,0.0,205.862213,108.451317,18769760000000.0,1.825557,0.0,11.773107,-1.812031,-2.63138,-1.798073,0.0,-89.987045,0.0,0.0,3829.0,0.0,1.0,2.0,15.0,-0.70166,-0.619076,-0.536432,0.007953,-32.948602,0.0,2.520257,3958.0,43251250000000.0,3.0,2.0,17.0,0.015846,-0.14181,-0.104193,0.019257,-6.358004,0.0,8.230733,4029.0,56305000000000.0,5.0,2.0,28.0,0.437897,0.148919,0.22377,0.036048,13.09575,0.0,24.75,4146.0,69780000000000.0,6.0,2.0,38.0,1.850391,3.580182,1.738203,5.314874,89.422226,0.0,2626.199951,4187.0,86395000000000.0,7.0,2.0,57.0,0745c390
1,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,0.113277,0.093139,-0.106038,0.02896,-6.065619,0.046508,56.437958,3829.466064,43311490000000.0,3.840885,2.0,232.909103,0.507897,0.541129,0.603787,0.096825,44.034721,0.208482,206.625092,167.600983,25091360000000.0,1.957999,0.0,5.701968,-1.807955,-2.887664,-1.004992,0.0,-89.654587,0.0,0.0,3098.166748,0.0,1.0,2.0,223.0,-0.231743,-0.2576,-0.595426,0.000367,-37.326844,0.0,4.0,3724.0,21285000000000.0,2.0,2.0,228.0,0.094074,0.068143,-0.2285,0.005257,-13.454103,0.0,10.05048,3812.0,43605000000000.0,4.0,2.0,233.0,0.517859,0.542323,0.312333,0.020598,18.462269,0.0,27.490936,3958.0,65110000000000.0,5.0,2.0,238.0,1.928769,3.234613,2.475326,3.966906,89.08033,1.0,2628.199951,4146.0,86395000000000.0,7.0,2.0,243.0,eaab7a96
2,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,-0.499738,0.046381,-0.181152,0.056544,-11.934993,0.0,77.30513,4106.425781,44816770000000.0,3.148264,3.0,100.144516,0.454021,0.510668,0.412588,0.140594,27.367514,0.0,274.848145,50.734318,20381560000000.0,1.169176,0.0,5.653936,-1.903281,-3.150104,-1.020313,0.0,-89.540176,0.0,0.0,3853.0,45000000000.0,1.0,3.0,97.0,-0.873151,-0.255299,-0.485521,0.005643,-30.154542,0.0,2.918126,4089.625,28885000000000.0,3.0,3.0,98.0,-0.644505,0.088542,-0.191693,0.018467,-11.570901,0.0,7.863636,4111.0,47270000000000.0,3.0,3.0,99.0,-0.242422,0.381953,0.088555,0.048282,5.009753,0.0,21.022933,4140.0,60945000000000.0,4.0,3.0,100.0,1.02151,1.016589,1.746797,5.066334,86.987267,0.0,2618.199951,4183.0,86365000000000.0,7.0,3.0,134.0,8ec2cc63
3,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,0.00743,0.007583,-0.19651,0.053544,-12.847143,0.0,9.369678,3958.604492,48366420000000.0,4.273992,2.303057,60.025017,0.5861,0.542189,0.474437,0.103401,32.552841,0.0,54.104408,122.706802,18687730000000.0,2.023705,1.487018,7.396456,-1.684624,-2.405738,-1.023798,0.0,-89.968369,0.0,0.0,3468.0,0.0,1.0,1.0,48.0,-0.530198,-0.412805,-0.556091,0.009947,-34.965618,0.0,0.893617,3841.0,35260000000000.0,3.0,1.0,53.0,0.022344,0.009674,-0.245181,0.027653,-15.000056,0.0,2.340206,3947.0,48810000000000.0,4.0,1.0,60.0,0.536801,0.443383,0.084469,0.057278,4.816339,0.0,6.2,4064.0,63300000000000.0,6.0,4.0,67.0,5.908,2.083693,1.269051,6.134459,89.976074,0.0,2502.0,6000.0,86395000000000.0,7.0,4.0,72.0,b2987a65
4,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,0.086653,-0.115162,-0.138969,0.040399,-11.009835,0.0,5.049157,3992.347656,58338950000000.0,4.541829,4.0,46.192024,0.509845,0.494897,0.639449,0.090201,47.933723,0.0,15.590773,126.12159,21462060000000.0,2.081796,0.0,18.615358,-1.675859,-1.071042,-1.012266,0.0,-89.770241,0.0,0.0,3815.083252,35000000000.0,1.0,4.0,20.0,-0.224805,-0.444297,-0.685736,0.005364,-46.348264,0.0,1.438378,3837.333252,51613750000000.0,3.0,4.0,32.0,0.053034,-0.087422,-0.22543,0.024135,-13.665493,0.0,2.897436,4000.0,64270000000000.0,4.0,4.0,42.0,0.544297,0.153125,0.347474,0.04369,20.726226,0.0,4.942201,4087.0,73936250000000.0,7.0,4.0,69.0,3.231563,1.03362,1.071875,2.774382,89.300034,0.0,1046.800049,4199.0,86015000000000.0,7.0,4.0,76.0,7b8842c3


In [6]:
train_ts.shape

(996, 97)

In [7]:
train_ts.describe()

Unnamed: 0,stat_0,stat_1,stat_2,stat_3,stat_4,stat_5,stat_6,stat_7,stat_8,stat_9,stat_10,stat_11,stat_12,stat_13,stat_14,stat_15,stat_16,stat_17,stat_18,stat_19,stat_20,stat_21,stat_22,stat_23,stat_24,stat_25,stat_26,stat_27,stat_28,stat_29,stat_30,stat_31,stat_32,stat_33,stat_34,stat_35,stat_36,stat_37,stat_38,stat_39,stat_40,stat_41,stat_42,stat_43,stat_44,stat_45,stat_46,stat_47,stat_48,stat_49,stat_50,stat_51,stat_52,stat_53,stat_54,stat_55,stat_56,stat_57,stat_58,stat_59,stat_60,stat_61,stat_62,stat_63,stat_64,stat_65,stat_66,stat_67,stat_68,stat_69,stat_70,stat_71,stat_72,stat_73,stat_74,stat_75,stat_76,stat_77,stat_78,stat_79,stat_80,stat_81,stat_82,stat_83,stat_84,stat_85,stat_86,stat_87,stat_88,stat_89,stat_90,stat_91,stat_92,stat_93,stat_94,stat_95
count,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0,996.0
mean,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,315832.478916,-0.062037,0.02006,-0.139247,0.038453,-11.243804,0.22199,40.240402,3879.672672,45720100000000.0,3.986639,2.381168,62.926103,0.473792,0.458367,0.561343,0.097117,41.676962,0.213969,143.560387,145.218718,22698280000000.0,1.90803,0.175343,7.510345,-1.913478,-2.391393,-1.076813,2.287877e-08,-89.612836,0.0,0.0,3297.473023,254206800000.0,1.028112,2.139558,50.879518,-0.412507,-0.270569,-0.591028,0.003753,-42.065305,0.096726,3.057059,3785.084045,26861260000000.0,2.168675,2.288153,56.677711,-0.076217,0.019573,-0.178983,0.013236,-12.801342,0.213298,7.384388,3856.479669,46113620000000.0,3.990964,2.38253,62.691767,0.268717,0.316502,0.256802,0.035415,16.156027,0.313309,20.112669,3987.292253,64943000000000.0,5.764056,2.492972,68.882028,1.672994,2.157886,1.875645,3.674156,88.888246,0.650602,2335.652809,4180.122239,86289800000000.0,6.933735,2.615462,78.834337
std,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,133011.574731,0.199042,0.13571,0.270067,0.024918,21.777049,0.297515,41.9195,74.553552,4020327000000.0,0.363256,1.108057,85.784405,0.138864,0.099471,0.11989,0.057625,10.998348,0.189562,124.209778,29.467088,3678057000000.0,0.278682,0.357336,3.530746,0.918718,0.895326,0.287006,5.115021e-07,0.616176,0.0,0.0,301.485667,2506944000000.0,0.256009,1.124242,85.533569,0.289976,0.209238,0.317817,0.005225,27.091537,0.288826,6.1144,88.860669,8174405000000.0,0.533936,1.168386,85.553824,0.265431,0.161416,0.396653,0.011159,30.380237,0.407141,7.882592,89.134341,4804747000000.0,0.479655,1.181625,86.025381,0.292563,0.210885,0.404028,0.025769,30.269773,0.462983,23.602476,62.814012,2608501000000.0,0.638429,1.182345,86.089421,0.881167,0.934672,0.911899,1.459785,3.313411,0.477019,898.044846,93.299368,898497100000.0,0.504877,1.158635,86.447984
min,927.0,927.0,927.0,927.0,927.0,927.0,927.0,927.0,927.0,927.0,927.0,927.0,-0.638839,-0.691523,-1.04802,0.000106,-89.119598,0.0,0.314382,3661.307861,34177550000000.0,1.059442,1.0,-131.253174,0.02478,0.021742,0.018929,0.001534,2.132784,0.0,0.244579,6.687519,6972893000000.0,0.0,0.0,0.0,-8.040816,-5.429414,-8.040491,0.0,-90.0,0.0,0.0,3080.916748,0.0,1.0,1.0,-143.0,-0.956868,-0.999538,-1.050378,0.0,-89.658388,0.0,0.0,3534.125061,12673750000000.0,1.0,1.0,-137.0,-0.73341,-0.999044,-1.047357,0.0,-89.396248,0.0,0.0,3689.0,25347500000000.0,1.0,1.0,-131.0,-0.514271,-0.77181,-1.046654,0.0,-89.021507,0.0,0.0,3800.0,47045000000000.0,1.0,1.0,-125.0,0.395664,0.694427,0.167161,0.140138,11.565893,0.0,26.5,3996.0,69805000000000.0,2.0,1.0,-119.0
25%,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,253592.75,-0.191542,-0.029575,-0.252719,0.019996,-18.165702,0.0,14.050416,3839.530212,43219630000000.0,3.910494,1.053306,17.500581,0.409186,0.428497,0.479145,0.053988,33.613696,0.0,43.41829,131.529034,18947090000000.0,1.919302,0.0,6.437423,-2.193893,-2.990001,-1.048741,0.0,-89.860825,0.0,0.0,3098.166748,0.0,1.0,1.0,5.75,-0.664013,-0.376795,-0.799268,0.000117,-54.144475,0.0,1.5,3741.0,21528750000000.0,2.0,1.0,11.0,-0.266271,-0.020694,-0.318459,0.004261,-19.463583,0.0,4.32164,3809.270752,43301880000000.0,4.0,1.0,17.0,0.029631,0.207801,0.079591,0.016678,4.365174,0.0,10.592305,3958.0,64720000000000.0,6.0,1.0,23.0,1.120858,1.295791,1.230632,2.790156,88.972979,0.0,2512.600098,4170.0,86395000000000.0,7.0,2.0,30.0
50%,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,383544.0,-0.033658,0.008939,-0.141485,0.035104,-10.27545,0.051458,24.780988,3848.289673,43407710000000.0,4.001631,2.021691,33.756311,0.513798,0.476344,0.556131,0.091879,40.038872,0.208968,104.244904,160.041496,24912380000000.0,1.961505,0.0,6.74157,-1.771064,-2.412578,-1.016421,0.0,-89.748867,0.0,0.0,3098.166748,0.0,1.0,2.0,21.0,-0.488205,-0.287686,-0.608817,0.00103,-38.661846,0.0,2.336633,3747.0,21832500000000.0,2.0,2.0,27.0,-0.013372,0.006974,-0.174688,0.010069,-10.770489,0.0,6.271717,3818.0,43678750000000.0,4.0,2.0,33.0,0.224109,0.316078,0.244292,0.029585,14.188232,0.0,15.806186,3970.0,64964380000000.0,6.0,2.0,40.0,1.486703,2.112647,1.740934,3.806256,89.377281,1.0,2613.625,4180.0,86395000000000.0,7.0,3.0,51.0
75%,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,402597.0,0.051972,0.056401,-0.007245,0.05261,-0.747711,0.396362,53.456012,3891.32074,48365970000000.0,4.111787,3.193248,69.702797,0.577053,0.514313,0.627864,0.133454,47.860791,0.403056,218.857834,164.848499,25007370000000.0,2.006777,0.232729,7.397724,-1.346679,-1.701394,-1.00584,0.0,-89.550341,0.0,0.0,3665.0,0.0,1.0,3.0,57.0,-0.138313,-0.151781,-0.442593,0.006798,-27.738929,0.0,3.155954,3786.125,35032500000000.0,2.0,3.0,64.25,0.036609,0.042826,0.004579,0.020944,0.031694,0.0,8.537187,3853.0,48976250000000.0,4.0,4.0,70.0,0.523139,0.406952,0.473484,0.050647,28.368755,1.0,22.721819,3994.0,65230000000000.0,6.0,4.0,76.25,1.899864,2.77945,2.234637,4.593709,89.651743,1.0,2637.0,4187.0,86395000000000.0,7.0,4.0,90.0
max,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,756212.0,0.509981,0.875131,0.943503,0.15683,81.13588,0.985155,367.226227,4163.574707,67337720000000.0,6.0142,4.0,736.465149,0.739901,0.791727,0.984865,0.361639,85.388504,0.499687,712.101196,232.576157,31374520000000.0,2.681513,1.49999,49.941177,-0.069883,-0.04276,-0.968112,1.217405e-05,-79.857498,0.0,0.0,4134.0,42000000000000.0,5.0,4.0,725.0,0.354788,1.022626,0.999427,0.04671,87.328514,1.0,133.333328,4170.0,64795000000000.0,6.0,4.0,731.0,0.638854,1.02751,1.01786,0.060281,88.542992,1.0,133.333328,4175.0,69785000000000.0,7.0,4.0,736.0,0.942128,1.029023,1.024162,0.140011,89.120226,1.0,479.5,4181.0,78097500000000.0,7.0,4.0,742.0,8.022779,7.90695,8.125557,11.3262,89.98114,1.0,20445.5,6000.0,86395000000000.0,7.0,4.0,748.0


In [8]:
temp_df = train[['id','sii']]

time_series_df_with_target = pd.merge(train_ts,temp_df,how="left",on='id')

temp_df = test[['id']]

time_series_df_without_target = pd.merge(test_ts,temp_df,how="left",on='id')

In [9]:
time_series_df_with_target.head()

Unnamed: 0,stat_0,stat_1,stat_2,stat_3,stat_4,stat_5,stat_6,stat_7,stat_8,stat_9,stat_10,stat_11,stat_12,stat_13,stat_14,stat_15,stat_16,stat_17,stat_18,stat_19,stat_20,stat_21,stat_22,stat_23,stat_24,stat_25,stat_26,stat_27,stat_28,stat_29,stat_30,stat_31,stat_32,stat_33,stat_34,stat_35,stat_36,stat_37,stat_38,stat_39,stat_40,stat_41,stat_42,stat_43,stat_44,stat_45,stat_46,stat_47,stat_48,stat_49,stat_50,stat_51,stat_52,stat_53,stat_54,stat_55,stat_56,stat_57,stat_58,stat_59,stat_60,stat_61,stat_62,stat_63,stat_64,stat_65,stat_66,stat_67,stat_68,stat_69,stat_70,stat_71,stat_72,stat_73,stat_74,stat_75,stat_76,stat_77,stat_78,stat_79,stat_80,stat_81,stat_82,stat_83,stat_84,stat_85,stat_86,stat_87,stat_88,stat_89,stat_90,stat_91,stat_92,stat_93,stat_94,stat_95,id,sii
0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,-0.054638,-0.163923,-0.114302,0.045252,-7.805897,0.0,46.009533,4027.514893,54154750000000.0,4.43886,2.0,30.202068,0.633126,0.513286,0.500372,0.132576,34.917873,0.0,205.862213,108.451317,18769760000000.0,1.825557,0.0,11.773107,-1.812031,-2.63138,-1.798073,0.0,-89.987045,0.0,0.0,3829.0,0.0,1.0,2.0,15.0,-0.70166,-0.619076,-0.536432,0.007953,-32.948602,0.0,2.520257,3958.0,43251250000000.0,3.0,2.0,17.0,0.015846,-0.14181,-0.104193,0.019257,-6.358004,0.0,8.230733,4029.0,56305000000000.0,5.0,2.0,28.0,0.437897,0.148919,0.22377,0.036048,13.09575,0.0,24.75,4146.0,69780000000000.0,6.0,2.0,38.0,1.850391,3.580182,1.738203,5.314874,89.422226,0.0,2626.199951,4187.0,86395000000000.0,7.0,2.0,57.0,0745c390,1.0
1,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,0.113277,0.093139,-0.106038,0.02896,-6.065619,0.046508,56.437958,3829.466064,43311490000000.0,3.840885,2.0,232.909103,0.507897,0.541129,0.603787,0.096825,44.034721,0.208482,206.625092,167.600983,25091360000000.0,1.957999,0.0,5.701968,-1.807955,-2.887664,-1.004992,0.0,-89.654587,0.0,0.0,3098.166748,0.0,1.0,2.0,223.0,-0.231743,-0.2576,-0.595426,0.000367,-37.326844,0.0,4.0,3724.0,21285000000000.0,2.0,2.0,228.0,0.094074,0.068143,-0.2285,0.005257,-13.454103,0.0,10.05048,3812.0,43605000000000.0,4.0,2.0,233.0,0.517859,0.542323,0.312333,0.020598,18.462269,0.0,27.490936,3958.0,65110000000000.0,5.0,2.0,238.0,1.928769,3.234613,2.475326,3.966906,89.08033,1.0,2628.199951,4146.0,86395000000000.0,7.0,2.0,243.0,eaab7a96,0.0
2,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,-0.499738,0.046381,-0.181152,0.056544,-11.934993,0.0,77.30513,4106.425781,44816770000000.0,3.148264,3.0,100.144516,0.454021,0.510668,0.412588,0.140594,27.367514,0.0,274.848145,50.734318,20381560000000.0,1.169176,0.0,5.653936,-1.903281,-3.150104,-1.020313,0.0,-89.540176,0.0,0.0,3853.0,45000000000.0,1.0,3.0,97.0,-0.873151,-0.255299,-0.485521,0.005643,-30.154542,0.0,2.918126,4089.625,28885000000000.0,3.0,3.0,98.0,-0.644505,0.088542,-0.191693,0.018467,-11.570901,0.0,7.863636,4111.0,47270000000000.0,3.0,3.0,99.0,-0.242422,0.381953,0.088555,0.048282,5.009753,0.0,21.022933,4140.0,60945000000000.0,4.0,3.0,100.0,1.02151,1.016589,1.746797,5.066334,86.987267,0.0,2618.199951,4183.0,86365000000000.0,7.0,3.0,134.0,8ec2cc63,0.0
3,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,0.00743,0.007583,-0.19651,0.053544,-12.847143,0.0,9.369678,3958.604492,48366420000000.0,4.273992,2.303057,60.025017,0.5861,0.542189,0.474437,0.103401,32.552841,0.0,54.104408,122.706802,18687730000000.0,2.023705,1.487018,7.396456,-1.684624,-2.405738,-1.023798,0.0,-89.968369,0.0,0.0,3468.0,0.0,1.0,1.0,48.0,-0.530198,-0.412805,-0.556091,0.009947,-34.965618,0.0,0.893617,3841.0,35260000000000.0,3.0,1.0,53.0,0.022344,0.009674,-0.245181,0.027653,-15.000056,0.0,2.340206,3947.0,48810000000000.0,4.0,1.0,60.0,0.536801,0.443383,0.084469,0.057278,4.816339,0.0,6.2,4064.0,63300000000000.0,6.0,4.0,67.0,5.908,2.083693,1.269051,6.134459,89.976074,0.0,2502.0,6000.0,86395000000000.0,7.0,4.0,72.0,b2987a65,0.0
4,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,0.086653,-0.115162,-0.138969,0.040399,-11.009835,0.0,5.049157,3992.347656,58338950000000.0,4.541829,4.0,46.192024,0.509845,0.494897,0.639449,0.090201,47.933723,0.0,15.590773,126.12159,21462060000000.0,2.081796,0.0,18.615358,-1.675859,-1.071042,-1.012266,0.0,-89.770241,0.0,0.0,3815.083252,35000000000.0,1.0,4.0,20.0,-0.224805,-0.444297,-0.685736,0.005364,-46.348264,0.0,1.438378,3837.333252,51613750000000.0,3.0,4.0,32.0,0.053034,-0.087422,-0.22543,0.024135,-13.665493,0.0,2.897436,4000.0,64270000000000.0,4.0,4.0,42.0,0.544297,0.153125,0.347474,0.04369,20.726226,0.0,4.942201,4087.0,73936250000000.0,7.0,4.0,69.0,3.231563,1.03362,1.071875,2.774382,89.300034,0.0,1046.800049,4199.0,86015000000000.0,7.0,4.0,76.0,7b8842c3,2.0


In [10]:
time_series_df_with_target.shape

(996, 98)

In [11]:
from torch.utils.data import Dataset,DataLoader
from sklearn.preprocessing import StandardScaler

class CustomDataset(Dataset):
    
    def __init__(self, dataframe):
        # Apply StandardScaler to the input features
        self.scaler = StandardScaler()
        if 'sii' in dataframe.columns:
            self.train = True
            features = dataframe.drop(['id', 'sii'], axis=1)  # Drop ID and target column

            self.targets = dataframe['sii'].values  # Keep target values (sii)
        else:
            self.train = False
            features = dataframe.drop(['id'],axis=1)
            
        self.scaled_data = self.scaler.fit_transform(features)  # Scale features
            

    def __len__(self):
        return len(self.scaled_data)
    
    def __getitem__(self, idx):
        # Return the scaled input features and target value
        if self.train:
            return torch.tensor(self.scaled_data[idx], dtype=torch.float32), torch.tensor(self.targets[idx], dtype=torch.long)  # Ensure targets are long for classification
        else:
            return torch.tensor(self.scaled_data[idx], dtype=torch.float32)
        
train_dataset = CustomDataset(time_series_df_with_target)
train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)
test_dataset = CustomDataset(time_series_df_without_target)
test_dataloader = DataLoader(test_dataset,batch_size=1,shuffle=True)

In [12]:
class LSTMEncoder(nn.Module):
    
    def __init__(self, input_size, hidden_size, latent_dim, num_classes, num_layers=1):
        super(LSTMEncoder, self).__init__()
        # LSTM layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # Fully connected layers for latent dimension and classification
        self.fc_latent = nn.Linear(hidden_size, latent_dim)  # Latent space
        self.fc_class = nn.Linear(latent_dim, num_classes)   # Classifier layer for output
        
    def forward(self, x): 
        _, (hn, _) = self.lstm(x)  # Get hidden state from LSTM (hn)
        latent = self.fc_latent(hn[-1])  # Compress the last hidden state to latent space
        out = self.fc_class(latent)  # Classify using the latent space
        return latent, out

In [13]:
# Example parameters
input_size = time_series_df_with_target.shape[1] - 2  # Number of features (excluding the ID column)
hidden_size = 64  # Number of hidden units in LSTM
latent_dim = 30  # Number of dimensions to encode into
num_classes = len(time_series_df_with_target['sii'].unique()) 
num_layers = 2  # LSTM layers


model = LSTMEncoder(input_size, hidden_size, latent_dim, num_classes)

# Define a loss function and optimizer
criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [14]:
# Training loop
epochs = 40  # Number of epochs to train

for epoch in range(epochs):
    total_loss = 0
    for batch_data, batch_targets in train_dataloader:
        optimizer.zero_grad()
        # Forward pass: latent representation and predictions
        latent, predictions = model(batch_data.unsqueeze(1))
        # Compute the loss
        loss = criterion(predictions, batch_targets)
        total_loss+=loss.item()
        loss.backward()  # Backpropagate the error
        optimizer.step()  # Update the weights

    print(f"Epoch [{epoch+1}], Loss: {total_loss/32}")

Epoch [1], Loss: 1.2404455803334713
Epoch [2], Loss: 1.035022884607315
Epoch [3], Loss: 0.9608190562576056
Epoch [4], Loss: 0.9374386016279459
Epoch [5], Loss: 0.909640483558178
Epoch [6], Loss: 0.8784117512404919
Epoch [7], Loss: 0.8636265266686678
Epoch [8], Loss: 0.8390360716730356
Epoch [9], Loss: 0.8201069105416536
Epoch [10], Loss: 0.7985943667590618
Epoch [11], Loss: 0.7723584808409214
Epoch [12], Loss: 0.7490436993539333
Epoch [13], Loss: 0.7086466662585735
Epoch [14], Loss: 0.6841372419148684
Epoch [15], Loss: 0.6511627063155174
Epoch [16], Loss: 0.634468924254179
Epoch [17], Loss: 0.596195443533361
Epoch [18], Loss: 0.5766428112983704
Epoch [19], Loss: 0.5415463871322572
Epoch [20], Loss: 0.528366825543344
Epoch [21], Loss: 0.5323669826611876
Epoch [22], Loss: 0.4941811729222536
Epoch [23], Loss: 0.45144467521458864
Epoch [24], Loss: 0.43001858051866293
Epoch [25], Loss: 0.40831130370497704
Epoch [26], Loss: 0.3851572214625776
Epoch [27], Loss: 0.36273577623069286
Epoch [28],

In [15]:
import pandas as pd

final_list = []

# After training, you can use the encoder to extract both the latent dimensions and class predictions
with torch.no_grad():
    for batch_data, batch_targets in train_dataloader:
        batch_data = batch_data.unsqueeze(1)  # Add sequence length dim if needed
        latent, predictions = model(batch_data)
        final_list.append(latent)

# Step 1: Concatenate all the latent representations into a single tensor
# Assuming each 'latent' is of shape (batch_size, latent_dim)
all_latents = torch.cat(final_list, dim=0)  # Concatenate along the first dimension (batch dimension)

# Step 2: Convert the concatenated tensor to a NumPy array
latent_array = all_latents.numpy()  # Convert to NumPy array

# Step 3: Create a DataFrame from the NumPy array
num_latent_dims = latent_array.shape[1]  # Get the number of latent dimensions
train_latent = pd.DataFrame(latent_array, columns=[f'enc_{i + 1}' for i in range(num_latent_dims)])


In [17]:
num_latent_dims

30

In [16]:
train_latent.head()

Unnamed: 0,enc_1,enc_2,enc_3,enc_4,enc_5,enc_6,enc_7,enc_8,enc_9,enc_10,enc_11,enc_12,enc_13,enc_14,enc_15,enc_16,enc_17,enc_18,enc_19,enc_20,enc_21,enc_22,enc_23,enc_24,enc_25,enc_26,enc_27,enc_28,enc_29,enc_30
0,1.384762,0.081567,0.740432,0.829361,-1.525097,-0.014673,1.912793,1.337686,1.398968,-1.206113,-1.328526,0.751323,1.472244,0.411321,0.80827,1.704962,1.089582,-0.881723,-1.222761,-1.608002,-1.650093,-0.590714,0.013961,0.971984,-1.505845,-0.187743,-0.895593,0.99433,-0.772713,-1.647657
1,-0.05748,1.311167,-0.836831,-0.521505,-0.689205,-1.514201,-0.095523,0.801894,0.902057,-0.61335,0.153963,1.188354,0.476232,-0.811944,0.581134,0.326969,0.956164,0.788822,-0.658555,-0.124069,-0.700108,-0.80497,1.09445,0.695873,-0.586935,-0.792594,-0.354745,-0.011389,-0.844777,-0.540772
2,-1.44204,0.530479,-0.712663,0.974248,0.817819,-0.636752,-0.819578,-0.117547,-0.063284,0.861366,1.05374,0.040443,-0.128282,-1.227911,-0.79271,-0.749145,-0.259506,0.45814,1.041381,0.329298,0.123221,0.180655,0.304033,-1.010097,0.361208,-1.152425,1.507697,0.057937,-0.008707,0.134124
3,-1.601115,1.618527,-0.748296,1.473149,-0.20864,-1.553808,-0.00301,1.076604,0.64095,-0.053615,0.28418,1.312809,0.912976,-1.963471,-0.481466,-0.291403,0.594934,0.18974,0.201411,-0.718823,-0.537076,-0.897316,0.95596,-0.36636,-0.832996,-2.919144,1.403833,0.61279,-1.311536,-1.082387
4,-1.331578,1.242839,-0.703403,1.143044,-0.006306,-1.140545,0.095361,1.004541,0.336048,0.148075,0.451261,1.048767,0.61306,-1.53002,-0.493065,-0.203204,0.635847,0.270278,0.154042,-0.44988,-0.550983,-0.575723,0.897552,-0.200369,-0.48423,-2.11012,1.086367,0.218793,-1.090903,-0.631838


In [18]:
train_latent['id'] = time_series_df_with_target['id']

In [19]:
train_latent.head()

Unnamed: 0,enc_1,enc_2,enc_3,enc_4,enc_5,enc_6,enc_7,enc_8,enc_9,enc_10,enc_11,enc_12,enc_13,enc_14,enc_15,enc_16,enc_17,enc_18,enc_19,enc_20,enc_21,enc_22,enc_23,enc_24,enc_25,enc_26,enc_27,enc_28,enc_29,enc_30,id
0,1.384762,0.081567,0.740432,0.829361,-1.525097,-0.014673,1.912793,1.337686,1.398968,-1.206113,-1.328526,0.751323,1.472244,0.411321,0.80827,1.704962,1.089582,-0.881723,-1.222761,-1.608002,-1.650093,-0.590714,0.013961,0.971984,-1.505845,-0.187743,-0.895593,0.99433,-0.772713,-1.647657,0745c390
1,-0.05748,1.311167,-0.836831,-0.521505,-0.689205,-1.514201,-0.095523,0.801894,0.902057,-0.61335,0.153963,1.188354,0.476232,-0.811944,0.581134,0.326969,0.956164,0.788822,-0.658555,-0.124069,-0.700108,-0.80497,1.09445,0.695873,-0.586935,-0.792594,-0.354745,-0.011389,-0.844777,-0.540772,eaab7a96
2,-1.44204,0.530479,-0.712663,0.974248,0.817819,-0.636752,-0.819578,-0.117547,-0.063284,0.861366,1.05374,0.040443,-0.128282,-1.227911,-0.79271,-0.749145,-0.259506,0.45814,1.041381,0.329298,0.123221,0.180655,0.304033,-1.010097,0.361208,-1.152425,1.507697,0.057937,-0.008707,0.134124,8ec2cc63
3,-1.601115,1.618527,-0.748296,1.473149,-0.20864,-1.553808,-0.00301,1.076604,0.64095,-0.053615,0.28418,1.312809,0.912976,-1.963471,-0.481466,-0.291403,0.594934,0.18974,0.201411,-0.718823,-0.537076,-0.897316,0.95596,-0.36636,-0.832996,-2.919144,1.403833,0.61279,-1.311536,-1.082387,b2987a65
4,-1.331578,1.242839,-0.703403,1.143044,-0.006306,-1.140545,0.095361,1.004541,0.336048,0.148075,0.451261,1.048767,0.61306,-1.53002,-0.493065,-0.203204,0.635847,0.270278,0.154042,-0.44988,-0.550983,-0.575723,0.897552,-0.200369,-0.48423,-2.11012,1.086367,0.218793,-1.090903,-0.631838,7b8842c3


In [20]:
train.head()

Unnamed: 0,id,Basic_Demos-Enroll_Season,Basic_Demos-Age,Basic_Demos-Sex,CGAS-Season,CGAS-CGAS_Score,Physical-Season,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Season,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-Season,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-Season,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-Season,PAQ_A-PAQ_A_Total,PAQ_C-Season,PAQ_C-PAQ_C_Total,PCIAT-Season,PCIAT-PCIAT_01,PCIAT-PCIAT_02,PCIAT-PCIAT_03,PCIAT-PCIAT_04,PCIAT-PCIAT_05,PCIAT-PCIAT_06,PCIAT-PCIAT_07,PCIAT-PCIAT_08,PCIAT-PCIAT_09,PCIAT-PCIAT_10,PCIAT-PCIAT_11,PCIAT-PCIAT_12,PCIAT-PCIAT_13,PCIAT-PCIAT_14,PCIAT-PCIAT_15,PCIAT-PCIAT_16,PCIAT-PCIAT_17,PCIAT-PCIAT_18,PCIAT-PCIAT_19,PCIAT-PCIAT_20,PCIAT-PCIAT_Total,SDS-Season,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-Season,PreInt_EduHx-computerinternet_hoursday,sii
0,00008ff9,Fall,5,0,Winter,51.0,Fall,16.877316,46.0,50.8,,,,,,,,,Fall,0.0,0.0,,,,,0.0,0.0,7.0,0.0,6.0,0.0,6.0,1.0,Fall,2.0,2.66855,16.8792,932.498,1492.0,8.25598,41.5862,13.8177,3.06143,9.21377,1.0,24.4349,8.89536,38.9177,19.5413,32.6909,,,,,Fall,5.0,4.0,4.0,0.0,4.0,0.0,0.0,4.0,0.0,0.0,4.0,0.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,4.0,55.0,,,,Fall,3.0,2.0
1,000fd460,Summer,9,0,,,Fall,14.03559,48.0,46.0,22.0,75.0,70.0,122.0,,,,,Fall,3.0,0.0,,,,,5.0,0.0,11.0,1.0,11.0,1.0,3.0,0.0,Winter,2.0,2.57949,14.0371,936.656,1498.65,6.01993,42.0291,12.8254,1.21172,3.97085,1.0,21.0352,14.974,39.4497,15.4107,27.0552,,,Fall,2.34,Fall,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Fall,46.0,64.0,Summer,0.0,0.0
2,00105258,Summer,10,1,Fall,71.0,Fall,16.648696,56.5,75.6,,65.0,94.0,117.0,Fall,5.0,7.0,33.0,Fall,20.0,1.0,10.2,1.0,14.7,2.0,7.0,1.0,10.0,1.0,10.0,1.0,5.0,0.0,,,,,,,,,,,,,,,,,,,,Summer,2.17,Fall,5.0,2.0,2.0,1.0,2.0,1.0,1.0,2.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,2.0,2.0,1.0,1.0,28.0,Fall,38.0,54.0,Summer,2.0,0.0
3,00115b9f,Winter,9,0,Fall,71.0,Summer,18.292347,56.0,81.6,,60.0,97.0,117.0,Summer,6.0,9.0,37.0,Summer,18.0,1.0,,,,,5.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,Summer,3.0,3.84191,18.2943,1131.43,1923.44,15.5925,62.7757,14.074,4.22033,18.8243,2.0,30.4041,16.779,58.9338,26.4798,45.9966,,,Winter,2.451,Summer,4.0,2.0,4.0,0.0,5.0,1.0,0.0,3.0,2.0,2.0,3.0,0.0,3.0,0.0,0.0,3.0,4.0,3.0,4.0,1.0,44.0,Summer,31.0,45.0,Winter,0.0,1.0
4,0016bb22,Spring,18,1,Summer,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,Summer,1.04,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [21]:
train = pd.merge(train, train_latent, how="left", on='id')

In [22]:
train.head()

Unnamed: 0,id,Basic_Demos-Enroll_Season,Basic_Demos-Age,Basic_Demos-Sex,CGAS-Season,CGAS-CGAS_Score,Physical-Season,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Season,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-Season,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-Season,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-Season,PAQ_A-PAQ_A_Total,PAQ_C-Season,PAQ_C-PAQ_C_Total,PCIAT-Season,PCIAT-PCIAT_01,PCIAT-PCIAT_02,PCIAT-PCIAT_03,PCIAT-PCIAT_04,PCIAT-PCIAT_05,PCIAT-PCIAT_06,PCIAT-PCIAT_07,PCIAT-PCIAT_08,PCIAT-PCIAT_09,PCIAT-PCIAT_10,PCIAT-PCIAT_11,PCIAT-PCIAT_12,PCIAT-PCIAT_13,PCIAT-PCIAT_14,PCIAT-PCIAT_15,PCIAT-PCIAT_16,PCIAT-PCIAT_17,PCIAT-PCIAT_18,PCIAT-PCIAT_19,PCIAT-PCIAT_20,PCIAT-PCIAT_Total,SDS-Season,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-Season,PreInt_EduHx-computerinternet_hoursday,sii,enc_1,enc_2,enc_3,enc_4,enc_5,enc_6,enc_7,enc_8,enc_9,enc_10,enc_11,enc_12,enc_13,enc_14,enc_15,enc_16,enc_17,enc_18,enc_19,enc_20,enc_21,enc_22,enc_23,enc_24,enc_25,enc_26,enc_27,enc_28,enc_29,enc_30
0,00008ff9,Fall,5,0,Winter,51.0,Fall,16.877316,46.0,50.8,,,,,,,,,Fall,0.0,0.0,,,,,0.0,0.0,7.0,0.0,6.0,0.0,6.0,1.0,Fall,2.0,2.66855,16.8792,932.498,1492.0,8.25598,41.5862,13.8177,3.06143,9.21377,1.0,24.4349,8.89536,38.9177,19.5413,32.6909,,,,,Fall,5.0,4.0,4.0,0.0,4.0,0.0,0.0,4.0,0.0,0.0,4.0,0.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,4.0,55.0,,,,Fall,3.0,2.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,000fd460,Summer,9,0,,,Fall,14.03559,48.0,46.0,22.0,75.0,70.0,122.0,,,,,Fall,3.0,0.0,,,,,5.0,0.0,11.0,1.0,11.0,1.0,3.0,0.0,Winter,2.0,2.57949,14.0371,936.656,1498.65,6.01993,42.0291,12.8254,1.21172,3.97085,1.0,21.0352,14.974,39.4497,15.4107,27.0552,,,Fall,2.34,Fall,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Fall,46.0,64.0,Summer,0.0,0.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,00105258,Summer,10,1,Fall,71.0,Fall,16.648696,56.5,75.6,,65.0,94.0,117.0,Fall,5.0,7.0,33.0,Fall,20.0,1.0,10.2,1.0,14.7,2.0,7.0,1.0,10.0,1.0,10.0,1.0,5.0,0.0,,,,,,,,,,,,,,,,,,,,Summer,2.17,Fall,5.0,2.0,2.0,1.0,2.0,1.0,1.0,2.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,2.0,2.0,1.0,1.0,28.0,Fall,38.0,54.0,Summer,2.0,0.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,00115b9f,Winter,9,0,Fall,71.0,Summer,18.292347,56.0,81.6,,60.0,97.0,117.0,Summer,6.0,9.0,37.0,Summer,18.0,1.0,,,,,5.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,Summer,3.0,3.84191,18.2943,1131.43,1923.44,15.5925,62.7757,14.074,4.22033,18.8243,2.0,30.4041,16.779,58.9338,26.4798,45.9966,,,Winter,2.451,Summer,4.0,2.0,4.0,0.0,5.0,1.0,0.0,3.0,2.0,2.0,3.0,0.0,3.0,0.0,0.0,3.0,4.0,3.0,4.0,1.0,44.0,Summer,31.0,45.0,Winter,0.0,1.0,-1.33054,0.509823,-0.574174,0.882908,0.222762,-0.632775,-0.285588,0.342157,0.169526,0.343326,0.792544,0.529145,0.204933,-1.037695,-0.527566,-0.449814,0.258553,-0.051728,0.654262,0.050083,-0.236651,-0.316079,0.043747,-0.540096,-0.246509,-1.398425,1.055014,0.30213,-0.574991,-0.321656
4,0016bb22,Spring,18,1,Summer,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,Summer,1.04,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [23]:
from sklearn.impute import SimpleImputer, KNNImputer

imputer = KNNImputer(n_neighbors=5)
numeric_cols = train.select_dtypes(include=['float64', 'int64']).columns
imputed_data = imputer.fit_transform(train[numeric_cols])
train_imputed = pd.DataFrame(imputed_data, columns=numeric_cols)
train_imputed['sii'] = train_imputed['sii'].round().astype(int)
for col in train.columns:
    if col not in numeric_cols:
        train_imputed[col] = train[col]
        
train = train_imputed


In [24]:
train

Unnamed: 0,Basic_Demos-Age,Basic_Demos-Sex,CGAS-CGAS_Score,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-PAQ_A_Total,PAQ_C-PAQ_C_Total,PCIAT-PCIAT_01,PCIAT-PCIAT_02,PCIAT-PCIAT_03,PCIAT-PCIAT_04,PCIAT-PCIAT_05,PCIAT-PCIAT_06,PCIAT-PCIAT_07,PCIAT-PCIAT_08,PCIAT-PCIAT_09,PCIAT-PCIAT_10,PCIAT-PCIAT_11,PCIAT-PCIAT_12,PCIAT-PCIAT_13,PCIAT-PCIAT_14,PCIAT-PCIAT_15,PCIAT-PCIAT_16,PCIAT-PCIAT_17,PCIAT-PCIAT_18,PCIAT-PCIAT_19,PCIAT-PCIAT_20,PCIAT-PCIAT_Total,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-computerinternet_hoursday,sii,id,Basic_Demos-Enroll_Season,CGAS-Season,Physical-Season,Fitness_Endurance-Season,FGC-Season,BIA-Season,PAQ_A-Season,PAQ_C-Season,PCIAT-Season,SDS-Season,PreInt_EduHx-Season,enc_1,enc_2,enc_3,enc_4,enc_5,enc_6,enc_7,enc_8,enc_9,enc_10,enc_11,enc_12,enc_13,enc_14,enc_15,enc_16,enc_17,enc_18,enc_19,enc_20,enc_21,enc_22,enc_23,enc_24,enc_25,enc_26,enc_27,enc_28,enc_29,enc_30
0,5.0,0.0,51.0,16.877316,46.00,50.8,23.0,61.2,86.4,110.6,4.0,5.8,27.0,0.0,0.0,17.56,1.8,16.18,1.4,0.0,0.0,7.0,0.0,6.0,0.0,6.0,1.0,2.0,2.668550,16.87920,932.4980,1492.000,8.255980,41.58620,13.81770,3.061430,9.213770,1.0,24.43490,8.895360,38.91770,19.54130,32.69090,1.9120,2.2220,5.0,4.0,4.0,0.0,4.0,0.0,0.0,4.0,0.0,0.0,4.0,0.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,4.0,55.0,48.4,62.2,3.0,2,00008ff9,Fall,Winter,Fall,,Fall,Fall,,,Fall,,Fall,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,9.0,0.0,70.0,14.035590,48.00,46.0,22.0,75.0,70.0,122.0,4.6,6.6,24.2,3.0,0.0,16.04,1.6,15.50,1.6,5.0,0.0,11.0,1.0,11.0,1.0,3.0,0.0,2.0,2.579490,14.03710,936.6560,1498.650,6.019930,42.02910,12.82540,1.211720,3.970850,1.0,21.03520,14.974000,39.44970,15.41070,27.05520,2.6260,2.3400,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,46.0,64.0,0.0,0,000fd460,Summer,,Fall,,Fall,Winter,,Fall,Fall,Fall,Summer,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,10.0,1.0,71.0,16.648696,56.50,75.6,24.8,65.0,94.0,117.0,5.0,7.0,33.0,20.0,1.0,10.20,1.0,14.70,2.0,7.0,1.0,10.0,1.0,10.0,1.0,5.0,0.0,2.6,3.431454,19.10500,1106.4030,1889.264,17.199762,60.10940,14.83936,4.265620,17.650582,2.6,28.81348,14.096188,56.67794,27.61536,46.01322,2.0938,2.1700,5.0,2.0,2.0,1.0,2.0,1.0,1.0,2.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,2.0,2.0,1.0,1.0,28.0,38.0,54.0,2.0,0,00105258,Summer,Fall,Fall,Fall,Fall,,,Summer,Fall,Fall,Summer,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,9.0,0.0,71.0,18.292347,56.00,81.6,25.4,60.0,97.0,117.0,6.0,9.0,37.0,18.0,1.0,14.50,1.6,16.92,2.2,5.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,3.0,3.841910,18.29430,1131.4300,1923.440,15.592500,62.77570,14.07400,4.220330,18.824300,2.0,30.40410,16.779000,58.93380,26.47980,45.99660,1.7980,2.4510,4.0,2.0,4.0,0.0,5.0,1.0,0.0,3.0,2.0,2.0,3.0,0.0,3.0,0.0,0.0,3.0,4.0,3.0,4.0,1.0,44.0,31.0,45.0,0.0,1,00115b9f,Winter,Fall,Summer,Summer,Summer,Summer,,Winter,Summer,Summer,Winter,-1.330540,0.509823,-0.574174,0.882908,0.222762,-0.632775,-0.285588,0.342157,0.169526,0.343326,0.792544,0.529145,0.204933,-1.037695,-0.527566,-0.449814,0.258553,-0.051728,0.654262,0.050083,-0.236651,-0.316079,0.043747,-0.540096,-0.246509,-1.398425,1.055014,0.302130,-0.574991,-0.321656
4,18.0,1.0,69.4,26.713639,62.54,123.8,33.6,67.4,79.0,116.8,4.4,8.4,18.8,12.8,0.2,28.48,2.0,28.80,2.0,1.4,0.0,10.1,0.6,9.5,0.6,10.7,0.8,2.4,4.382366,26.06698,1394.9880,2144.724,29.722340,90.84782,16.01834,10.048682,56.672180,2.4,35.37708,25.748480,86.46560,47.54038,65.09940,1.0400,2.0724,4.0,1.8,2.0,2.0,3.2,1.2,2.8,3.6,2.8,2.8,4.2,1.0,2.2,2.0,2.0,2.6,2.0,1.6,1.0,1.4,40.0,42.0,58.8,2.6,1,0016bb22,Spring,Summer,,,,,Summer,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3955,13.0,0.0,60.0,16.362460,59.50,82.4,25.0,71.0,70.0,104.0,4.8,7.2,23.8,16.0,0.0,18.00,1.0,19.90,2.0,10.0,1.0,8.0,1.0,9.0,1.0,12.0,1.0,3.0,4.522770,16.36420,1206.8800,2051.700,19.461100,70.81170,14.06290,2.301380,11.588300,1.0,33.37090,17.979700,66.28890,29.77900,52.83200,2.7338,3.2600,3.0,3.0,3.0,2.0,3.0,2.0,2.0,2.0,2.0,1.0,2.0,0.0,2.0,0.0,1.0,0.0,2.0,1.0,1.0,0.0,32.0,35.0,50.0,1.0,1,ff8a2de4,Fall,Spring,Fall,,Fall,Fall,,Winter,Winter,Winter,Fall,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3956,10.0,0.0,58.6,18.764678,53.50,76.4,27.0,60.0,78.0,118.0,4.4,6.2,23.6,0.0,0.0,17.20,1.8,16.52,1.4,4.0,0.0,0.0,0.0,0.0,0.0,12.0,1.0,3.0,2.940418,18.27962,1010.4646,1785.136,12.685996,49.89076,14.06412,4.215510,15.749254,2.4,25.20942,11.995308,46.95032,21.39602,37.89540,2.2040,2.3400,1.8,2.2,3.0,0.6,2.2,1.0,0.2,1.0,0.4,1.6,1.2,0.2,0.4,0.4,1.4,0.8,2.0,0.8,0.8,0.4,22.4,38.6,54.8,0.0,0,ffa9794a,Winter,,Spring,,Spring,Spring,,Winter,,,Winter,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3957,11.0,0.0,68.0,21.441500,60.00,109.8,28.6,79.0,99.0,116.0,4.6,6.2,32.4,15.0,1.0,18.50,2.0,15.80,2.0,0.0,0.0,10.0,1.0,10.0,1.0,14.0,1.0,2.0,4.413050,21.44380,1253.7400,2005.990,20.482500,75.80330,14.80430,6.639520,33.996700,2.0,33.98050,21.340300,71.39030,28.77920,54.46300,2.7338,2.7290,5.0,5.0,3.0,0.0,5.0,1.0,0.0,2.0,0.0,2.0,1.0,0.0,1.0,3.0,0.0,0.0,1.0,1.0,0.0,1.0,31.0,56.0,77.0,0.0,1,ffcd4dbd,Fall,Spring,Winter,,Winter,Winter,,Winter,Winter,Winter,Fall,-1.211039,0.590514,-0.076698,1.430093,-0.168171,-0.544026,0.275189,0.877730,0.470519,0.289310,0.395656,0.732232,0.738506,-1.107202,-0.360576,-0.023297,0.442283,-0.665836,0.724116,-0.547622,-0.613015,-0.391986,-0.133329,-0.524546,-0.914075,-2.213551,0.932166,1.091346,-1.080955,-1.186035
3958,13.0,0.0,70.0,12.235895,70.70,87.0,27.6,59.0,61.0,113.0,3.8,4.6,25.0,19.0,0.6,23.18,2.0,24.90,2.2,3.8,0.4,9.8,0.4,10.3,0.8,11.6,1.0,4.0,6.661680,12.23720,1414.3400,2970.120,26.532300,92.90920,13.06840,-0.831170,-5.909170,2.0,41.37150,25.005400,86.24750,45.43400,67.90380,2.7338,3.3000,2.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,2.0,0.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,19.0,33.0,47.0,1.0,0,ffed1dd5,Spring,Spring,Winter,,Spring,Summer,,Spring,Spring,Spring,Spring,-0.673213,1.569033,-0.315775,0.728833,-0.865939,-1.166083,0.271233,1.033212,0.544014,-0.681839,-0.150384,0.695773,0.975266,-0.886019,0.258724,0.353051,0.752052,0.004597,-0.473847,-0.405804,-1.431344,-0.526796,0.433676,0.193217,-0.658152,-1.178668,0.344414,0.759541,-1.077679,-0.546399


In [25]:
final_list = []

# After training, you can use the encoder to extract both the latent dimensions and class predictions
with torch.no_grad():
    for batch_data in test_dataloader:
        batch_data = batch_data.unsqueeze(1)  # Add sequence length dim if needed
        latent, predictions = model(batch_data)
        final_list.append(latent)

# Step 1: Concatenate all the latent representations into a single tensor
# Assuming each 'latent' is of shape (batch_size, latent_dim)
all_latents = torch.cat(final_list, dim=0)  # Concatenate along the first dimension (batch dimension)

# Step 2: Convert the concatenated tensor to a NumPy array
latent_array = all_latents.numpy()  # Convert to NumPy array

# Step 3: Create a DataFrame from the NumPy array
num_latent_dims = latent_array.shape[1]  # Get the number of latent dimensions
test_latent = pd.DataFrame(latent_array, columns=[f'enc_{i + 1}' for i in range(num_latent_dims)])

In [26]:
test_latent['id'] = time_series_df_without_target['id']
test = pd.merge(test, test_latent, how="left", on='id')

In [27]:
test.head()

Unnamed: 0,id,Basic_Demos-Enroll_Season,Basic_Demos-Age,Basic_Demos-Sex,CGAS-Season,CGAS-CGAS_Score,Physical-Season,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Season,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-Season,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-Season,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-Season,PAQ_A-PAQ_A_Total,PAQ_C-Season,PAQ_C-PAQ_C_Total,SDS-Season,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-Season,PreInt_EduHx-computerinternet_hoursday,enc_1,enc_2,enc_3,enc_4,enc_5,enc_6,enc_7,enc_8,enc_9,enc_10,enc_11,enc_12,enc_13,enc_14,enc_15,enc_16,enc_17,enc_18,enc_19,enc_20,enc_21,enc_22,enc_23,enc_24,enc_25,enc_26,enc_27,enc_28,enc_29,enc_30
0,00008ff9,Fall,5,0,Winter,51.0,Fall,16.877316,46.0,50.8,,,,,,,,,Fall,0.0,0.0,,,,,0.0,0.0,7.0,0.0,6.0,0.0,6.0,1.0,Fall,2.0,2.66855,16.8792,932.498,1492.0,8.25598,41.5862,13.8177,3.06143,9.21377,1.0,24.4349,8.89536,38.9177,19.5413,32.6909,,,,,,,,Fall,3.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,000fd460,Summer,9,0,,,Fall,14.03559,48.0,46.0,22.0,75.0,70.0,122.0,,,,,Fall,3.0,0.0,,,,,5.0,0.0,11.0,1.0,11.0,1.0,3.0,0.0,Winter,2.0,2.57949,14.0371,936.656,1498.65,6.01993,42.0291,12.8254,1.21172,3.97085,1.0,21.0352,14.974,39.4497,15.4107,27.0552,,,Fall,2.34,Fall,46.0,64.0,Summer,0.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,00105258,Summer,10,1,Fall,71.0,Fall,16.648696,56.5,75.6,,65.0,94.0,117.0,Fall,5.0,7.0,33.0,Fall,20.0,1.0,10.2,1.0,14.7,2.0,7.0,1.0,10.0,1.0,10.0,1.0,5.0,0.0,,,,,,,,,,,,,,,,,,,,Summer,2.17,Fall,38.0,54.0,Summer,2.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,00115b9f,Winter,9,0,Fall,71.0,Summer,18.292347,56.0,81.6,,60.0,97.0,117.0,Summer,6.0,9.0,37.0,Summer,18.0,1.0,,,,,5.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,Summer,3.0,3.84191,18.2943,1131.43,1923.44,15.5925,62.7757,14.074,4.22033,18.8243,2.0,30.4041,16.779,58.9338,26.4798,45.9966,,,Winter,2.451,Summer,31.0,45.0,Winter,0.0,-1.51527,1.563666,-0.652432,1.443732,-0.336878,-1.275869,0.148527,1.06132,0.712389,-0.133747,0.293381,1.016995,0.969857,-1.659968,-0.410516,-0.078542,0.931415,-0.104727,0.155349,-0.339625,-0.831246,-0.966991,0.561038,-0.231293,-0.804451,-2.255823,1.08183,0.705595,-1.426094,-0.815867
4,0016bb22,Spring,18,1,Summer,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,Summer,1.04,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [28]:
featuresCols = ['Basic_Demos-Enroll_Season', 'Basic_Demos-Age', 'Basic_Demos-Sex',
                'CGAS-Season', 'CGAS-CGAS_Score', 'Physical-Season', 'Physical-BMI',
                'Physical-Height', 'Physical-Weight', 'Physical-Waist_Circumference',
                'Physical-Diastolic_BP', 'Physical-HeartRate', 'Physical-Systolic_BP',
                'Fitness_Endurance-Season', 'Fitness_Endurance-Max_Stage',
                'Fitness_Endurance-Time_Mins', 'Fitness_Endurance-Time_Sec',
                'FGC-Season', 'FGC-FGC_CU', 'FGC-FGC_CU_Zone', 'FGC-FGC_GSND',
                'FGC-FGC_GSND_Zone', 'FGC-FGC_GSD', 'FGC-FGC_GSD_Zone', 'FGC-FGC_PU',
                'FGC-FGC_PU_Zone', 'FGC-FGC_SRL', 'FGC-FGC_SRL_Zone', 'FGC-FGC_SRR',
                'FGC-FGC_SRR_Zone', 'FGC-FGC_TL', 'FGC-FGC_TL_Zone', 'BIA-Season',
                'BIA-BIA_Activity_Level_num', 'BIA-BIA_BMC', 'BIA-BIA_BMI',
                'BIA-BIA_BMR', 'BIA-BIA_DEE', 'BIA-BIA_ECW', 'BIA-BIA_FFM',
                'BIA-BIA_FFMI', 'BIA-BIA_FMI', 'BIA-BIA_Fat', 'BIA-BIA_Frame_num',
                'BIA-BIA_ICW', 'BIA-BIA_LDM', 'BIA-BIA_LST', 'BIA-BIA_SMM',
                'BIA-BIA_TBW', 'PAQ_A-Season', 'PAQ_A-PAQ_A_Total', 'PAQ_C-Season',
                'PAQ_C-PAQ_C_Total', 'SDS-Season', 'SDS-SDS_Total_Raw',
                'SDS-SDS_Total_T', 'PreInt_EduHx-Season',
                'PreInt_EduHx-computerinternet_hoursday', 'sii']

featuresCols += ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5', 'enc_6', 'enc_7', 'enc_8',
       'enc_9', 'enc_10', 'enc_11', 'enc_12', 'enc_13', 'enc_14', 'enc_15',
       'enc_16', 'enc_17', 'enc_18', 'enc_19', 'enc_20','enc_21','enc_22','enc_23','enc_24','enc_25','enc_26','enc_27','enc_28','enc_29','enc_30']

train = train[featuresCols]
train = train.dropna(subset='sii')
test  = test.drop('id',axis=1)

cat_c = ['Basic_Demos-Enroll_Season', 'CGAS-Season', 'Physical-Season', 
          'Fitness_Endurance-Season', 'FGC-Season', 'BIA-Season', 
          'PAQ_A-Season', 'PAQ_C-Season', 'SDS-Season', 'PreInt_EduHx-Season']

In [33]:
def update(df):
    global cat_c
    for c in cat_c: 
        df[c] = df[c].fillna('Missing')
        df[c] = df[c].astype('category')
    return df
        
train = update(train)
test = update(test)

def create_mapping(column, dataset):
    unique_values = dataset[column].unique()
    return {value: idx for idx, value in enumerate(unique_values)}

for col in cat_c:
    mapping = create_mapping(col, train)
    mappingTe = create_mapping(col, test)
    
    train[col] = train[col].replace(mapping).astype(int)
    test[col] = test[col].replace(mappingTe).astype(int)

def quadratic_weighted_kappa(y_true, y_pred):
    return cohen_kappa_score(y_true, y_pred, weights='quadratic')

def threshold_Rounder(oof_non_rounded, thresholds):
    return np.where(oof_non_rounded < thresholds[0], 0,
                    np.where(oof_non_rounded < thresholds[1], 1,
                             np.where(oof_non_rounded < thresholds[2], 2, 3)))

def evaluate_predictions(thresholds, y_true, oof_non_rounded):
    rounded_p = threshold_Rounder(oof_non_rounded, thresholds)
    return -quadratic_weighted_kappa(y_true, rounded_p)

In [34]:
def TrainML(model_class, test_data):
    X = train.drop(['sii'], axis=1)
    y = train['sii']

    SKF = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    
    train_S = []
    test_S = []
    
    oof_non_rounded = np.zeros(len(y), dtype=float) 
    oof_rounded = np.zeros(len(y), dtype=int) 
    test_preds = np.zeros((len(test_data), n_splits))

    for fold, (train_idx, test_idx) in enumerate(tqdm(SKF.split(X, y), desc="Training Folds", total=n_splits)):
        X_train, X_val = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_val = y.iloc[train_idx], y.iloc[test_idx]

        model = clone(model_class)
        model.fit(X_train, y_train)

        y_train_pred = model.predict(X_train)
        y_val_pred = model.predict(X_val)

        oof_non_rounded[test_idx] = y_val_pred
        y_val_pred_rounded = y_val_pred.round(0).astype(int)
        oof_rounded[test_idx] = y_val_pred_rounded

        train_kappa = quadratic_weighted_kappa(y_train, y_train_pred.round(0).astype(int))
        val_kappa = quadratic_weighted_kappa(y_val, y_val_pred_rounded)

        train_S.append(train_kappa)
        test_S.append(val_kappa)
        
        test_preds[:, fold] = model.predict(test_data)
        
        print(f"Fold {fold+1} - Train QWK: {train_kappa:.4f}, Validation QWK: {val_kappa:.4f}")
        clear_output(wait=True)

    print(f"Mean Train QWK --> {np.mean(train_S):.4f}")
    print(f"Mean Validation QWK ---> {np.mean(test_S):.4f}")

    KappaOPtimizer = minimize(evaluate_predictions,
                              x0=[0.5, 1.5, 2.5], args=(y, oof_non_rounded), 
                              method='Nelder-Mead')
    assert KappaOPtimizer.success, "Optimization did not converge."
    
    oof_tuned = threshold_Rounder(oof_non_rounded, KappaOPtimizer.x)
    tKappa = quadratic_weighted_kappa(y, oof_tuned)

    print(f"----> || Optimized QWK SCORE :: {Fore.CYAN}{Style.BRIGHT} {tKappa:.3f}{Style.RESET_ALL}")

    tpm = test_preds.mean(axis=1)
    tpTuned = threshold_Rounder(tpm, KappaOPtimizer.x)
    
    submission = pd.DataFrame({
        'id': sample['id'],
        'sii': tpTuned
    })

    return submission

In [35]:
Params = {
    'learning_rate': 0.046,
    'max_depth': 12,
    'num_leaves': 478,
    'min_data_in_leaf': 13,
    'feature_fraction': 0.893,
    'bagging_fraction': 0.784,
    'bagging_freq': 4,
    'lambda_l1': 10,  # Increased from 6.59
    'lambda_l2': 0.01  # Increased from 2.68e-06
}


XGB_Params = {
    'learning_rate': 0.05,
    'max_depth': 6,
    'n_estimators': 200,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'reg_alpha': 1,  # Increased from 0.1
    'reg_lambda': 5,  # Increased from 1
    'random_state': SEED,
    'tree_method': 'exact'
}


CatBoost_Params = {
    'learning_rate': 0.05,
    'depth': 6,
    'iterations': 200,
    'random_seed': SEED,
    'verbose': 0,
    'l2_leaf_reg': 10  # Increase this value
}

# Create model instances
Light = LGBMRegressor(**Params, random_state=SEED, verbose=-1, n_estimators=300)
XGB_Model = XGBRegressor(**XGB_Params)
CatBoost_Model = CatBoostRegressor(**CatBoost_Params)

# Combine models using Voting Regressor
voting_model = VotingRegressor(estimators=[
    ('lightgbm', Light),
    ('xgboost', XGB_Model),
    ('catboost', CatBoost_Model)
])

In [36]:
Submission1 = TrainML(voting_model, test)

# Save submission
Submission1.to_csv('submission.csv', index=False)

Training Folds: 100%|██████████| 5/5 [00:27<00:00,  5.48s/it]

Mean Train QWK --> 0.8096
Mean Validation QWK ---> 0.4964
----> || Optimized QWK SCORE :: [36m[1m 0.550[0m



