## Imports and Set Up

In [None]:
import os
import random
import warnings
from concurrent.futures import ThreadPoolExecutor
from sklearn.model_selection import train_test_split
from scipy.optimize import minimize

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from colorama import Fore, Style
from IPython.display import clear_output
from lightgbm import LGBMClassifier, LGBMRegressor
from matplotlib import pyplot as plt
from sklearn.base import clone
from sklearn.ensemble import VotingClassifier, VotingRegressor, StackingClassifier, StackingRegressor
from sklearn.impute import KNNImputer
from sklearn.metrics import (accuracy_score, cohen_kappa_score,
                             confusion_matrix, f1_score, mean_absolute_error,
                             mean_squared_error, precision_score, recall_score,
                             classification_report)
from sklearn.model_selection import (GridSearchCV, KFold, RandomizedSearchCV,
                                     cross_val_score)
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression, Ridge, Lasso
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from xgboost import XGBClassifier, XGBRegressor
from catboost import CatBoostClassifier, CatBoostRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

In [295]:
warnings.filterwarnings('ignore')
pd.options.display.max_columns = None

In [296]:
def set_seed(seed_value=2024):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.backends.cudnn.deterministic = True

set_seed(2024)

## Data Processing

### Load in Files

In [297]:
file_path = '/Users/naliniramanathan/projects/ml_course/final_project/kaggle/input/cmi-piu'

In [298]:
TRAIN_CSV = f'{file_path}/train.csv'
TEST_CSV = f'{file_path}/test.csv'
SAMPLE_SUBMISSION_CSV = f'{file_path}/sample_submission.csv'
SERIES_TRAIN_DIR = f'{file_path}/series_train.parquet'
SERIES_TEST_DIR = f'{file_path}/series_test.parquet'


train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)
sample_submission_df = pd.read_csv(SAMPLE_SUBMISSION_CSV)

# Drop all the PCIAT variables as they are not present in the test data
for col in train_df.columns:
    if 'PCIAT' in col:
        train_df.drop(col, axis=1, inplace=True)

In [299]:
# Function to process individual time series files
def process_time_series(file_name, directory):
    df = pd.read_parquet(os.path.join(directory, file_name, 'part-0.parquet'))
    df = df.drop('step', axis=1)
    stats = df.describe().values.flatten()
    record_id = file_name.split('=')[1]
    return stats, record_id

In [300]:
# Function to load and aggregate time series data
def load_time_series_data(directory):
    file_names = os.listdir(directory)
    stats_list = []
    ids_list = []

    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(lambda fname: process_time_series(fname, directory), file_names),
                            total=len(file_names)))

    for stats, record_id in results:
        stats_list.append(stats)
        ids_list.append(record_id)

    stats_df = pd.DataFrame(stats_list, columns=[f'stat_{i}' for i in range(len(stats_list[0]))])
    stats_df['id'] = ids_list
    return stats_df

In [301]:
train_series_df = load_time_series_data(SERIES_TRAIN_DIR)
# test_series_df = load_time_series_data(SERIES_TEST_DIR)

100%|██████████| 996/996 [00:26<00:00, 38.17it/s]


In [302]:
train_df, test_df = train_test_split(train_df, test_size=0.2, random_state=2024)
test_series_df = train_series_df[train_series_df.id.isin(test_df.id)]
train_series_df = train_series_df[train_series_df.id.isin(train_df.id)]


In [303]:
train_series_df

Unnamed: 0,stat_0,stat_1,stat_2,stat_3,stat_4,stat_5,stat_6,stat_7,stat_8,stat_9,stat_10,stat_11,stat_12,stat_13,stat_14,stat_15,stat_16,stat_17,stat_18,stat_19,stat_20,stat_21,stat_22,stat_23,stat_24,stat_25,stat_26,stat_27,stat_28,stat_29,stat_30,stat_31,stat_32,stat_33,stat_34,stat_35,stat_36,stat_37,stat_38,stat_39,stat_40,stat_41,stat_42,stat_43,stat_44,stat_45,stat_46,stat_47,stat_48,stat_49,stat_50,stat_51,stat_52,stat_53,stat_54,stat_55,stat_56,stat_57,stat_58,stat_59,stat_60,stat_61,stat_62,stat_63,stat_64,stat_65,stat_66,stat_67,stat_68,stat_69,stat_70,stat_71,stat_72,stat_73,stat_74,stat_75,stat_76,stat_77,stat_78,stat_79,stat_80,stat_81,stat_82,stat_83,stat_84,stat_85,stat_86,stat_87,stat_88,stat_89,stat_90,stat_91,stat_92,stat_93,stat_94,stat_95,id
1,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,412332.0,0.047866,0.003234,-0.249981,0.023465,-18.722479,0.216525,68.818016,3841.463379,4.323937e+13,3.809581,2.475539,24.915834,0.523361,0.440953,0.646356,0.052377,49.601418,0.407909,278.388855,165.153732,2.500622e+13,1.971711,0.499402,6.901020,-1.777734,-2.433394,-1.005808,0.0,-89.819664,0.0,0.0,3098.166748,0.0,1.0,2.0,13.0,-0.266012,-0.277724,-0.829161,0.000011,-56.706230,0.0,2.738918,3741.000000,2.147500e+13,2.0,2.0,19.0,0.009822,0.008072,-0.383322,0.006272,-23.246984,0.0,7.405453,3807.000000,4.344500e+13,4.0,2.0,25.0,0.445334,0.261080,0.160221,0.020526,9.357183,0.0,18.088059,3963.333252,6.492000e+13,6.0,3.0,31.0,1.859814,1.518311,1.510279,3.006919,89.322289,1.0,2648.000000,4181.0,8.639500e+13,7.0,3.0,37.0,cefdb7fe
2,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,384228.0,-0.088861,0.045154,-0.212270,0.004798,-21.601578,0.611183,9.674905,3838.082031,4.347224e+13,4.096849,1.000000,36.297169,0.300376,0.371397,0.840793,0.027920,67.862083,0.480341,47.099369,145.080429,2.495826e+13,1.986910,0.000000,6.425583,-2.163437,-3.142938,-1.001401,0.0,-89.842148,0.0,0.0,3098.166748,0.0,1.0,1.0,25.0,-0.345387,-0.088445,-0.992394,0.000000,-88.350620,0.0,1.677741,3759.000000,2.183000e+13,2.0,1.0,31.0,0.005507,0.002966,-0.636735,0.000079,-40.234131,1.0,6.103776,3818.000000,4.366000e+13,4.0,1.0,36.0,0.027835,0.067280,0.786521,0.003858,52.158220,1.0,10.394940,3941.000000,6.521000e+13,6.0,1.0,42.0,1.017271,1.381445,1.041023,4.491224,88.801147,1.0,1157.250000,4152.0,8.639500e+13,7.0,1.0,47.0,58391429
3,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,311959.0,-0.080044,0.058017,-0.269036,0.045412,-17.783920,0.000000,132.968567,3874.709473,4.857577e+13,3.851570,3.000000,48.147717,0.601244,0.595210,0.375176,0.095587,25.640087,0.000000,434.704041,133.961487,1.791852e+13,1.985380,0.000000,9.506028,-1.962057,-2.844661,-1.021510,0.0,-89.553940,0.0,0.0,3683.000000,0.0,1.0,3.0,31.0,-0.605807,-0.445863,-0.536875,0.005877,-33.453754,0.0,2.168796,3771.000000,3.450500e+13,2.0,3.0,41.0,-0.190000,0.093646,-0.278828,0.022030,-16.665411,0.0,7.046413,3829.000000,4.836500e+13,4.0,3.0,48.0,0.472721,0.594089,-0.049271,0.041659,-3.008855,0.0,29.298994,3971.250000,6.338000e+13,6.0,3.0,56.0,1.148359,3.186745,2.724948,4.054967,89.521629,0.0,2648.500000,4181.0,8.639500e+13,7.0,3.0,67.0,2ca2206f
4,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,377160.0,-0.067303,0.187596,-0.381376,0.018893,-30.760855,0.657652,10.622702,3829.777344,4.324427e+13,4.015813,1.000000,19.931402,0.282446,0.528214,0.669038,0.075740,53.370640,0.468158,39.889423,147.425171,2.503205e+13,1.965738,0.000000,6.311698,-3.150714,-4.179972,-1.019038,0.0,-89.815567,0.0,0.0,3098.166748,0.0,1.0,1.0,9.0,-0.169914,-0.029689,-0.996593,0.000286,-80.339439,0.0,4.218323,3747.000000,2.142500e+13,2.0,1.0,14.0,-0.049600,-0.005727,-0.594736,0.009630,-36.482208,1.0,6.641018,3812.000000,4.347750e+13,4.0,1.0,20.0,0.016026,0.712304,-0.000686,0.012452,0.043783,1.0,10.536869,3941.000000,6.497000e+13,6.0,1.0,25.0,2.427422,2.343212,2.094606,5.087605,89.960457,1.0,2408.199951,4133.0,8.639500e+13,7.0,1.0,31.0,19455336
5,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,526776.0,0.011678,0.114502,-0.723974,0.002158,-62.733452,0.784815,25.791946,3894.840332,4.320400e+13,3.868753,1.000000,70.551842,0.168990,0.434890,0.475918,0.019802,41.028954,0.397332,56.427204,159.122070,2.479460e+13,2.038921,0.000000,8.967446,-1.999768,-1.199915,-2.148964,0.0,-89.697090,0.0,0.0,3354.000000,0.0,1.0,1.0,55.0,-0.005133,-0.026829,-0.990580,0.000000,-88.332832,1.0,1.372340,3771.000000,2.194500e+13,2.0,1.0,63.0,0.020726,-0.003964,-0.981840,0.000009,-86.213867,1.0,6.011971,3912.000000,4.319500e+13,4.0,1.0,70.0,0.026316,0.091347,-0.678192,0.000896,-44.343567,1.0,25.867804,4023.000000,6.449000e+13,6.0,1.0,78.0,1.143887,2.047239,1.386601,2.261413,89.643997,1.0,1770.199951,4185.0,8.639500e+13,7.0,1.0,86.0,ca33a5e7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
990,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,429084.0,0.054041,0.110304,-0.932455,0.004260,-75.901634,0.970956,3.264637,3842.950928,4.325764e+13,4.203792,4.000000,55.396873,0.181127,0.173883,0.243926,0.007777,19.675661,0.163449,10.167803,158.916595,2.501305e+13,2.004653,0.000000,7.178455,-1.323896,-1.420903,-1.029749,0.0,-88.951241,0.0,0.0,3098.166748,0.0,1.0,4.0,43.0,-0.002213,0.150396,-0.991953,0.003147,-81.185448,1.0,0.500000,3747.000000,2.145000e+13,2.0,4.0,49.0,0.015264,0.155581,-0.991110,0.004280,-80.995773,1.0,0.500000,3806.833252,4.349000e+13,4.0,4.0,55.0,0.017470,0.172698,-0.989982,0.005075,-80.114990,1.0,0.500000,3958.000000,6.494500e+13,6.0,4.0,62.0,0.950812,0.996813,1.007953,2.163906,85.700821,1.0,184.600006,4174.0,8.639500e+13,7.0,4.0,68.0,05e94f88
992,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,72533.0,-0.227848,-0.117374,-0.109314,0.070964,-7.892540,0.000000,23.489414,4024.496338,4.969159e+13,4.140736,1.000000,129.086533,0.505669,0.523566,0.546675,0.129670,39.408108,0.000000,47.289368,94.724304,1.763440e+13,1.713176,0.000000,9.245427,-2.023180,-2.103810,-1.051996,0.0,-89.811768,0.0,0.0,3830.000000,0.0,1.0,1.0,119.0,-0.638697,-0.532812,-0.542981,0.010700,-35.077408,0.0,3.294118,3935.000000,3.508500e+13,3.0,1.0,121.0,-0.311040,-0.107125,-0.139521,0.028815,-8.490570,0.0,8.333333,4011.000000,5.160500e+13,4.0,1.0,129.0,0.110838,0.243214,0.238977,0.073403,14.122293,0.0,24.666666,4111.000000,6.424500e+13,6.0,1.0,139.0,1.163996,2.271134,1.489056,2.631471,89.823715,0.0,2255.500000,4171.0,8.639500e+13,7.0,1.0,149.0,2840643b
993,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,401964.0,0.470508,0.069204,-0.119591,0.044411,-8.286257,0.003582,35.851643,3852.268066,4.335045e+13,3.916291,3.000000,103.655518,0.451216,0.433979,0.556107,0.117473,37.932995,0.057099,153.920746,163.820465,2.485187e+13,1.930933,0.000000,6.727608,-0.989845,-5.156441,-1.008375,0.0,-89.379982,0.0,0.0,3098.166748,0.0,1.0,3.0,92.0,0.289556,-0.165256,-0.602924,0.002373,-37.793286,0.0,1.672412,3747.000000,2.184500e+13,2.0,3.0,98.0,0.604042,0.086445,-0.154146,0.009824,-9.497402,0.0,5.251727,3819.000000,4.369000e+13,4.0,3.0,104.0,0.801560,0.355886,0.296767,0.035004,17.242387,0.0,17.481647,3976.000000,6.470000e+13,6.0,3.0,109.0,2.261910,1.023213,2.497662,4.914551,89.513603,1.0,2561.399902,4184.0,8.639500e+13,7.0,3.0,115.0,1b329556
994,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,401880.0,0.001327,-0.233908,-0.117364,0.038436,-9.717763,0.307449,28.920029,3845.557129,4.336256e+13,3.917826,1.335443,73.671005,0.509233,0.547586,0.579904,0.146882,44.027718,0.457175,151.875031,164.799362,2.486164e+13,1.930397,0.472146,6.722071,-2.147763,-3.231679,-1.083508,0.0,-89.907837,0.0,0.0,3098.166748,0.0,1.0,1.0,62.0,-0.195067,-0.910050,-0.515145,0.000000,-31.985840,0.0,3.099985,3747.000000,2.184000e+13,2.0,1.0,68.0,-0.036661,-0.098634,-0.101701,0.002339,-5.965729,0.0,6.100000,3809.333252,4.368000e+13,4.0,1.0,74.0,0.357545,0.096994,0.242008,0.027419,13.758414,1.0,15.486906,3964.000000,6.475000e+13,6.0,2.0,79.0,2.621233,2.709097,2.616484,5.241471,89.762779,1.0,2623.000000,4187.0,8.639500e+13,7.0,2.0,85.0,62b873a2


In [304]:
test_series_df

Unnamed: 0,stat_0,stat_1,stat_2,stat_3,stat_4,stat_5,stat_6,stat_7,stat_8,stat_9,stat_10,stat_11,stat_12,stat_13,stat_14,stat_15,stat_16,stat_17,stat_18,stat_19,stat_20,stat_21,stat_22,stat_23,stat_24,stat_25,stat_26,stat_27,stat_28,stat_29,stat_30,stat_31,stat_32,stat_33,stat_34,stat_35,stat_36,stat_37,stat_38,stat_39,stat_40,stat_41,stat_42,stat_43,stat_44,stat_45,stat_46,stat_47,stat_48,stat_49,stat_50,stat_51,stat_52,stat_53,stat_54,stat_55,stat_56,stat_57,stat_58,stat_59,stat_60,stat_61,stat_62,stat_63,stat_64,stat_65,stat_66,stat_67,stat_68,stat_69,stat_70,stat_71,stat_72,stat_73,stat_74,stat_75,stat_76,stat_77,stat_78,stat_79,stat_80,stat_81,stat_82,stat_83,stat_84,stat_85,stat_86,stat_87,stat_88,stat_89,stat_90,stat_91,stat_92,stat_93,stat_94,stat_95,id
0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,269335.0,-0.478973,-0.037643,-0.215956,0.061542,-14.676989,0.000000,41.468441,3876.515869,4.758289e+13,4.015780,1.000000,46.020077,0.429458,0.518862,0.422680,0.129371,28.847290,0.000000,180.024918,121.012596,1.727138e+13,1.998955,0.000000,9.290177,-3.298790,-3.262288,-1.134395,0.0,-89.820724,0.0,0.0,3706.000000,0.0,1.0,1.0,28.0,-0.816840,-0.353651,-0.545926,0.007793,-34.635059,0.000000,1.500000,3779.250000,3.406250e+13,2.0,1.0,39.0,-0.578881,-0.006695,-0.211011,0.027122,-12.874933,0.0,5.495555,3841.000000,4.772500e+13,4.0,1.0,45.0,-0.250647,0.300744,0.050556,0.068639,2.749033,0.0,21.666666,3970.000000,6.146000e+13,6.0,1.0,55.0,1.159667,2.525316,1.802745,4.568309,89.673332,0.0,2659.666748,4179.0,8.639500e+13,7.0,1.0,63.0,0d01bbf2
11,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,399120.0,-0.285605,0.071007,-0.313368,0.024183,-21.146635,0.007216,45.585548,3842.319336,4.319523e+13,4.181058,2.000000,341.493896,0.531390,0.478450,0.482372,0.069384,33.808033,0.082754,216.744736,164.662521,2.488954e+13,2.003139,0.000000,6.681085,-1.589197,-2.173970,-1.043743,0.0,-89.192787,0.0,0.0,3098.166748,0.0,1.0,2.0,330.0,-0.724618,-0.253596,-0.691327,0.000023,-44.900260,0.000000,1.737113,3741.000000,2.169000e+13,2.0,2.0,336.0,-0.422316,0.108709,-0.405710,0.004300,-24.966161,0.0,4.489776,3807.500000,4.317500e+13,4.0,2.0,341.0,0.092856,0.424834,-0.061616,0.021861,-3.668349,0.0,10.569219,3959.666748,6.470500e+13,6.0,2.0,347.0,1.006139,1.341341,1.023566,3.108276,89.264053,1.0,2619.399902,4182.0,8.639500e+13,7.0,2.0,353.0,b3b200af
19,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,287571.0,0.202230,-0.003141,-0.188392,0.079971,-13.381116,0.000000,140.634109,3886.381104,4.926739e+13,4.325092,2.000000,73.435066,0.567678,0.553317,0.444292,0.190162,30.720694,0.000000,411.122864,128.497330,1.796438e+13,1.870637,0.000000,9.301646,-3.653923,-4.190508,-1.022067,0.0,-89.691467,0.0,0.0,3722.000000,0.0,1.0,2.0,56.0,-0.253580,-0.401566,-0.506908,0.008786,-33.052214,0.000000,3.000000,3777.000000,3.523500e+13,3.0,2.0,66.0,0.326310,0.024360,-0.215055,0.032396,-13.348002,0.0,10.250000,3847.000000,4.951000e+13,5.0,2.0,73.0,0.671350,0.393137,0.051707,0.079640,2.613756,0.0,46.987309,3979.750000,6.369500e+13,6.0,2.0,82.0,2.654759,5.195954,2.832015,6.454380,89.583122,0.0,2652.500000,4175.0,8.639500e+13,7.0,2.0,89.0,051680a0
23,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,403320.0,0.011216,-0.018847,-0.697816,0.002489,-58.750317,0.853317,11.358668,3845.631592,4.336458e+13,4.011366,1.000000,30.630705,0.135910,0.305648,0.625384,0.023102,54.402809,0.343040,42.282906,162.969589,2.481832e+13,1.912128,0.000000,6.748073,-1.469559,-1.060594,-1.014847,0.0,-89.863556,0.0,0.0,3098.166748,0.0,1.0,1.0,19.0,-0.019339,-0.018319,-0.994136,0.000015,-89.149544,1.000000,1.253009,3747.000000,2.191500e+13,2.0,1.0,25.0,-0.004935,0.010467,-0.992521,0.000073,-86.951851,1.0,4.353283,3812.000000,4.367500e+13,4.0,1.0,31.0,0.008185,0.049206,-0.734066,0.001015,-48.064898,1.0,11.292164,3964.000000,6.468000e+13,6.0,1.0,36.0,0.992514,1.794758,1.305533,2.336523,89.392189,1.0,2637.199951,4177.0,8.639500e+13,7.0,1.0,42.0,90161e10
25,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,87538.0,0.003999,0.110171,0.043007,0.059819,2.681419,0.000000,22.743656,4052.015625,5.245292e+13,4.567091,3.983584,18.057495,0.622313,0.511731,0.497145,0.138316,34.413090,0.000000,65.989326,72.575485,1.651591e+13,1.667935,0.221310,5.575263,-1.253277,-2.877334,-1.007265,0.0,-89.879013,0.0,0.0,3844.000000,0.0,1.0,1.0,12.0,-0.577394,-0.227775,-0.320940,0.010689,-19.702371,0.000000,2.256177,3999.270813,3.953500e+13,3.0,4.0,14.0,0.011879,0.106812,0.021848,0.026004,1.106987,0.0,7.171294,4070.000000,5.321500e+13,4.0,4.0,16.0,0.579564,0.521798,0.424460,0.057313,25.821259,0.0,22.843782,4111.000000,6.540500e+13,6.0,4.0,20.0,1.417182,2.422021,1.825919,3.503941,89.174927,0.0,2468.199951,4189.0,8.639500e+13,7.0,4.0,43.0,3e5d5b58
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
973,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,379188.0,-0.082540,-0.079869,-0.042417,0.056889,-3.423715,0.077992,94.957237,3832.540283,4.320096e+13,3.934776,2.000000,157.984329,0.569989,0.571310,0.524226,0.187297,36.632317,0.264059,334.423859,166.462967,2.497332e+13,1.983275,0.000000,6.349608,-3.014132,-2.904584,-1.042274,0.0,-89.806732,0.0,0.0,3098.166748,0.0,1.0,2.0,147.0,-0.576627,-0.527424,-0.432366,0.000961,-26.814069,0.000000,2.660053,3731.416748,2.154000e+13,2.0,2.0,152.0,-0.104755,-0.086252,-0.027888,0.010186,-1.884405,0.0,8.333333,3809.083252,4.326000e+13,4.0,2.0,158.0,0.394502,0.337412,0.306536,0.031147,17.537446,0.0,29.819900,3953.000000,6.485500e+13,6.0,2.0,163.0,2.019584,3.029424,2.999776,4.654961,88.891907,1.0,2628.333252,4173.0,8.639500e+13,7.0,2.0,169.0,350dbeba
975,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,368388.0,0.250754,-0.075845,-0.192110,0.026559,-14.090879,0.046042,27.160236,3808.138184,4.341863e+13,3.970097,1.000000,10.668621,0.632445,0.436654,0.538393,0.044102,37.970463,0.208189,133.221649,143.152313,2.483842e+13,1.999779,0.000000,6.166293,-1.041804,-1.277343,-1.040714,0.0,-89.494568,0.0,0.0,3098.166748,0.0,1.0,1.0,0.0,-0.242655,-0.388633,-0.657494,0.000571,-41.172942,0.000000,2.471314,3730.000000,2.192500e+13,2.0,1.0,5.0,0.433162,-0.075005,-0.222548,0.012290,-12.876396,0.0,8.222222,3790.750000,4.385500e+13,4.0,1.0,11.0,0.822297,0.202641,0.207740,0.033273,11.851877,0.0,18.350832,3912.000000,6.480000e+13,6.0,1.0,16.0,1.246738,1.366126,1.011202,1.806151,89.266899,1.0,2610.500000,4143.0,8.639500e+13,7.0,1.0,21.0,029a19c9
977,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,351003.0,-0.345470,0.122593,-0.093052,0.044392,-6.657838,0.000000,29.102085,3879.713623,4.601385e+13,4.145261,3.000000,155.721741,0.551300,0.484132,0.463365,0.081445,31.521679,0.000000,145.294983,145.071259,2.622573e+13,1.984711,0.000000,9.082359,-2.023905,-2.581890,-1.051024,0.0,-89.899117,0.0,0.0,3372.000000,0.0,1.0,3.0,140.0,-0.791670,-0.195786,-0.445745,0.006845,-28.216097,0.000000,2.155172,3765.000000,2.200750e+13,2.0,3.0,148.0,-0.516473,0.165282,-0.106230,0.019534,-6.528247,0.0,5.525997,3841.000000,5.022500e+13,4.0,3.0,156.0,-0.024668,0.483355,0.221420,0.051605,13.148877,0.0,14.123962,3996.750000,6.908000e+13,6.0,3.0,163.0,1.032127,1.183022,1.394631,3.043137,89.263519,0.0,2625.250000,4198.0,8.639500e+13,7.0,3.0,174.0,a5dbb00a
989,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,349404.0,0.131730,-0.008608,0.187249,0.040975,12.479903,0.000000,18.632076,3868.046143,4.718307e+13,4.119758,1.869106,22.022997,0.679721,0.435364,0.486258,0.058817,33.226883,0.000000,46.156944,141.434158,1.763411e+13,1.925532,0.337285,9.614893,-1.313665,-1.571729,-0.997084,0.0,-89.331566,0.0,0.0,3486.250000,0.0,1.0,1.0,5.0,-0.562262,-0.277304,-0.163658,0.008704,-9.783342,0.000000,3.903253,3759.000000,3.282875e+13,3.0,2.0,14.0,0.308594,-0.013777,0.218678,0.022349,12.757167,0.0,8.982333,3827.166748,4.759750e+13,4.0,2.0,22.0,0.764384,0.249961,0.580601,0.049643,36.235672,0.0,19.799999,3976.000000,6.163000e+13,6.0,2.0,30.0,1.047408,1.649819,1.601738,2.212217,89.548485,0.0,2558.000000,4187.0,8.639500e+13,7.0,2.0,40.0,22b4191a


### Encoding of Time Series Data

In [305]:
class TimeSeriesAutoencoder(nn.Module):
    def __init__(self, input_size, encoding_size):
        super(TimeSeriesAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, encoding_size * 3),
            nn.ReLU(),
            nn.Linear(encoding_size * 3, encoding_size * 2),
            nn.ReLU(),
            nn.Linear(encoding_size * 2, encoding_size),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(encoding_size, input_size * 2),
            nn.ReLU(),
            nn.Linear(input_size * 2, input_size * 3),
            nn.ReLU(),
            nn.Linear(input_size * 3, input_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [306]:
def encodedDataPrep(train, test):
    scaler = StandardScaler()
    scaled_data_train = scaler.fit_transform(train.drop('id', axis=1))
    scaled_data_test = scaler.transform(test.drop('id', axis=1))
    tensor_data_train = torch.FloatTensor(scaled_data_train)
    tensor_data_test = torch.FloatTensor(scaled_data_test)
    return tensor_data_train, tensor_data_test

In [307]:
def get_encoded_features(tensor_data, train=True, autoencoder=None, encoding_dim=60, epochs=100, batch_size=32):
    input_dim = tensor_data.shape[1]
    autoencoder = TimeSeriesAutoencoder(input_dim, encoding_dim)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(autoencoder.parameters())
    for epoch in range(epochs):
        for i in range(0, len(tensor_data), batch_size):
            batch = tensor_data[i:i + batch_size]
            optimizer.zero_grad()
            outputs = autoencoder(batch)
            loss = criterion(outputs, batch)
            loss.backward()
            optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

    with torch.no_grad():
        encoded_data = autoencoder.encoder(tensor_data).numpy()

    encoded_df = pd.DataFrame(encoded_data, columns=[f'Enc_{i + 1}' for i in range(encoded_data.shape[1])])
    return encoded_df, autoencoder

In [308]:
def encode_test_data(autoencoder, test_tensor):
    with torch.no_grad():
        encoded_data = autoencoder.encoder(test_tensor).numpy()
    encoded_df = pd.DataFrame(encoded_data, columns=[f'Enc_{i + 1}' for i in range(encoded_data.shape[1])])
    return encoded_df

In [309]:
tensor_data_train, tensor_data_test = encodedDataPrep(train_series_df, test_series_df)
train_encoded, autoencoder = get_encoded_features(tensor_data_train)
test_encoded = encode_test_data(autoencoder, tensor_data_test)

train_encoded['id'] = train_series_df['id']
test_encoded['id'] = test_series_df['id']

train_df = train_df.merge(train_encoded, on='id', how='left')
test_df = test_df.merge(test_encoded, on='id', how='left')

Epoch [10/100], Loss: 0.6212
Epoch [20/100], Loss: 0.5033
Epoch [30/100], Loss: 0.4650
Epoch [40/100], Loss: 0.4470
Epoch [50/100], Loss: 0.4372
Epoch [60/100], Loss: 0.4205
Epoch [70/100], Loss: 0.4159
Epoch [80/100], Loss: 0.4128
Epoch [90/100], Loss: 0.4107
Epoch [100/100], Loss: 0.4110


### Imputation of Missing Numerical Values

In [310]:
# Imputing missing values using KNN imputer
def impute_missing_values(df):
    imputer = KNNImputer(n_neighbors=5)
    numeric_columns = df.select_dtypes(include=['float64','float32','int64']).columns # We treat the time series data as a median for now but others could be used
    imputed_array = imputer.fit_transform(df[numeric_columns])
    imputed_df = pd.DataFrame(imputed_array, columns=numeric_columns)
    for col in df.columns:
        if col not in numeric_columns:
            imputed_df[col] = df[col]
    return imputed_df

In [311]:
# train_df.replace([np.inf, -np.inf], np.nan, inplace=True) # Debug why we have inf values

train_df = impute_missing_values(train_df)
test_df = impute_missing_values(test_df)

train_df['sii'] = train_df['sii'].round().astype(int)
test_df['sii'] = test_df['sii'].round().astype(int)

# Imputation is needed in test set for some cases but not others to revisit

print(train_df.isna().sum())
print(test_df.isna().sum())

Basic_Demos-Age           0
Basic_Demos-Sex           0
CGAS-CGAS_Score           0
Physical-BMI              0
Physical-Height           0
                       ... 
BIA-Season             1450
PAQ_A-Season           2778
PAQ_C-Season           1794
SDS-Season             1057
PreInt_EduHx-Season     331
Length: 120, dtype: int64
Basic_Demos-Age          0
Basic_Demos-Sex          0
CGAS-CGAS_Score          0
Physical-BMI             0
Physical-Height          0
                      ... 
BIA-Season             365
PAQ_A-Season           707
PAQ_C-Season           445
SDS-Season             285
PreInt_EduHx-Season     89
Length: 120, dtype: int64


### Categorical Processing

In [312]:
categorical_cols = list(train_df.select_dtypes(include=['object']).columns) #sii (outcome var) is categorical but we are encoding that differently
categorical_cols.remove('id')
print(categorical_cols)

def preprocess_categorical(df):
    for col in categorical_cols:
        df[col] = df[col].fillna('Missing').astype('category')
    return df

train_df = preprocess_categorical(train_df)
test_df = preprocess_categorical(test_df)

train_df = pd.get_dummies(train_df, columns = categorical_cols, drop_first=True, dtype='int')
test_df = pd.get_dummies(test_df, columns = categorical_cols, drop_first=True, dtype='int')

['Basic_Demos-Enroll_Season', 'CGAS-Season', 'Physical-Season', 'Fitness_Endurance-Season', 'FGC-Season', 'BIA-Season', 'PAQ_A-Season', 'PAQ_C-Season', 'SDS-Season', 'PreInt_EduHx-Season']


In [313]:
train_df = train_df.drop('id', axis=1)
test_df = test_df.drop('id', axis=1)

In [314]:
# Define the Quadratic Weighted Kappa metric
def quadratic_weighted_kappa(y_actual, y_predicted):
    return cohen_kappa_score(y_actual, y_predicted, weights='quadratic')

## Model 1 - Voting Classifier

In [315]:
def train_and_evaluate(model, param_dist, n_iter=10):
    X = train_df.drop('sii', axis=1)
    y = train_df['sii'].astype(int)
    kf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    search = RandomizedSearchCV(model, param_dist, n_iter=n_iter, cv=kf, n_jobs=-1, random_state=42, verbose=2)
    search.fit(X, y)
    best_model = search.best_estimator_
    y_pred = best_model.predict(X)
    print(f"Best Parameters: {search.best_params_}")
    return best_model, y_pred, y

In [None]:
# Training and evaluation function
def train_and_evaluate_classifier(model, param_dist, n_iter=10):
    best_model, y_pred, y = train_and_evaluate(model, param_dist, n_iter)
    train_kappa = quadratic_weighted_kappa(y, y_pred)
    train_accuracy = accuracy_score(y, y_pred)
    train_f1 = f1_score(y, y_pred, average='micro')
    print(f"Best Train QWK: {train_kappa:.4f}")
    print(f"Best Train Accuracy: {train_accuracy:.4f}")
    return best_model

In [317]:
# Define the classifiers
rf = RandomForestClassifier(random_state=42)
svc = SVC(probability=True, random_state=42)
xgb_model = XGBClassifier(random_state=42, objective='multi:softprob', num_class=4, verbosity=0, learning_rate=0.05)
lgb_model = LGBMClassifier(random_state=42, objective='multiclass', num_class=4, verbose=-1, learning_rate=0.05)
catb_model = CatBoostClassifier(random_state=42, objective='MultiClass', verbose=False, learning_rate=0.05)

In [318]:
param_dist = {
            # 'rf__n_estimators': [100, 200, 300], 'rf__max_depth': [3], 'rf__min_samples_split': [2, 5, 10], 'rf__min_samples_leaf': [1], 'rf__max_features': ['auto', 'sqrt', 'log2'],
            #   'svc__C': [0.1, 1, 10], 'svc__gamma': ['scale', 'auto'], 'svc__kernel': ['linear', 'rbf'],
              'xgb__n_estimators': [200], 'xgb__max_depth': [2, 3],
              'lgb__n_estimators': [100], 'lgb__num_leaves': [5, 10],
              'cat__iterations': [100, 200], 'cat__depth': [2, 4]
              }

In [319]:

# Create an ensemble using VotingClassifier with soft voting
ensemble_model = VotingClassifier(
    estimators=[('lgb', lgb_model), 
                ('xgb', xgb_model), 
                ('cat', catb_model)
                # ('rf', rf) 
                # ,('svc', svc)
                ],
    voting='soft',
    # weights=[4.0, 4.0, 5.0],
    n_jobs=-1
)

In [320]:
best_model = train_and_evaluate_classifier(ensemble_model, param_dist, n_iter=2)

Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=   5.0s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=   5.0s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=   5.0s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=   5.0s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=   5.1s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=   3.5s
Best Parameters: {'xgb__n_estimators': 200, 'xgb__max_depth': 3, 

In [None]:
X_test = test_df.drop('sii', axis=1)
y_test = test_df['sii'].astype(int)
y_test_pred = best_model.predict(X_test)
test_kappa = quadratic_weighted_kappa(y_test, y_test_pred)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_f1 = f1_score(y_test, y_test_pred, average='micro')
print(f"Test QWK: {Fore.GREEN}{Style.BRIGHT}{test_kappa:.4f}{Style.RESET_ALL}")
print(f"Test Accuracy: {Fore.GREEN}{Style.BRIGHT}{test_accuracy:.4f}{Style.RESET_ALL}")
print(f"Test F1: {Fore.GREEN}{Style.BRIGHT}{test_f1:.4f}{Style.RESET_ALL}")

Test QWK: [32m[1m0.3721[0m
Test Accuracy: [32m[1m0.6414[0m


## Model 1.5 - Stacking Classifier

In [322]:

# Create an ensemble using VotingClassifier with soft voting
ensemble_model = StackingClassifier(
    estimators=[('lgb', lgb_model), 
                ('xgb', xgb_model), 
                ('cat', catb_model)
                # ('rf', rf) 
                # ,('svc', svc)
                ],
    # weights=[4.0, 4.0, 5.0],
    n_jobs=-1
)

In [323]:
best_model = train_and_evaluate_classifier(ensemble_model, param_dist, n_iter=2)

Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=  30.8s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=  30.9s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=  30.9s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=  31.3s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=  31.3s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=  31.5s
Best Parameters: {'xgb__n_estimators': 200, 'xgb__max_depth': 2, 

In [None]:
X_test = test_df.drop('sii', axis=1)
y_test = test_df['sii'].astype(int)
y_test_pred = best_model.predict(X_test)
test_kappa = quadratic_weighted_kappa(y_test, y_test_pred)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_f1 = f1_score(y_test, y_test_pred, average='micro')
print(f"Test QWK: {Fore.GREEN}{Style.BRIGHT}{test_kappa:.4f}{Style.RESET_ALL}")
print(f"Test Accuracy: {Fore.GREEN}{Style.BRIGHT}{test_accuracy:.4f}{Style.RESET_ALL}")
print(f"Test F1: {Fore.GREEN}{Style.BRIGHT}{test_f1:.4f}{Style.RESET_ALL}")

Test QWK: [32m[1m0.3739[0m
Test Accuracy: [32m[1m0.6326[0m


## Model 2 - Voting Regressor

In [325]:
# Function to apply thresholds to continuous predictions
def apply_thresholds(predictions, thresholds):
    return np.digitize(predictions, bins=thresholds)

# Function to optimize thresholds to maximize QWK
def optimize_thresholds(y_true, predictions):
    def loss_func(thresh):
        # Ensure thresholds are sorted
        thresh_sorted = np.sort(thresh)
        preds = apply_thresholds(predictions, thresh_sorted)
        return -quadratic_weighted_kappa(y_true, preds)
    
    initial_thresholds = [0.5, 1.5, 2.5]  # Initial guesses for thresholds
    bounds = [(0, 3)] * 3  # Assuming classes are 0,1,2,3
    result = minimize(loss_func, initial_thresholds, method='Nelder-Mead')
    return result.x

In [326]:
def train_and_evaluate_regressor(model, param_dist, n_iter=10):
    best_model, y_pred, y = train_and_evaluate(model, param_dist, n_iter)
    thresholds = optimize_thresholds(y, y_pred)
    y_pred = apply_thresholds(y_pred, thresholds)
    train_kappa = quadratic_weighted_kappa(y, y_pred)
    train_accuracy = accuracy_score(y, y_pred)
    print(f"Best Train QWK: {train_kappa:.4f}")
    print(f"Best Train Accuracy: {train_accuracy:.4f}")
    return best_model, thresholds

In [327]:
svm = SVC(random_state=42)
xgb_model = XGBRegressor(random_state=42, verbosity=0, learning_rate=0.05)
lgb_model = LGBMRegressor(random_state=42, verbose=-1, learning_rate=0.05)
catb_model = CatBoostRegressor(random_state=42, verbose=False, learning_rate=0.05)

In [328]:
param_dist = {
            # 'rf__n_estimators': [100, 200, 300], 'rf__max_depth': [3], 'rf__min_samples_split': [2, 5, 10], 'rf__min_samples_leaf': [1], 'rf__max_features': ['auto', 'sqrt', 'log2'],
            #   'svc__C': [0.1, 1, 10], 'svc__gamma': ['scale', 'auto'], 'svc__kernel': ['linear', 'rbf'],
              'xgb__n_estimators': [200], 'xgb__max_depth': [2, 3],
              'lgb__n_estimators': [100], 'lgb__num_leaves': [5, 10],
              'cat__iterations': [100, 200], 'cat__depth': [2, 4]
              }

In [329]:
# Create an ensemble using VotingClassifier with soft voting
ensemble_model = VotingRegressor(
    estimators=[('lgb', lgb_model), 
                ('xgb', xgb_model), 
                ('cat', catb_model)
                # ('rf', rf) 
                # ,('svc', svc)
                ],
    # weights=[4.0, 4.0, 5.0],
    n_jobs=-1
)

In [330]:
best_model, thresholds = train_and_evaluate_regressor(ensemble_model, param_dist, n_iter=2)

Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=   1.6s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=   1.6s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=   1.6s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=   1.7s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=2, xgb__n_estimators=200; total time=   1.7s
[CV] END cat__depth=2, cat__iterations=100, lgb__n_estimators=100, lgb__num_leaves=5, xgb__max_depth=3, xgb__n_estimators=200; total time=   1.6s
Best Parameters: {'xgb__n_estimators': 200, 'xgb__max_depth': 3, 

In [None]:
X_test = test_df.drop('sii', axis=1)
y_test = test_df['sii'].astype(int)
y_test_pred = best_model.predict(X_test)
y_test_pred = apply_thresholds(y_test_pred, thresholds)
test_kappa = quadratic_weighted_kappa(y_test, y_test_pred)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_f1 = f1_score(y_test, y_test_pred, average='micro')
print(f"Test QWK: {Fore.GREEN}{Style.BRIGHT}{test_kappa:.4f}{Style.RESET_ALL}")
print(f"Test Accuracy: {Fore.GREEN}{Style.BRIGHT}{test_accuracy:.4f}{Style.RESET_ALL}")
print(f"Test F1: {Fore.GREEN}{Style.BRIGHT}{test_f1:.4f}{Style.RESET_ALL}")

Test QWK: [32m[1m0.3805[0m
Test Accuracy: [32m[1m0.5859[0m


## Model 2.5 - Stacking Regressor

In [None]:
# Create an ensemble using VotingClassifier with soft voting
ensemble_model = StackingRegressor(
    estimators=[('lgb', lgb_model), 
                ('xgb', xgb_model), 
                ('cat', catb_model)
                # ('rf', rf) 
                # ,('svc', svc)
                ],
    # weights=[4.0, 4.0, 5.0],
    n_jobs=-1
)

In [None]:
best_model, thresholds = train_and_evaluate_regressor(ensemble_model, param_dist, n_iter=2)

In [None]:
X_test = test_df.drop('sii', axis=1)
y_test = test_df['sii'].astype(int)
y_test_pred = best_model.predict(X_test)
y_test_pred = apply_thresholds(y_test_pred, thresholds)
test_kappa = quadratic_weighted_kappa(y_test, y_test_pred)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_f1 = f1_score(y_test, y_test_pred, average='micro')
print(f"Test QWK: {Fore.GREEN}{Style.BRIGHT}{test_kappa:.4f}{Style.RESET_ALL}")
print(f"Test Accuracy: {Fore.GREEN}{Style.BRIGHT}{test_accuracy:.4f}{Style.RESET_ALL}")
print(f"Test F1: {Fore.GREEN}{Style.BRIGHT}{test_f1:.4f}{Style.RESET_ALL}")