In [2]:
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import datetime
from sklearn.metrics import mean_squared_error, roc_auc_score
from trees import construct_tree

In [3]:
df = pd.read_csv('./dataset/data.csv').astype('float32')
df['label'] = df['label'].replace(21, 19)
print(df.head(4))

   Assets - Total      Cash  Debt in Current Liabilities - Total  \
0        0.147774 -0.067606                            -0.026326   
1       -0.128259 -0.269847                            -0.097770   
2        0.124385  0.561633                             0.139320   
3       -0.241668 -0.175692                            -0.129999   

   Long-Term Debt - Total  Earnings Before Interest  Gross Profit (Loss)  \
0                0.748226                  0.942850             1.175610   
1               -0.175437                 -0.157587            -0.274136   
2               -0.149001                  0.509364             1.047681   
3               -0.388112                 -0.382308            -0.386092   

   Liabilities - Total  Retained Earnings  EBTI/REV  \
0             0.211900          -0.941212 -0.099914   
1            -0.123878          -0.126145  0.449058   
2            -0.020466           0.742475  0.085299   
3            -0.225056          -0.131515  0.038316   

  

In [9]:
X = df.drop('label', axis=1)  # Assuming 'label_column' is the column containing labels
y = df['label']

X = torch.tensor(X.values, dtype=torch.float32)
y = torch.tensor(y.values, dtype=torch.int).reshape(-1,1)  # Assuming labels are integers

In [10]:
class_counts = df['label'].value_counts()
print(class_counts)

8.0     352
12.0    290
9.0     280
11.0    266
7.0     262
10.0    246
13.0    219
14.0    217
6.0     164
5.0     144
15.0     85
4.0      61
3.0      35
16.0     32
2.0      13
0.0      10
1.0       8
17.0      8
19.0      6
18.0      5
Name: label, dtype: int64


In [11]:
def ordinal_criterion(predictions, targets):
  # Ordinal regression with encoding as in https://arxiv.org/pdf/0704.1028.pdf

  # Create out modified target with [batch_size, num_labels] shape
  modified_target = torch.zeros_like(predictions)

  # Fill in ordinal target function, i.e. 0 -> [1,0,0,...]
  for i, target in enumerate(targets):
    modified_target[i, 0:int(target)+1] = 1

  return nn.MSELoss(reduction='none')(predictions, modified_target).sum(axis=1)

In [12]:
def fl_split(split, x_train, y_train):
    x_eq = []
    y_eq = []

    n_splits = 50


    skf = StratifiedKFold(n_splits=n_splits,shuffle=True, random_state=42)
    skf.get_n_splits(x_train, y_train)

    for i, (_, test_index) in enumerate(skf.split(x_train, y_train)):
        x_eq.append(x_train[test_index])
        y_eq.append(y_train[test_index])

    x_split = []
    y_split = []

    acc = 0
    for s in split:
        x_split.append(torch.cat(x_eq[acc:acc+int(s*n_splits)], 0))
        y_split.append(torch.cat(y_eq[acc:acc+int(s*n_splits)], 0))
        acc += int(s*n_splits)

    return x_split, y_split

In [14]:
from sklearn.model_selection import StratifiedKFold
from imblearn.over_sampling import RandomOverSampler

def run_local(split, X=X, y=y):

    mses = []
    aucs = []

    kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    start = datetime.datetime.now()
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X, y)):
        print(f"Fold {fold_idx + 1}:")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        class_1_indices = torch.where(y_train == 1)[0]
        class_17_indices = torch.where(y_train == 17)[0]
        class_18_indices = torch.where(y_train== 18)[0]
        class_19_indices = torch.where(y_train== 19)[0]
        class_3_indices = torch.where(y_train== 3)[0]
        class_16_indices = torch.where(y_train== 16)[0]
        class_2_indices = torch.where(y_train==2)[0]
        class_0_indices = torch.where(y_train==0)[0]

        class_1_samples = X_train[class_1_indices]
        class_17_samples = X_train[class_17_indices]
        class_18_samples = X_train[class_18_indices]
        class_19_samples = X_train[class_19_indices]
        class_3_samples = X_train[class_3_indices]
        class_16_samples = X_train[class_16_indices]
        class_2_samples = X_train[class_2_indices]
        class_0_samples = X_train[class_0_indices]

        count_1= (50 - class_1_samples.size()[0]) // class_1_samples.size()[0] + 1
        for i in range(count_1):
            X_train = torch.cat((X_train, class_1_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_1_samples.size()[0],1 ), 1)), dim=0)

        count_17= (50 - class_17_samples.size()[0]) // class_17_samples.size()[0] + 1
        for i in range(count_17):
            X_train = torch.cat((X_train, class_17_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_17_samples.size()[0], 1), 17)), dim=0)

        count_18= (50 - class_18_samples.size()[0]) // class_18_samples.size()[0] + 1
        for i in range(count_18):
            X_train = torch.cat((X_train, class_18_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_18_samples.size()[0], 1), 18)), dim=0)

        count_19= (50 - class_19_samples.size()[0]) // class_19_samples.size()[0] + 1
        for i in range(count_19):
            X_train = torch.cat((X_train, class_19_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_19_samples.size()[0], 1), 19)), dim=0)

        count_3= (50 - class_3_samples.size()[0]) // class_3_samples.size()[0] + 1
        for i in range(count_3):
            X_train = torch.cat((X_train, class_3_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_3_samples.size()[0], 1), 3)), dim=0)

        count_16= (50 - class_16_samples.size()[0]) // class_16_samples.size()[0] + 1
        for i in range(count_16):
            X_train = torch.cat((X_train, class_16_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_16_samples.size()[0], 1), 16)), dim=0)

        count_2= (50 - class_2_samples.size()[0]) // class_2_samples.size()[0] + 1
        for i in range(count_2):
            X_train = torch.cat((X_train, class_2_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_2_samples.size()[0], 1), 2)), dim=0)

        count_0= (50 - class_0_samples.size()[0]) // class_0_samples.size()[0] + 1
        for i in range(count_0):
            X_train = torch.cat((X_train, class_0_samples), dim=0)
            y_train = torch.cat((y_train, torch.full((class_0_samples.size()[0], 1), 0)), dim=0)

        class_4_indices = torch.where(y_train==4)[0]
        class_4_samples = X_train[class_4_indices]

        shuffled_indices = torch.randperm(class_4_samples.shape[0])
        class_4_samples = class_4_samples[shuffled_indices]

        for n in range(5):
            X_train = torch.cat((X_train, class_4_samples[n].unsqueeze(0)), dim=0)

        y_train = torch.cat((y_train, torch.full((5, 1), 4)), dim=0)



        x_split, y_split = fl_split(split, X_train, y_train)
        for X_train, y_train in zip(x_split, y_split):
            ros = RandomOverSampler(random_state=42)
            X_train, y_train = ros.fit_resample(X_train, y_train)

            global_tree = construct_tree(X_train, y_train, learning_rate=0.1, max_depth=6,n_estimators=40,
               min_child_weight=1,subsample=1,colsample_bytree=1, gamma=0, task_type='multiclass')
            # lr was 0.3

            y_pred = global_tree.predict(X_test)
            y_default_test = np.where(y_test > 9, 1, 0)
            y_default_pred = np.where(y_pred > 9, 1, 0)

            auc = roc_auc_score(y_default_test, y_default_pred)
            mse = mean_squared_error(y_test, y_pred)

            aucs.append(auc)
            mses.append(mse)

    end = datetime.datetime.now()
    time = end - start

    print('split',split)
    print('Training time: ', time)
    print('mse',mses)
    print('auc',aucs)


In [15]:
split =[1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [1]
Training time:  0:01:00.534214
mse [2.6635859519408505, 2.1072088724584104, 2.831792975970425, 2.8518518518518516, 2.0592592592592593]
auc [0.9148872180451127, 0.8923707957342084, 0.8816678058783323, 0.8652516327314637, 0.911311124526645]


In [16]:
split =[0.5, 0.5]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.5, 0.5]
Training time:  0:01:59.520033
mse [2.833641404805915, 3.4750462107208873, 2.9168207024029575, 2.9353049907578557, 3.144177449168207, 2.9944547134935307, 3.0277777777777777, 4.0018518518518515, 3.024074074074074, 3.096296296296296]
auc [0.8891866028708135, 0.8502665755297335, 0.8723680612523926, 0.8833128247197156, 0.8797265892002734, 0.8706356801093642, 0.8596125349871029, 0.8501591570166291, 0.8647028154327425, 0.8906344327973218]


In [17]:
split =[0.6, 0.4]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.4]
Training time:  0:01:04.716421
mse [2.811460258780037, 3.3475046210720887, 2.536044362292052, 2.9981515711645104, 3.2162661737523104, 3.3031423290203326, 2.8703703703703702, 3.7314814814814814, 2.714814814814815, 3.261111111111111]
auc [0.8931305536568693, 0.8836705399863294, 0.8965203718895269, 0.8756904566584632, 0.8670608339029391, 0.8612987012987013, 0.8781351188189451, 0.8350666813017947, 0.8983727567092916, 0.8706712035563361]


In [18]:
split =[0.8, 0.2]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.2]
Training time:  0:01:33.227374
mse [2.512014787430684, 4.419593345656192, 2.186691312384473, 4.105360443622921, 2.9205175600739373, 4.197781885397412, 2.85, 4.383333333333334, 2.0462962962962963, 4.007407407407407]
auc [0.8890635680109364, 0.8425632262474368, 0.8982567678424939, 0.8646704949412086, 0.8830553656869446, 0.8442583732057416, 0.8687366225783436, 0.8116184622139289, 0.9019675100159157, 0.8358350255200044]


In [19]:
split = [0.34, 0.33, 0.33]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.34, 0.33, 0.33]
Training time:  0:00:45.625201
mse [3.656192236598891, 3.288354898336414, 3.502772643253235, 3.4473197781885396, 3.2735674676524953, 3.033271719038817, 4.2014787430683915, 3.6913123844731976, 3.611829944547135, 3.9518518518518517, 4.131481481481481, 4.62962962962963, 2.8925925925925924, 3.312962962962963, 2.9407407407407407]
auc [0.874395078605605, 0.8671223513328776, 0.8574777853725222, 0.8884468143286848, 0.8759912496581899, 0.8774268526114302, 0.8408065618591936, 0.8576623376623377, 0.8559056732740944, 0.8370561440096592, 0.8298666373964108, 0.8092997091268316, 0.8833351627243291, 0.8720569672356073, 0.8571291367103893]


In [22]:
split = [0.6, 0.2, 0.2]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.2, 0.2]
Training time:  0:01:40.775436
mse [2.898336414048059, 3.709796672828096, 4.295748613678374, 2.536044362292052, 3.9926062846580406, 4.151571164510166, 3.2957486136783736, 4.524953789279113, 4.316081330868761, 2.8407407407407406, 4.309259259259259, 4.687037037037037, 2.9166666666666665, 4.611111111111111, 3.901851851851852]
auc [0.8802187286397811, 0.8728844839371155, 0.8535953520164047, 0.9016543614984961, 0.8574241181296144, 0.8610473065354115, 0.8687559808612441, 0.859788106630212, 0.8445044429254955, 0.8838290982931782, 0.8534246199440207, 0.8022748477031997, 0.8983178749794193, 0.8331321003238021, 0.8468936940892376]


In [24]:
split = [0.8, 0.1, 0.1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.1, 0.1]
Training time:  0:00:40.394809
mse [2.6155268022181146, 6.282809611829944, 5.240295748613678, 2.1404805914972274, 5.6044362292051755, 5.0499075785582255, 2.9537892791127542, 5.654343807763401, 4.828096118299445, 2.738888888888889, 5.148148148148148, 6.127777777777778, 2.0148148148148146, 4.825925925925926, 4.538888888888889]
auc [0.889002050580998, 0.8204374572795625, 0.8392344497607656, 0.8997675690456659, 0.8192302433688816, 0.8331966092425487, 0.8851811346548188, 0.816678058783322, 0.8337183868762816, 0.8613824707754789, 0.8098485264255529, 0.8025492563525602, 0.9057817902420284, 0.8242275396520498, 0.8445749410021404]


In [20]:
split =[0.2, 0.2, 0.2, 0.2, 0.2]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.2, 0.2, 0.2, 0.2, 0.2]
Training time:  0:00:42.391626
mse [3.7486136783733826, 4.245841035120148, 3.7282809611829943, 3.55452865064695, 4.306839186691312, 3.990757855822551, 3.4953789279112755, 4.404805914972274, 3.9926062846580406, 4.129390018484289, 4.818853974121996, 5.280961182994455, 4.166358595194085, 4.524953789279113, 3.9038817005545288, 5.261111111111111, 4.6722222222222225, 5.187037037037037, 3.9277777777777776, 4.159259259259259, 3.8222222222222224, 3.574074074074074, 3.675925925925926, 4.433333333333334, 4.181481481481481]
auc [0.8617293233082707, 0.8503896103896105, 0.8523308270676692, 0.8729460013670539, 0.8500205058099796, 0.84361498496035, 0.8691960623461854, 0.8499521465682254, 0.8574241181296144, 0.8611225047853432, 0.8479562542720438, 0.8466917293233083, 0.8281408065618592, 0.859788106630212, 0.8478947368421053, 0.7904478349157564, 0.824721475220899, 0.7948658141704626, 0.8481148125788925, 0.8076944185280721, 0.8588990

In [23]:
split =[0.6, 0.1, 0.1, 0.1, 0.1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.1, 0.1, 0.1, 0.1]
Training time:  0:01:30.092586
mse [2.913123844731978, 5.32532347504621, 5.308687615526802, 6.282809611829944, 4.868761552680222, 2.5378927911275415, 5.11275415896488, 5.3345656192236595, 5.550831792975971, 5.0499075785582255, 3.086876155268022, 5.979667282809612, 4.717190388170056, 5.8317929759704255, 4.828096118299445, 2.914814814814815, 6.114814814814815, 5.7592592592592595, 5.148148148148148, 6.127777777777778, 2.6574074074074074, 6.303703703703704, 6.529629629629629, 4.825925925925926, 5.148148148148148]
auc [0.8894326725905674, 0.8440396445659604, 0.8237662337662337, 0.8204374572795625, 0.8316541353383459, 0.9018047579983594, 0.8357601859447634, 0.8237626469783976, 0.8302502050861362, 0.8331966092425487, 0.8707587149692413, 0.821948051948052, 0.8286944634313056, 0.8147983595352016, 0.8337183868762816, 0.8797952911475769, 0.8352313264914111, 0.7846989737116513, 0.8098485264255529, 0.8025492563525602, 0.9020223

In [21]:
split = [0.8, 0.05, 0.05, 0.05, 0.05]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.05, 0.05, 0.05, 0.05]
Training time:  0:00:56.680693
mse [2.7837338262476896, 11.645101663585953, 10.475046210720887, 8.950092421441774, 7.55268022181146, 2.3160813308687613, 7.99815157116451, 10.179297597042513, 7.963031423290204, 8.66173752310536, 2.9408502772643255, 8.850277264325323, 9.360443622920517, 9.968576709796674, 8.036968576709796, 2.9296296296296296, 8.264814814814814, 8.190740740740742, 8.792592592592593, 9.537037037037036, 1.9592592592592593, 8.396296296296295, 7.968518518518518, 6.572222222222222, 5.85]
auc [0.8836705399863294, 0.7458646616541353, 0.751995898838004, 0.7895898838004102, 0.7744292549555707, 0.9019551544982226, 0.7429997265518186, 0.8088118676510802, 0.8169674596663933, 0.8018731200437518, 0.8793574846206426, 0.7838892686261107, 0.7921189336978811, 0.752734107997266, 0.8172317156527683, 0.8669666867899676, 0.7654492069590034, 0.7282393941057022, 0.7164672630481312, 0.736869546128094, 0.9019126282860436,

In [25]:
split = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
Training time:  0:00:42.736208
mse [4.565619223659889, 5.844731977818854, 4.959334565619224, 5.55822550831793, 4.728280961182994, 5.354898336414048, 5.384473197781886, 5.382624768946395, 6.282809611829944, 5.2865064695009245, 5.787430683918669, 5.410351201478743, 5.693160813308688, 5.188539741219963, 5.157116451016636, 4.959334565619224, 5.11275415896488, 5.3345656192236595, 5.693160813308688, 5.0499075785582255, 5.809611829944547, 5.280961182994455, 6.2255083179297594, 5.2199630314232905, 5.404805914972274, 5.377079482439926, 5.979667282809612, 4.717190388170056, 5.656192236598891, 4.828096118299445, 6.227777777777778, 6.8277777777777775, 6.985185185185185, 5.916666666666667, 5.574074074074074, 5.866666666666666, 5.5018518518518515, 6.203703703703703, 5.261111111111111, 6.127777777777778, 5.688888888888889, 6.074074074074074, 5.207407407407407, 5.333333333333333, 4.555555555555555, 5.48703

In [26]:
split = [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
Training time:  0:01:30.793916
mse [3.0055452865064693, 7.805914972273568, 8.719038817005545, 7.67467652495379, 8.876155268022181, 7.502772643253235, 11.645101663585953, 10.475046210720887, 8.950092421441774, 7.271719038817006, 2.4177449168207024, 7.44547134935305, 9.975970425138632, 6.9741219963031424, 7.711645101663586, 7.800369685767098, 7.99815157116451, 7.602587800369686, 7.963031423290204, 8.66173752310536, 3.120147874306839, 9.062846580406655, 7.66728280961183, 7.377079482439926, 7.3345656192236595, 6.850277264325324, 8.850277264325323, 9.360443622920517, 9.968576709796674, 8.036968576709796, 2.9092592592592594, 7.192592592592592, 8.87037037037037, 7.853703703703704, 9.783333333333333, 9.11111111111111, 7.683333333333334, 8.190740740740742, 8.792592592592593, 9.537037037037036, 2.716666666666667, 7.07037037037037, 7.618518518518519, 8.72037037037037, 6.587037037037037, 7.231

In [27]:
split = [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
Training time:  0:01:12.116310
mse [2.523105360443623, 13.286506469500925, 14.471349353049908, 12.950092421441774, 13.484288354898336, 11.796672828096119, 10.571164510166358, 9.0, 9.33271719038817, 10.907578558225508, 2.22365988909427, 8.399260628465804, 13.98336414048059, 11.7818853974122, 12.744916820702404, 11.11275415896488, 15.425138632162662, 8.23475046210721, 10.861367837338262, 7.707948243992607, 3.157116451016636, 8.940850277264325, 12.595194085027726, 13.595194085027726, 14.1090573012939, 10.968576709796674, 12.090573012939002, 9.11645101663586, 10.351201478743068, 14.353049907578558, 2.7925925925925927, 8.622222222222222, 10.62037037037037, 9.57037037037037, 10.603703703703705, 11.622222222222222, 10.533333333333333, 14.372222222222222, 14.251851851851852, 9.08148148148148, 1.9796296296296296, 11.738888888888889, 12.075925925925926, 12.424074074074074, 10.035185185185185