In [None]:
# !pip install oolearning --upgrade

In [1]:
import copy
import os
import oolearning as oo
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


from helpers import column_log, BinaryAucRocScore

pd.set_option('display.width', 500)
pd.set_option('display.max_colwidth', -1)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
width = 10
plt.rcParams['figure.figsize'] = [width, width/1.333]

In [2]:
csv_file = '../census.csv'
target_variable = 'income'

explore = oo.ExploreClassificationDataset.from_csv(csv_file_path=csv_file,
                                                   target_variable=target_variable)
negative_class = '<=50K'
positive_class = '>50K'

explore.dataset.head(20)

Unnamed: 0,age,workclass,education_level,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,39,State-gov,Bachelors,13.0,Never-married,Adm-clerical,Not-in-family,White,Male,2174.0,0.0,40.0,United-States,<=50K
1,50,Self-emp-not-inc,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,13.0,United-States,<=50K
2,38,Private,HS-grad,9.0,Divorced,Handlers-cleaners,Not-in-family,White,Male,0.0,0.0,40.0,United-States,<=50K
3,53,Private,11th,7.0,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0.0,0.0,40.0,United-States,<=50K
4,28,Private,Bachelors,13.0,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0.0,0.0,40.0,Cuba,<=50K
5,37,Private,Masters,14.0,Married-civ-spouse,Exec-managerial,Wife,White,Female,0.0,0.0,40.0,United-States,<=50K
6,49,Private,9th,5.0,Married-spouse-absent,Other-service,Not-in-family,Black,Female,0.0,0.0,16.0,Jamaica,<=50K
7,52,Self-emp-not-inc,HS-grad,9.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,45.0,United-States,>50K
8,31,Private,Masters,14.0,Never-married,Prof-specialty,Not-in-family,White,Female,14084.0,0.0,50.0,United-States,>50K
9,42,Private,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178.0,0.0,40.0,United-States,>50K


In [3]:
n_positive = np.sum(explore.dataset[target_variable] == positive_class)
n_negative = np.sum(explore.dataset.income == negative_class)
scale_pos_weight_calc = n_negative / n_positive
scale_pos_weight_calc

3.034796573875803

In [4]:
def create_net_capital(x):
    temp = x.copy()
    temp['net capital'] = temp['capital-gain'] - temp['capital-loss']
    return temp

In [5]:
global_transformations = [
    # kaggle test file has white space around values
    oo.StatelessColumnTransformer(columns=explore.categoric_features,
                                  custom_function=lambda x: x.str.strip()),
    oo.ImputationTransformer(),
    oo.StatelessColumnTransformer(columns=['capital-gain', 'capital-loss'],
                               custom_function=column_log),
    oo.StatelessTransformer(custom_function=create_net_capital),
    oo.CenterScaleTransformer(),
    oo.DummyEncodeTransformer(oo.CategoricalEncoding.ONE_HOT)
]

In [6]:
model_infos = [oo.ModelInfo(model=oo.RandomForestClassifier(extra_trees_implementation=True),
                            hyper_params=oo.RandomForestHP(
                                criterion='gini',
                                num_features=None,
                                max_features=1.0,
                                n_estimators=3750,
                                max_depth=14,
                            )),
               oo.ModelInfo(model=oo.RandomForestClassifier(),
                            hyper_params=oo.RandomForestHP(
                                criterion='gini',
                                num_features=None,
                                max_features=0.2,
                                n_estimators=1815,
                                max_depth=20,
                                min_samples_split=16,
                                min_samples_leaf=2,
                                min_weight_fraction_leaf=0.0,
                                max_leaf_nodes=None,
                                min_impurity_decrease=0,
                            )),
               oo.ModelInfo(model=oo.AdaBoostClassifier(),
                            hyper_params=oo.AdaBoostClassifierHP(
                                n_estimators=4250,
                                learning_rate=0.45,
                                algorithm='SAMME.R',
                                # Tree-specific hyper-params
                                criterion='gini',
                                splitter='best',
                                max_features=0.3,
                                max_depth=2,
                                min_samples_split=0.7,
                                min_samples_leaf=0.004,
                                min_weight_fraction_leaf=0.,
                                max_leaf_nodes=None,
                                min_impurity_decrease=0.,
                                class_weight=None,
                            )),
               oo.ModelInfo(model=oo.LogisticClassifier(),
                            hyper_params=oo.LogisticClassifierHP(
                                penalty='l2',
                                regularization_inverse=0.245
                            )),
               oo.ModelInfo(model=oo.SvmLinearClassifier(),
                            hyper_params=oo.SvmLinearClassifierHP(
                                penalty='l2',
                                penalty_c=10,
                            )),
               oo.ModelInfo(model=oo.XGBoostClassifier(),
                            hyper_params=oo.XGBoostTreeHP(
                                objective=oo.XGBObjective.BINARY_LOGISTIC,
                                learning_rate=0.045,
                                n_estimators=3000,
                                max_depth=3,
                                min_child_weight=5,
                                gamma=0.15,
                                subsample=1,
                                colsample_bytree=0.4,
                             reg_alpha=0,
                             reg_lambda=2,
                             scale_pos_weight=scale_pos_weight_calc,
                         )),
]

In [7]:
# use the ideal threshold for the evaluator in order to view ROC
evaluator = oo.TwoClassProbabilityEvaluator(converter=oo.TwoClassThresholdConverter(positive_class=positive_class))

trainer = oo.ModelTrainer(model=oo.ModelAggregator(base_models=model_infos,
                                                   aggregation_strategy=oo.SoftVotingAggregationStrategy(aggregation=np.median)),
                          model_transformations=[t.clone() for t in global_transformations],
                          splitter=oo.ClassificationStratifiedDataSplitter(holdout_ratio=0.2),  # don't split, train on all data
                          evaluator=evaluator)
predictions = trainer.train_predict_eval(data=explore.dataset,
                                         target_variable=target_variable)

In [8]:
trainer.training_evaluator.all_quality_metrics

{'AUC ROC': 0.9396031339033886,
 'AUC Precision/Recall': 0.8541285779816499,
 'Kappa': 0.6563073562997247,
 'F1 Score': 0.7332966562748334,
 'Two-Class Accuracy': 0.8793985128672914,
 'Error Rate': 0.12060148713270863,
 'True Positive Rate': 0.66897167075619,
 'True Negative Rate': 0.9487339678806365,
 'False Positive Rate': 0.051266032119363494,
 'False Negative Rate': 0.3310283292438099,
 'Positive Predictive Value': 0.8113079940484241,
 'Negative Predictive Value': 0.896887159533074,
 'Prevalence': 0.247837023523233,
 'No Information Rate': 0.752162976476767,
 'Total Observations': 36177}

In [9]:
trainer.holdout_evaluator.all_quality_metrics

{'AUC ROC': 0.9165018830570498,
 'AUC Precision/Recall': 0.8076945992789192,
 'Kappa': 0.6028800505158199,
 'F1 Score': 0.6912487708947885,
 'Two-Class Accuracy': 0.8611387506909896,
 'Error Rate': 0.1388612493090105,
 'True Positive Rate': 0.6271186440677966,
 'True Negative Rate': 0.9382625312362193,
 'False Positive Rate': 0.06173746876378069,
 'False Negative Rate': 0.3728813559322034,
 'Positive Predictive Value': 0.7699890470974808,
 'Negative Predictive Value': 0.884194486771021,
 'Prevalence': 0.2478717523493643,
 'No Information Rate': 0.7521282476506357,
 'Total Observations': 9045}