In [1]:
import pandas as pd
import joblib
import altair as alt
import glob
import sys
import os
import numpy as np
from pathlib import Path

sys.path.append(os.path.abspath(".."))
from src.models.feat_selection import ImportanceFeatureSelector

  from .autonotebook import tqdm as notebook_tqdm


# Tables

This notebook contains all the code used to generate tables for the report, mainly the error metrics and class imbalance.

In [9]:
error_dicts = []
for model_name in ['gradient_boosting','random_forest','logistic_regression']:
    for file in Path('..').glob(f"results/*/{model_name}_scores.joblib"):
        error_dict = joblib.load(file)
        error_dict['Model'] =  model_name.replace('_', ' ').title()
        error_dict['Threshold'] = file.parent.name + '%'
        error_dicts.append(error_dict)

results_df = pd.DataFrame(error_dicts).sort_values(by=['Model','Threshold'])[['Model','Threshold','Accuracy','F1 Score','F2 Score','AUC','AP']]

In [10]:
results_df.query('Threshold == "50%"').drop(columns='Threshold')

Unnamed: 0,Model,Accuracy,F1 Score,F2 Score,AUC,AP
0,Gradient Boosting,0.934,0.058,0.045,0.599,0.072
8,Logistic Regression,0.597,0.122,0.228,0.59,0.083
4,Random Forest,0.78,0.124,0.192,0.545,0.061


In [11]:
results_df.query('Threshold == "60%"').drop(columns='Threshold')

Unnamed: 0,Model,Accuracy,F1 Score,F2 Score,AUC,AP
1,Gradient Boosting,0.875,0.121,0.102,0.601,0.131
9,Logistic Regression,0.604,0.224,0.362,0.638,0.163
5,Random Forest,0.737,0.258,0.361,0.676,0.157


In [12]:
results_df.query('Threshold == "70%"').drop(columns='Threshold')

Unnamed: 0,Model,Accuracy,F1 Score,F2 Score,AUC,AP
3,Gradient Boosting,0.784,0.235,0.212,0.594,0.228
11,Logistic Regression,0.609,0.332,0.448,0.647,0.27
7,Random Forest,0.67,0.338,0.422,0.623,0.244


### Class Imbalance

In [13]:
imbalance_df = pd.DataFrame(error_dicts).sort_values(by=['Threshold'])[['Threshold','% Low Rate','% High Rate']].drop_duplicates()
imbalance_df

Unnamed: 0,Threshold,% Low Rate,% High Rate
0,50%,5.1,94.9
1,60%,9.3,90.7
3,70%,16.6,83.4
2,80%,31.6,68.4
