In [10]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import gzip
import pickle
import datetime

from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier

from sklearn.metrics import f1_score
from sklearn.model_selection import RandomizedSearchCV

In [11]:
dataset_list = ['yeast', 'diabete', 'woman']
columns_ylist = ['Class', 'readmitted', 'service']
blackbox_list = ['rf', 'svm', 'mlp']

In [12]:
dt_params = {
    'max_depth': [None, 10, 20, 30, 40, 50, 70, 80, 90, 100],
    'min_samples_split': [2**i for i in range(1, 10)],
    'min_samples_leaf': [2**i for i in range(1, 10)],
}

In [17]:
import warnings
warnings.filterwarnings("ignore")

cv = 5
for idx, dataset in enumerate(dataset_list):
    print(datetime.datetime.now(), 'dataset: %s' % dataset)
    df_2e = pd.read_csv('../dataset/%s_2e.csv' % dataset)
    
    cols_Y = [col for col in df_2e.columns if col.startswith(columns_ylist[idx])]
    cols_X = [col for col in df_2e.columns if col not in cols_Y]
    
    X2e = df_2e[cols_X].values
    y2e = df_2e[cols_Y].values
    
    for blackbox_name in blackbox_list:
        print(datetime.datetime.now(), '\tblack box: %s' % blackbox_name)
        
        bb = pickle.load(gzip.open('../models_hold_out/%s_%s.pickle.gz' % (blackbox_name, dataset), 'rb'))
        y = bb.predict(X2e)
        
        dt = DecisionTreeClassifier()
        sop = np.prod([len(v) for k, v in dt_params.items()])
        n_iter_search = min(100, sop)
        random_search = RandomizedSearchCV(dt, param_distributions=dt_params,
                                           scoring='f1_micro', n_iter=n_iter_search, cv=cv)
        random_search.fit(X2e, y)
        best_params = random_search.best_params_
        dt.set_params(**best_params)
        
        pred_2e = bb.predict(X2e)
        print(datetime.datetime.now(), '\t  F1: %.4f' % f1_score(y, pred_2e, average='micro'))
        
        pickle_file = gzip.open('../global_dt/%s_%s.pickle.gz' % (blackbox_name, dataset), 'wb')
        pickle.dump(dt, pickle_file)
        pickle_file.close()

2018-10-26 09:59:00.505916 dataset: yeast
2018-10-26 09:59:00.539023 	black box: rf
2018-10-26 09:59:03.517211 	  F1: 1.0000
