<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
from wildwood.dataset import load_car, load_adult, load_bank
from wildwood import ForestClassifier
import sys
import numpy as np

%config Completer.use_jedi = False

In [None]:
from sklearn.preprocessing import LabelBinarizer


In [None]:
import matplotlib.pyplot as plt

%matplotlib inline

def dynamic_print(stri):
    sys.stdout.write('\r'+stri)
    sys.stdout.flush()


In [None]:
dataset = load_car()

random_state = 0

dataset.one_hot_encode = False
dataset.standardize = False

X_train, X_test, y_train, y_test = dataset.extract(random_state=random_state)

clf = ForestClassifier(
    max_features=None,
    class_weight="balanced",
    categorical_features=dataset.categorical_features_,
    random_state=random_state,
    #verbose=True
)


In [None]:
clf.fit(X_train, y_train)

In [None]:
from sklearn.metrics import average_precision_score, roc_auc_score, log_loss, accuracy_score


In [None]:
dirichlet_values = np.array([2**(i) for i in np.linspace(-7, 0, 27)])
step_values = np.array([2**(i) for i in np.linspace(-2, 4, 29)])

task = dataset.task

y_test_binary = LabelBinarizer().fit_transform(y_test)


values = np.zeros((4, len(dirichlet_values), len(step_values)))

for i, dirichlet in enumerate(dirichlet_values):
    clf.dirichlet = dirichlet
    for j, step in enumerate(step_values):
        clf.step = step
        dynamic_print(str(i+1) + "/" + str(len(dirichlet_values)) + "\t" + str(j+1) + "/" + str(len(step_values)))
        
        y_scores = clf.predict_proba(X_test)

        
        y_pred = np.argmax(y_scores, axis=1)
        
        if task == "binary-classification":
            values[0,i,j] = roc_auc_score(y_test, y_scores[:, 1])
            values[1,i,j] = average_precision_score(y_test, y_scores[:, 1])

        elif task == "multiclass-classification":
            values[0,i,j] = roc_auc_score(y_test, y_scores, multi_class="ovr", average="macro")
            values[1,i,j] = average_precision_score(y_test_binary, y_scores)

        values[2,i,j] = log_loss(y_test, y_scores)

        values[3,i,j] = accuracy_score(y_test, y_pred)


In [None]:
import seaborn as sns
import matplotlib.pylab as plb

def skip_ticks(ticks, skip=1):
    return [x if i%skip==0 else '' for i,x in enumerate(ticks)]
st=5
metrics_names = ["roc auc", "average precision", "log loss", "accuracy"]
metric_index=0

ax = sns.heatmap(values[metric_index], xticklabels=skip_ticks(np.around(step_values, decimals=2), st), yticklabels=
                 skip_ticks(np.around(dirichlet_values, decimals=2), st))
ax.set(xlabel="step", ylabel="dirichlet", title=metrics_names[metric_index]+" for "+dataset.name)
plb.show()

In [None]:
metric_index=1

ax = sns.heatmap(values[metric_index], xticklabels=skip_ticks(np.around(step_values, decimals=2), st), yticklabels=
                 skip_ticks(np.around(dirichlet_values, decimals=2), st))
ax.set(xlabel="step", ylabel="dirichlet", title=metrics_names[metric_index]+" for "+dataset.name)
plb.show()

In [None]:
metric_index=2

ax = sns.heatmap(values[metric_index], xticklabels=skip_ticks(np.around(step_values, decimals=2), st), yticklabels=
                 skip_ticks(np.around(dirichlet_values, decimals=2), st))
ax.set(xlabel="step", ylabel="dirichlet", title=metrics_names[metric_index]+" for "+dataset.name)
plb.show()

In [None]:
metric_index=3

ax = sns.heatmap(values[metric_index], xticklabels=skip_ticks(np.around(step_values, decimals=2), st), yticklabels=
                 skip_ticks(np.around(dirichlet_values, decimals=2), st))
ax.set(xlabel="step", ylabel="dirichlet", title=metrics_names[metric_index]+" for "+dataset.name)
plb.show()