In [1]:
# @title Import Libraries
import pandas as pd
import time
from realtabformer import REaLTabFormer
from transformers import GPT2Config
from src.data_processing import csv_data_split

In [2]:
train_data, test_data, sample_data = csv_data_split("../data/breast-cancer-wisconsin.csv")
train_data

Unnamed: 0,ID,CT,UCSi,UCSh,Madh,SECS,BN,BC,NN,Mi,Class
198,1017061,1,1,1,1,0,1,3,1,1,0
359,501111,5,1,1,0,0,1,0,1,1,0
481,1181567,1,1,1,1,1,1,1,1,1,0
125,1177007,3,1,1,1,0,1,3,1,1,0
598,1016631,0,3,1,1,0,1,0,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...
485,1001565,6,1,1,1,0,1,3,1,1,0
260,301107,10,8,8,0,3,1,8,7,8,1
364,657753,3,1,1,1,3,1,0,0,1,0
623,1077790,5,1,1,3,0,1,1,1,1,0


In [3]:
results = pd.DataFrame(columns=["Model", "Time (s)"])
results

Unnamed: 0,Model,Time (s)


In [4]:
def fit_and_track(model, data, model_name):
    start_time = time.time()

    model.fit(data,num_bootstrap=20,target_col="Class")

    end_time = time.time()
    elapsed_time = end_time - start_time


    print(f"Model: {model_name}")


    results.loc[len(results)] = [
        model_name,
        elapsed_time
    ]

    model.save(f"../models/{model_name}")

In [5]:
config_small = GPT2Config(
    n_embd=512,
    n_layer=4,
    n_head=8
)

config_large = GPT2Config()

rtf_model_small = REaLTabFormer(
    model_type="tabular",
    tabular_config=config_small,
    epochs=50,
    mask_rate=0.15
)

rtf_model_reg = REaLTabFormer(
    model_type="tabular",
    epochs=50,
    mask_rate=0.15
)

rtf_model_large = REaLTabFormer(
    model_type="tabular",
    tabular_config=config_large,
    epochs=50,
    mask_rate=0.15
)

# Fit models and track performance
fit_and_track(rtf_model_small, train_data, "rtf_small")
fit_and_track(rtf_model_reg, train_data, "rtf_regular")
fit_and_track(rtf_model_large, train_data, "rtf_large")


# Display the results
results

Computing the sensitivity threshold...
Using parallel computation!!!




Bootstrap round:   0%|          | 0/20 [00:00<?, ?it/s]

Sensitivity threshold summary:
count    20.000000
mean     -0.001361
std       0.011289
min      -0.018333
25%      -0.009722
50%      -0.001111
75%       0.006806
max       0.022222
dtype: float64
Sensitivity threshold: 0.014833333333333339 qt_max: 0.05


Map:   0%|          | 0/546 [00:00<?, ? examples/s]

Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 5,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.015481481481481478,                             val_sensitivities: [-0.01666666666666667, -0.02388888888888889, -0.019444444444444445, -0.016666666666666666, -0.005555555555555556, -0.025, -0.020555555555555556, -0.005000000000000001, -0.011666666666666667, -0.011111111111111112, -0.005, -0.015, -0.020555555555555556, -0.016666666666666666, -0.019444444444444445]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 10,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.017370370370370373,                             val_sensitivities: [-0.018333333333333333, -0.018333333333333333, -0.01611111111111111, -0.021666666666666667, -0.015, -0.02388888888888889, -0.020555555555555556, -0.007777777777777779, -0.015000000000000001, -0.015555555555555557, -0.0022222222222222235, -0.02388888888888889, -0.01777777777777778, -0.021666666666666667, -0.02277777777777778]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 15,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.007407407407407408,                             val_sensitivities: [-0.012777777777777779, -0.018333333333333333, -0.012222222222222223, -0.01611111111111111, -0.0027777777777777783, -0.017222222222222222, -0.01, 0.01888888888888889, 0.008333333333333333, 0.01611111111111111, 0.006666666666666666, -0.017222222222222222, -0.01666666666666667, -0.017777777777777778, -0.02]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 20,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.015148148148148145,                             val_sensitivities: [-0.011666666666666667, -0.02388888888888889, -0.014444444444444446, -0.014444444444444446, 0.007222222222222221, -0.02388888888888889, -0.02, -0.004444444444444445, -0.009444444444444445, -0.014444444444444446, -0.003333333333333333, -0.025, -0.020555555555555556, -0.02388888888888889, -0.025]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 25,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.014111111111111112,                             val_sensitivities: [-0.01388888888888889, -0.020555555555555556, -0.021111111111111112, -0.015555555555555555, -0.015, -0.02277777777777778, -0.013333333333333334, -0.006111111111111111, -0.0011111111111111118, 0.0022222222222222214, -0.005555555555555556, -0.020555555555555556, -0.019444444444444445, -0.020555555555555556, -0.018333333333333333]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 30,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.005296296296296296,                             val_sensitivities: [-0.011111111111111112, -0.018333333333333333, -0.011111111111111112, -0.011111111111111112, 0.009444444444444445, -0.02277777777777778, -0.008333333333333333, 0.011111111111111112, 0.0022222222222222227, 0.0061111111111111106, 0.02, -0.006666666666666668, -0.012222222222222223, -0.014444444444444444, -0.012222222222222221]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 35,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.008518518518518517,                             val_sensitivities: [-0.011111111111111112, -0.013888888888888888, -0.01388888888888889, -0.012777777777777779, -0.0027777777777777775, -0.025, -0.011666666666666667, 0.0077777777777777776, 0.0, -0.003333333333333334, 0.012777777777777777, -0.01611111111111111, -0.011111111111111112, -0.012777777777777779, -0.013888888888888888]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 40,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.008,                             val_sensitivities: [-0.0061111111111111106, -0.012777777777777777, -0.009444444444444445, -0.005555555555555556, 0.008888888888888889, -0.021666666666666667, -0.01388888888888889, 0.001666666666666667, 0.0005555555555555552, -0.006666666666666666, 0.0027777777777777783, -0.013333333333333332, -0.018333333333333333, -0.012222222222222223, -0.013888888888888888]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 45,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.009518518518518518,                             val_sensitivities: [-0.0077777777777777776, -0.014444444444444444, -0.010555555555555556, -0.006666666666666667, 0.005000000000000001, -0.021111111111111112, -0.013333333333333334, -0.0016666666666666679, -0.0061111111111111106, -0.0016666666666666679, -0.0011111111111111118, -0.020555555555555556, -0.01611111111111111, -0.010555555555555556, -0.01611111111111111]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 50,                     sensitivity_threshold: 0.014833333333333339,                         val_sensitivity: -0.012851851851851852,                             val_sensitivities: [-0.01611111111111111, -0.021666666666666667, -0.014444444444444444, -0.013888888888888888, -0.005000000000000001, -0.025, -0.01611111111111111, -0.0022222222222222227, -0.006666666666666668, -0.008888888888888889, 0.005555555555555555, -0.017777777777777778, -0.012777777777777779, -0.017222222222222222, -0.020555555555555556]
Model: rtf_small
Copying artefacts from: best-disc-model
Copying artefacts from: mean-best-disc-model
Copying artefacts from: not-best-disc-model
Copying artefacts from: last-epoch-model
Computing the sensitivity threshold...
Using parallel computation!!!




Bootstrap round:   0%|          | 0/20 [00:00<?, ?it/s]

Sensitivity threshold summary:
count    20.000000
mean      0.004500
std       0.013377
min      -0.021667
25%      -0.004861
50%       0.002500
75%       0.013889
max       0.032222
dtype: float64
Sensitivity threshold: 0.02694444444444445 qt_max: 0.05


Map:   0%|          | 0/546 [00:00<?, ? examples/s]

Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 5,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: -0.01114814814814815,                             val_sensitivities: [-0.0005555555555555565, -0.016666666666666666, -0.008333333333333333, -0.006666666666666666, 0.004999999999999999, -0.020555555555555556, -0.012777777777777779, -0.0022222222222222222, -0.01, -0.0038888888888888888, -0.0027777777777777775, -0.020555555555555556, -0.018333333333333333, -0.02388888888888889, -0.025]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 10,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: -0.0035185185185185193,                             val_sensitivities: [0.0005555555555555557, -0.018333333333333333, -0.006666666666666668, -0.012777777777777779, 0.012222222222222221, -0.02388888888888889, -0.007777777777777779, 0.007777777777777778, -0.0027777777777777783, 0.0027777777777777775, 0.020555555555555556, -0.009444444444444445, -0.005000000000000001, -0.0016666666666666679, -0.008333333333333333]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 15,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: -0.007814814814814814,                             val_sensitivities: [-0.0027777777777777775, -0.012777777777777777, -0.015, -0.008333333333333333, 0.008888888888888889, -0.02388888888888889, -0.018333333333333333, -0.01, -0.016666666666666666, -0.01, 0.012777777777777777, -0.007222222222222222, -0.006111111111111112, -0.007222222222222222, -0.0005555555555555557]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 20,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: -0.011259259259259257,                             val_sensitivities: [-0.011666666666666667, -0.017222222222222222, -0.012777777777777779, -0.01388888888888889, 0.0005555555555555544, -0.018333333333333333, -0.011666666666666667, -0.0005555555555555561, -0.005555555555555555, -0.0077777777777777776, 0.0011111111111111113, -0.020555555555555556, -0.012777777777777779, -0.01611111111111111, -0.021666666666666667]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 25,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: -0.002148148148148148,                             val_sensitivities: [-0.0061111111111111106, -0.015, -0.006666666666666666, -0.006111111111111111, 0.010555555555555556, -0.008333333333333333, -0.0038888888888888896, 0.010555555555555556, 0.011111111111111112, 0.009444444444444445, 0.008888888888888887, -0.0022222222222222227, -0.015555555555555557, -0.007222222222222222, -0.011666666666666665]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 30,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: 0.004703703703703704,                             val_sensitivities: [0.010555555555555556, 0.0033333333333333322, 0.0077777777777777776, 0.008333333333333333, 0.043888888888888894, -0.01611111111111111, -0.011111111111111112, 0.0033333333333333327, -0.0038888888888888896, -0.007222222222222222, 0.024444444444444442, 0.008888888888888889, -0.0027777777777777783, 0.006111111111111112, -0.005000000000000001]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 35,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: 0.005851851851851852,                             val_sensitivities: [0.0005555555555555557, -0.01, 0.006666666666666666, 0.009444444444444443, 0.024444444444444446, -0.008333333333333333, 0.0, 0.0077777777777777776, 0.011111111111111112, 0.01388888888888889, 0.03666666666666667, 0.0044444444444444444, -0.007222222222222224, 0.0033333333333333322, -0.004999999999999999]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 40,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: 0.022259259259259263,                             val_sensitivities: [0.03055555555555555, 0.0038888888888888888, 0.021111111111111112, 0.018333333333333333, 0.04666666666666666, 0.000555555555555555, 0.021666666666666667, 0.04833333333333333, 0.03777777777777778, 0.029444444444444447, 0.03666666666666667, 0.01, 0.01, 0.012222222222222221, 0.006666666666666666]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 45,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: 0.01877777777777778,                             val_sensitivities: [0.025555555555555554, -0.0016666666666666666, 0.02333333333333333, 0.01, 0.051111111111111114, -0.01, 0.004444444444444444, 0.031111111111111114, 0.01111111111111111, 0.025555555555555554, 0.06055555555555556, 0.008333333333333331, 0.012777777777777779, 0.02388888888888889, 0.005555555555555555]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 50,                     sensitivity_threshold: 0.02694444444444445,                         val_sensitivity: 0.02196296296296296,                             val_sensitivities: [0.035, 0.01611111111111111, 0.03222222222222222, 0.035, 0.05666666666666666, -0.0016666666666666666, 0.0077777777777777776, 0.020555555555555556, 0.021666666666666667, 0.015000000000000001, 0.052777777777777785, 0.010555555555555554, 0.0077777777777777776, 0.013333333333333332, 0.006666666666666667]
Model: rtf_regular
Copying artefacts from: best-disc-model
Copying artefacts from: mean-best-disc-model
Copying artefacts from: not-best-disc-model
Copying artefacts from: last-epoch-model
Computing the sensitivity threshold...
Using parallel computation!!!




Bootstrap round:   0%|          | 0/20 [00:00<?, ?it/s]

Sensitivity threshold summary:
count    20.000000
mean      0.005917
std       0.014750
min      -0.016111
25%      -0.006806
50%       0.005833
75%       0.015000
max       0.039444
dtype: float64
Sensitivity threshold: 0.02466666666666668 qt_max: 0.05


Map:   0%|          | 0/546 [00:00<?, ? examples/s]

Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 5,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: -0.01474074074074074,                             val_sensitivities: [-0.02, -0.025, -0.021111111111111112, -0.020555555555555556, -0.012777777777777779, -0.02277777777777778, -0.015000000000000001, 0.0027777777777777775, -0.007222222222222224, 0.002222222222222222, -0.0022222222222222227, -0.021666666666666667, -0.01666666666666667, -0.019444444444444445, -0.021666666666666667]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 10,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: -0.014518518518518517,                             val_sensitivities: [-0.015000000000000001, -0.02388888888888889, -0.018333333333333333, -0.017222222222222222, -0.006666666666666666, -0.02277777777777778, -0.014444444444444446, -8.673617379884035e-19, -0.008333333333333333, 0.0011111111111111105, -0.008888888888888889, -0.025, -0.020555555555555556, -0.01611111111111111, -0.021666666666666667]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 15,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: -0.007703703703703705,                             val_sensitivities: [-0.01666666666666667, -0.02277777777777778, -0.02, -0.015, -0.0033333333333333322, -0.017222222222222222, -0.008888888888888889, 0.011111111111111112, 0.0027777777777777766, 0.0038888888888888883, 0.013333333333333334, -0.008333333333333333, -0.01388888888888889, -0.009444444444444445, -0.011111111111111112]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 20,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: -0.007629629629629629,                             val_sensitivities: [-0.005555555555555557, -0.01888888888888889, -0.011666666666666667, -0.008888888888888889, 0.0066666666666666645, -0.019444444444444445, -0.011111111111111112, -0.0005555555555555557, 0.0022222222222222214, 0.002777777777777777, 0.008333333333333333, -0.014444444444444444, -0.015555555555555557, -0.012222222222222221, -0.01611111111111111]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 25,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: -0.0029629629629629637,                             val_sensitivities: [0.0044444444444444444, -0.0022222222222222227, -0.0022222222222222235, 0.003888888888888888, 0.03277777777777778, -0.018333333333333333, -0.014444444444444446, 0.0011111111111111105, 0.0011111111111111105, -0.005, 0.006666666666666666, -0.013888888888888888, -0.015000000000000001, -0.01, -0.013333333333333332]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 30,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: 0.004518518518518518,                             val_sensitivities: [-0.003333333333333334, -0.014444444444444446, -0.0016666666666666679, -0.002222222222222222, 0.013333333333333332, -0.012777777777777777, -0.003333333333333334, 0.01611111111111111, 0.011666666666666667, 0.01388888888888889, 0.049444444444444444, -0.001666666666666667, 0.001666666666666667, 0.0016666666666666655, -0.0005555555555555557]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 35,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: 0.015111111111111112,                             val_sensitivities: [0.006666666666666666, -0.0011111111111111115, -0.0027777777777777783, 0.009444444444444445, 0.017222222222222222, -0.006666666666666666, 0.0016666666666666666, 0.02611111111111111, 0.018888888888888893, 0.022222222222222223, 0.07, 0.016666666666666666, 0.021666666666666667, 0.02611111111111111, 0.0005555555555555557]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Saving not-best model...
Critic round: 40,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: 0.03240740740740741,                             val_sensitivities: [0.028333333333333332, 0.017222222222222222, 0.031111111111111114, 0.024999999999999998, 0.050555555555555555, 0.006666666666666666, 0.017777777777777778, 0.04666666666666667, 0.04, 0.04555555555555556, 0.052222222222222225, 0.04666666666666667, 0.015555555555555555, 0.03666666666666667, 0.026111111111111113]


Step,Training Loss


  0%|          | 0/270 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 384 samples generated. Sampling efficiency is: 100.0000%
Critic round: 45,                     sensitivity_threshold: 0.02466666666666668,                         val_sensitivity: 0.040185185185185185,                             val_sensitivities: [0.03833333333333334, 0.020555555555555556, 0.03111111111111111, 0.016666666666666663, 0.04777777777777778, 0.02277777777777778, 0.032777777777777774, 0.05388888888888889, 0.05444444444444445, 0.05888888888888889, 0.08555555555555557, 0.04444444444444445, 0.03111111111111111, 0.043333333333333335, 0.02111111111111111]
Stopping training, no improvement in critic...
Model: rtf_large
Copying artefacts from: best-disc-model
Copying artefacts from: mean-best-disc-model
Copying artefacts from: not-best-disc-model
Copying artefacts from: last-epoch-model


Unnamed: 0,Model,Time (s)
0,rtf_small,794.140514
1,rtf_regular,2143.99205
2,rtf_large,3969.005209
