In [None]:
# load packages
import pandas as pd
import numpy as np
import random
import shap
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
logging.info("Importing modules")

In [None]:
from utils import *
from models import *
from datasets import *

In [None]:
import warnings
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
# Prepare the data
# Adjust your path here
quality_df_dir = './results/quality_scores_per_subject.csv'
features_dir = "dataset_sample/features_df/"
info_dir = "dataset_sample/participant_info.csv"

In [None]:
logging.info("Preparing the data")
clean_df, new_features, good_quality_sids = data_preparation(
    threshold = 0.2, 
    quality_df_dir = quality_df_dir,
    features_dir = features_dir,
    info_dir = info_dir)

In [None]:
logging.info("Splitting data into train, validation, and test sets")
SW_df, final_features = split_data(clean_df, good_quality_sids, new_features)

In [None]:
random.seed(0)
train_sids = random.sample(good_quality_sids, 56)
remaining_sids = [subj for subj in good_quality_sids if subj not in train_sids]
val_sids = random.sample(remaining_sids, 8)
test_sids = [subj for subj in remaining_sids if subj not in val_sids]

In [None]:
group_variables = ["AHI_Severity", "Obesity"]
# when idx == 0, it returns ['AHI_Severity'], the first variable in the list
# when idx == 1, it returns ['Obesity'], the second variable in the list
group_variable = get_variable(group_variables, idx=0)

In [None]:
X_train, y_train, group_train = train_test_split(SW_df, train_sids, final_features, group_variable)
X_val, y_val, group_val = train_test_split(SW_df, val_sids, final_features, group_variable)
X_test, y_test, group_test = train_test_split(SW_df, test_sids, final_features, group_variable)

In [None]:
logging.info("Resampling the training data")
X_train_resampled, y_train_resampled, group_train_resampled = resample_data(X_train, y_train, group_train, group_variable)

In [None]:
logging.info("Running LightGBM model")
final_lgb_model = LightGBM_engine(X_train_resampled, y_train_resampled, X_val, y_val)

In [None]:
logging.info("Calculating training scores for LightGBM model")
prob_ls_train, len_train, true_ls_train = compute_probabilities(
    train_sids, SW_df, final_features, "lgb", final_lgb_model, group_variable)
lgb_train_results_df = LightGBM_result(final_lgb_model, X_train, y_train, prob_ls_train, true_ls_train)

In [None]:
logging.info("Calculating testing scores for LightGBM model")
prob_ls_test, len_test, true_ls_test = compute_probabilities(
    test_sids, SW_df, final_features, "lgb", final_lgb_model, group_variable)
lgb_test_results_df = LightGBM_result(final_lgb_model, X_test, y_test, prob_ls_test, true_ls_test)

In [None]:
logging.info("Identifying best features using SHAP")
explainer = shap.TreeExplainer(final_lgb_model)
shap_values = explainer.shap_values(X_train)
# shap.summary_plot(shap_values, X_train, plot_type="bar", feature_names=final_features)

logging.info("Creating train data for LSTM")
dataloader_train = LSTM_dataloader(
    prob_ls_train, len_train, true_ls_train, batch_size=32
)

logging.info("Running LSTM model")
LSTM_model = LSTM_engine(dataloader_train, num_epoch=300, hidden_layer_size=32, learning_rate=0.001)

logging.info("Testing LSTM model")
dataloader_test = LSTM_dataloader(
    prob_ls_test, len_test, true_ls_test, batch_size=1
)

lgb_lstm_test_results_df = LSTM_eval(LSTM_model, dataloader_test, true_ls_test, 'LightGBM_LSTM')

In [None]:
logging.info("Creating Transformer dataset for LightGBM")
dataloader_train = Transformer_dataloader(prob_ls_train, len_train, true_ls_train, batch_size=16)
dataloader_test = Transformer_dataloader(prob_ls_test, len_test, true_ls_test, batch_size=1)

In [None]:
logging.info("Running Transformer model for LightGBM post-processing")
transformer_model = Transformer_engine(dataloader_train, num_epoch=300)
lgb_transformer_test_results_df = Transformer_eval(transformer_model, dataloader_test, true_ls_test, 'LightGBM_Transformer')

logging.info("Running GPBoost model")
final_gpb_model = GPBoost_engine(X_train_resampled, group_train_resampled, y_train_resampled, X_val, y_val, group_val)

logging.info("Calculating training scores for GPBoost model")
prob_ls_train, len_train, true_ls_train = compute_probabilities(
    train_sids, SW_df, final_features, 'gpb', final_gpb_model, group_variable)
gpb_train_results_df = GPBoost_result(final_gpb_model, X_train, y_train, group_train, prob_ls_train, true_ls_train)

logging.info("Calculating testing scores for GPBoost model")
prob_ls_test, len_test, true_ls_test = compute_probabilities(
    test_sids, SW_df, final_features, 'gpb', final_gpb_model, group_variable)
gpb_test_results_df = GPBoost_result(final_gpb_model, X_test, y_test, group_test, prob_ls_test, true_ls_test)

logging.info("Creating LSTM dataset for GPBoost")
dataloader_train = LSTM_dataloader(
    prob_ls_train, len_train, true_ls_train, batch_size=32
)
dataloader_test = LSTM_dataloader(
    prob_ls_test, len_test, true_ls_test, batch_size=1
)

logging.info("Running LSTM model for GPBoost")
LSTM_model = LSTM_engine(dataloader_train, num_epoch=300, hidden_layer_size=32, learning_rate = 0.001)
gpb_lstm_test_results_df = LSTM_eval(LSTM_model, dataloader_test, true_ls_test, 'GPBoost_LSTM')

logging.info("Running Transformer model for GPBoost post-processing")
transformer_model = Transformer_engine(
    dataloader_train, 
    num_epoch=300, 
    d_model=128,
    nhead=4,
    num_layers=2,
    learning_rate=0.001,
    accumulation_steps=8
)
gpb_transformer_test_results_df = Transformer_eval(transformer_model, dataloader_test, true_ls_test, 'GPBoost_Transformer')

In [None]:
# overall result
overall_result = pd.concat([
    lgb_test_results_df, 
    # lgb_lstm_test_results_df,
    lgb_transformer_test_results_df,
    # gpb_test_results_df, 
    # gpb_lstm_test_results_df,
    # gpb_transformer_test_results_df
])
print(group_variable)
print(overall_result)