## Generate Models to Use in GAM Coach UI

In this notebook, we show how to generate model JSON files to set up your own GAM Coach UI.

In [10]:
import numpy as np
import pandas as pd
import gamcoach as coach
import pickle
import re

from tqdm import tqdm
from io import StringIO
from collections import Counter
from matplotlib import pyplot as plt
from interpret.glassbox import ExplainableBoostingClassifier
from interpret import show
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn import linear_model
from time import time
from json import load, dump

SEED = 922

In [2]:
adult_data = pd.read_csv('./data/adult.data', header=None)
adult_data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


In [3]:
adult_data_test = pd.read_csv('./data/adult.test', header=None)
adult_data_test.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K.
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K.
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K.
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K.
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,<=50K.


In [4]:
adult_data.shape

(32561, 15)

In [5]:
# From https://github.com/amirhk/mace/blob/01e6a405ff74e24dc3438a005cd60892154d189d/_data_main/fair_adult_data.py
adult_attrs = [
    "age",
    "workclass",
    "fnlwgt",
    "education",
    "education_num",
    "marital_status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "capital_gain",
    "capital_loss",
    "hours_per_week",
    "native_country",
]

selected_features = [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13]

x_train = adult_data.to_numpy()[:, selected_features]
x_train[:, [8, 9]] = np.log10(x_train[:, [8, 9]].astype(float) + 1e-6)

y_train_ints = list(map(lambda x: 0 if x == ' <=50K' else 1, adult_data.iloc[:, -1].tolist()))
y_train = np.array(y_train_ints)

x_test = adult_data_test.to_numpy()[:, selected_features]
x_test[:, [8, 9]] = np.log10(x_test[:, [8, 9]].astype(float) + 1e-6)

y_test_ints = list(map(lambda x: 0 if x == ' <=50K.' else 1, adult_data_test.iloc[:, -1].tolist()))
y_test = np.array(y_test_ints)

In [6]:
adult_feature_names = np.array(adult_attrs)[selected_features]
adult_feature_names
adult_cont_indexes = [0, 3, 8, 9, 10]
adult_feature_types = [
    "continuous" if i in adult_cont_indexes else "nominal"
    for i in range(len(adult_feature_names))
]

In [7]:
# Train an EBM classifier
adult_ebm = ExplainableBoostingClassifier(
    feature_names=adult_feature_names,
    feature_types=adult_feature_types,
    random_state=SEED,
)
adult_ebm.fit(x_train, y_train)


In [11]:
y_predict = adult_ebm.predict(x_test)
print('accuracy', metrics.accuracy_score(y_test, y_predict))
pickle.dump(adult_ebm, open('./model_pickles/adult-ebm-ca.pickle', 'wb'))

accuracy 0.8736564093114674


In [12]:
# Users can provide a partial dictionary to add descriptions for some features/levels
feature_info = {
    "age": ["Age", "Age (years)"],
    "workclass": ["Workclass", "Types of workclass"],
    "education": ["Education", "Types of education level"],
    "education_num": ["Education Years", "Years of education"],
    "marital_status": ["Marital Status", "Marital status"],
    "occupation": ["Occupation", "Occupation"],
    "relationship": ["Relationship", "Relationship"],
    "sex": ["Sex", "Sex"],
    "capital_gain": ["Capital gain", "Capital gain in a year"],
    "capital_loss": ["Capital loss", "Capital loss in a year"],
    "hours_per_week": ["Work Hours", "Hours to work every week"],
    "native_country": ["Native Country", "Native Country"],
}

feature_level_info = {
    "workclass": {
        1: ["Unknown", ""],
        2: ["Federal government", ""],
        3: ["Local government", ""],
        4: ["Never worked", ""],
        5: ["Private", ""],
        6: ["Self-employed (inc)", ""],
        7: ["Self-employed", ""],
        8: ["State government", ""],
        9: ["Without pay", ""],
    },
    "education": {
        1: ["10th", ""],
        2: ["11th", ""],
        3: ["12th", ""],
        4: ["1st-4th", ""],
        5: ["5th-6th", ""],
        6: ["7th-8th", ""],
        7: ["9th", ""],
        8: ["Assoc academy", ""],
        9: ["Assoc voc", ""],
        10: ["Bachelors", ""],
        11: ["Doctorate", ""],
        12: ["High school", ""],
        13: ["Masters", ""],
        14: ["Preschool", ""],
        15: ["Professional school", ""],
        16: ["Some college", ""],
    },
    "marital_status": {
        1: ["Divorced", ""],
        2: ["Married (armed force spouse)", ""],
        3: ["Married (civilian spouse)", ""],
        4: ["Married (absent spouse)", ""],
        5: ["Never married", ""],
        6: ["Separated", ""],
        7: ["Widowed", ""],
    },
    "occupation": {
        1: ["Unknown", ""],
        2: ["Adm-clerical", ""],
        3: ["Armed-Forces", ""],
        4: ["Craft-repair", ""],
        5: ["Exec-managerial", ""],
        6: ["Farming-fishing", ""],
        7: ["Handlers-cleaners", ""],
        8: ["Machine-op-inspct", ""],
        9: ["Other-service", ""],
        10: ["Priv-house-serv", ""],
        11: ["Prof-specialty", ""],
        12: ["Protective-serv", ""],
        13: ["Sales", ""],
        14: ["Tech-support", ""],
        15: ["Transport-moving", ""],
    },
    "relationship": {
        1: ["Husband", ""],
        2: ["Not-in-family", ""],
        3: ["Other-relative", ""],
        4: ["Own-child", ""],
        5: ["Unmarried", ""],
        6: ["Wife", ""],
    },
    "sex": {1: ["Female", ""], 2: ["Male", ""]},
    "native_country": {
        1: ["Unknown", ""],
        2: ["Cambodia", ""],
        3: ["Canada", ""],
        4: ["China", ""],
        5: ["Columbia", ""],
        6: ["Cuba", ""],
        7: ["Dominican-Republic", ""],
        8: ["Ecuador", ""],
        9: ["El-Salvador", ""],
        10: ["England", ""],
        11: ["France", ""],
        12: ["Germany", ""],
        13: ["Greece", ""],
        14: ["Guatemala", ""],
        15: ["Haiti", ""],
        16: ["Holand-Netherlands", ""],
        17: ["Honduras", ""],
        18: ["Hong", ""],
        19: ["Hungary", ""],
        20: ["India", ""],
        21: ["Iran", ""],
        22: ["Ireland", ""],
        23: ["Italy", ""],
        24: ["Jamaica", ""],
        25: ["Japan", ""],
        26: ["Laos", ""],
        27: ["Mexico", ""],
        28: ["Nicaragua", ""],
        29: ["Outlying-US(Guam-USVI-etc)", ""],
        30: ["Peru", ""],
        31: ["Philippines", ""],
        32: ["Poland", ""],
        33: ["Portugal", ""],
        34: ["Puerto-Rico", ""],
        35: ["Scotland", ""],
        36: ["South", ""],
        37: ["Taiwan", ""],
        38: ["Thailand", ""],
        39: ["Trinadad&Tobago", ""],
        40: ["United-States", ""],
        41: ["Vietnam", ""],
        42: ["Yugoslavia", ""],
    },
}

# Developers can provide initial configurations for individual features. The
# configuration includes difficulty, requireInt, and search range. If a
# configuration is not given for a feature, it uses the default value
# The difficulty is an integer between 1 and 6: 1 (very easy to change), 2 (easy)
# 3 (default), 4 (hard), 5 (very hard), 6 (impossible to change)

feature_config = {
    "age": {"requiresInt": True, "requiresIncreasing": True},
    # "workclass": {"requiresInt": True},
    # "education": {"requiresInt": True},
    "education_num": {"requiresInt": True},
    # "marital_status": {"requiresInt": True},
    # "occupation": {"requiresInt": True},
    # "relationship": {"requiresInt": True},
    "sex": {"difficulty": 5, "requiresInt": True},
    "capital_gain": {"requiresInt": True, "usesTransform": 'log10'},
    "capital_loss": {"requiresInt": True, "usesTransform": 'log10'},
    "hours_per_week": {"requiresInt": True},
    "native_country": {"difficulty": 5, "requiresInt": True},
}

model_info = {"classes": ["loan rejection", "loan approval"]}


In [13]:
model_data = coach.get_model_data(
    adult_ebm,
    x_train,
    feature_info=feature_info,
    feature_level_info=feature_level_info,
    model_info=model_info,
    feature_config=feature_config,
)

100%|██████████| 22/22 [00:00<00:00, 32.50it/s]


In [16]:
def dump_rejected_samples(cur_ebm, name, x_test, cont_indexes):
    # Find all examples that are rejected
    x_pred_prob = cur_ebm.predict_proba(x_test)[:, 1]
    interested_indexes = np.where(
        np.logical_and(
            x_pred_prob > 0.4,
            x_pred_prob < 0.45
        )
    )[0]

    # Generate 500 random samples
    np.random.seed(SEED)
    random_indexes = interested_indexes[np.random.choice(range(len(interested_indexes)), 500)]

    random_samples = []
    for i in random_indexes:
        cur_row = x_test[i, :].tolist()
        cur_row_with_string = [x if i in cont_indexes else str(x) for (i, x) in enumerate(cur_row)]
        random_samples.append(cur_row_with_string)

    dump(random_samples, open(f'./outputs/{name}-classifier-random-samples.json', 'w'))

In [17]:
dump(model_data, open('./outputs/adult-classifier.json', 'w'))
dump_rejected_samples(adult_ebm, 'adult', x_test, adult_cont_indexes)

## Put JSON Files in the Right Folders

Finally, put `adult-classifier-random-samples.json` to `gamcoach-ui/src/config/data` and put `adult-classifier.json` to `gamcoach-ui/public/data`. You will see GAM Coach running with this model and data file!