In [1]:
import pandas
import os
import numpy as np
import torch
import gates_models as gm
import pickle

from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split

In [2]:
path = 'models'
main_path = os.path.join(path, 'round5-train-dataset')
models_path = os.path.join(main_path, 'models')
metadata_file = 'METADATA.csv'    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
df = pandas.read_csv(os.path.join(main_path, metadata_file))
torch.backends.cudnn.enabled=False
use_amp = True if torch.cuda.is_available() else False # attempt to use mixed precision to accelerate embedding conversion process

In [None]:
cdrp_hgates_params = {'gamma':0.0025, 'iter':50, 'lr':0.1, 'eps':1e-3}
cdrp_igates_params = {'gamma':0.025, 'iter':50, 'lr':0.1, 'eps':1e-3}
data = gm.apply_cdrp_on_dataset(df, main_path, models_path, cdrp_hgates_params, cdrp_igates_params, use_amp, device)

with open('gates_data_weakest.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
cdrp_hgates_params = {'gamma':0.005, 'iter':50, 'lr':0.1, 'eps':1e-3}
cdrp_igates_params = {'gamma':0.05, 'iter':50, 'lr':0.1, 'eps':1e-3}
data = gm.apply_cdrp_on_dataset(df, main_path, models_path, cdrp_hgates_params, cdrp_igates_params, use_amp, device)

with open('gates_data_weak.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
cdrp_hgates_params = {'gamma':0.0075, 'iter':50, 'lr':0.1, 'eps':1e-3}
cdrp_igates_params = {'gamma':0.075, 'iter':50, 'lr':0.1, 'eps':1e-3}
data = gm.apply_cdrp_on_dataset(df, main_path, models_path, cdrp_hgates_params, cdrp_igates_params, use_amp, device)

with open('gates_data_medium.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
cdrp_hgates_params = {'gamma':0.01, 'iter':50, 'lr':0.1, 'eps':1e-3}
cdrp_igates_params = {'gamma':0.1, 'iter':50, 'lr':0.1, 'eps':1e-3}
data = gm.apply_cdrp_on_dataset(df, main_path, models_path, cdrp_hgates_params, cdrp_igates_params, use_amp, device)

with open('gates_data_strong.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
cdrp_hgates_params = {'gamma':0.0125, 'iter':50, 'lr':0.1, 'eps':1e-3}
cdrp_igates_params = {'gamma':0.125, 'iter':50, 'lr':0.1, 'eps':1e-3}
data = gm.apply_cdrp_on_dataset(df, main_path, models_path, cdrp_hgates_params, cdrp_igates_params, use_amp, device)

with open('gates_data_strongest.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [53]:
num_bins = 10
conversion = 'histogram'

with open(os.path.join('gates_data_weakest.pickle'), 'rb') as handle:
    data_weakest =  gm.convert_data(pickle.load(handle), num_bins=num_bins, conversion=conversion)

with open(os.path.join('gates_data_weak.pickle'), 'rb') as handle:
    data_weak = gm.convert_data(pickle.load(handle), num_bins=num_bins, conversion=conversion)

with open(os.path.join('gates_data_medium.pickle'), 'rb') as handle:
    data_medium = gm.convert_data(pickle.load(handle), num_bins=num_bins, conversion=conversion)

with open(os.path.join('gates_data_strong.pickle'), 'rb') as handle:
    data_strong = gm.convert_data(pickle.load(handle), num_bins=num_bins, conversion=conversion)

with open(os.path.join('gates_data_strongest.pickle'), 'rb') as handle:
    data_strongest = gm.convert_data(pickle.load(handle), num_bins=num_bins, conversion=conversion)

combined_data = np.hstack((data_weakest['data'], data_weak['data'], data_medium['data'], data_strong['data'], data_strongest['data']))
model_labels = data_weakest['labels']

X_train, X_test, y_train, y_test = train_test_split(combined_data, model_labels, stratify=model_labels, random_state=1)

In [54]:

scoring = make_scorer(accuracy_score)

parameters = {'learning_rate': [0.15,0.1,0.05,0.01,0.005,0.001],  'n_estimators': [100,250,500,750,1000,1250,1500], 'max_depth': [3,5,7]}

clf = GridSearchCV(GradientBoostingClassifier(), parameters, scoring=scoring, refit=True, cv=2, n_jobs=-1).fit(X_train, y_train)
print(f'Acc: {clf.score(X_test, y_test):.2f} - AUC: {roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]):.2f}')

with open('clf_gates.pickle', 'wb') as handle:
    pickle.dump(clf, handle, protocol=pickle.HIGHEST_PROTOCOL)

Acc: 0.78 - AUC: 0.85


In [64]:
scoring = make_scorer(accuracy_score)

calib_clf = CalibratedClassifierCV(base_estimator=clf, cv='prefit').fit(X_test, y_test)

print(f'Acc: {calib_clf.score(X_test, y_test):.2f} - AUC: {roc_auc_score(y_test, calib_clf.predict_proba(X_test)[:, 1]):.2f}')

with open('calib_clf.pickle', 'wb') as handle:
    pickle.dump(calib_clf, handle, protocol=pickle.HIGHEST_PROTOCOL)

Acc: 0.77 - AUC: 0.85


In [67]:
 preds = clf.predict_proba(X_test)
 np.exp(x)/sum(np.exp(x))

array([0.96454441, 0.99996667, 0.99999984, 1.        , 1.        ,
       0.83161124, 0.73448807, 1.        , 0.99064883, 1.        ,
       0.99702032, 1.        , 1.        , 0.99723031, 1.        ,
       1.        , 0.99999983, 1.        , 1.        , 1.        ,
       0.99986045, 1.        , 0.99946409, 0.6855799 , 0.99999916,
       0.99999999, 0.99999686, 1.        , 0.99907028, 1.        ,
       1.        , 0.99999605, 0.99668714, 0.99999992, 1.        ,
       1.        , 1.        , 1.        , 0.99698519, 0.99187957,
       1.        , 0.99996406, 1.        , 0.99919878, 0.9993244 ,
       0.99999999, 1.        , 0.99998031, 0.999483  , 0.99999969,
       0.99728983, 0.99974721, 0.99955613, 0.99999607, 1.        ,
       1.        , 1.        , 0.99984486, 1.        , 1.        ,
       0.99818406, 1.        , 1.        , 0.99999999, 1.        ,
       1.        , 1.        , 0.99977541, 1.        , 0.9775857 ,
       0.99998967, 0.99999983, 1.        , 1.        , 0.72539