In [5]:
# @title Import Libraries
import pandas as pd
import time
from realtabformer import REaLTabFormer
from transformers import GPT2Config
from src.data_processing import csv_data_split

In [6]:
# Splits the data into train test and sample
train_data, test_data, sample_data = csv_data_split("../data/breast-cancer-wisconsin.csv")
train_data

Unnamed: 0,ID,CT,UCSi,UCSh,Madh,SECS,BN,BC,NN,Mi,Class
198,1017061,1,1,1,1,0,1,3,1,1,0
359,501111,5,1,1,0,0,1,0,1,1,0
481,1181567,1,1,1,1,1,1,1,1,1,0
125,1177007,3,1,1,1,0,1,3,1,1,0
598,1016631,0,3,1,1,0,1,0,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...
485,1001565,6,1,1,1,0,1,3,1,1,0
260,301107,10,8,8,0,3,1,8,7,8,1
364,657753,3,1,1,1,3,1,0,0,1,0
623,1077790,5,1,1,3,0,1,1,1,1,0


In [7]:
# Keeps track of the time taken to train the model
results = pd.DataFrame(columns=["Model", "Time (s)"])

In [8]:
# function to train the models
def fit_and_track(model, data, model_name):
    start_time = time.time()

    model.fit(data,num_bootstrap=20,target_col="Class")

    end_time = time.time()
    elapsed_time = end_time - start_time


    print(f"Model: {model_name}")


    results.loc[len(results)] = [
        model_name,
        elapsed_time
    ]

    model.save(f"../models/{model_name}")

In [9]:
#Small configuration
config_small = GPT2Config(
    n_embd=512,
    n_layer=4,
    n_head=8
)
rtf_model_small = REaLTabFormer(
    model_type="tabular",
    tabular_config=config_small,
    epochs=50,
    batch_size=8,
    mask_rate=0.15
)

#Large configuration 
config_large = GPT2Config()
rtf_model_large = REaLTabFormer(
    model_type="tabular",
    tabular_config=config_large,
    epochs=50,
    batch_size=8,
    mask_rate=0.15
)

# Regular configuration
rtf_model_reg = REaLTabFormer(
    model_type="tabular",
    epochs=50,
    batch_size=8,
    mask_rate=0.15
)


# Fit models and track performance
# fit_and_track(rtf_model_small, train_data, "rtf_small_test")
# fit_and_track(rtf_model_reg, train_data, "rtf_regular")
# fit_and_track(rtf_model_large, train_data, "rtf_large")


