In [1]:
import os
import sys
import inspect

sys.path.insert(1, os.path.join(sys.path[0], '..'))

import fatapi
from fatapi.data import Data
from fatapi.model import BlackBox, Model, DensityEstimator
import numpy as np
from fatapi.methods import FACEMethod
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=100, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=1)
clf = MLPClassifier(random_state=1, max_iter=300).fit(X_train, y_train)

print(clf.predict(X_test[:5, :]))
print(clf.score(X_test, y_test))
data_X = Data(dataset=X_test, dtype="data", encoded=True)
data_y = Data(dataset=y_test, dtype="target", encoded=True)

# Has to return a boolean 
def conditionf(**kwargs):
    return True

row_indicies = [0,1,2,3,4]
blackb = BlackBox(clf)
face_model = Model(data_X, data_y, blackbox=blackb)
bandwidths = 10 ** np.linspace(-2, 1, 100)  
#bandwidths = [0.60]
grid = GridSearchCV(KernelDensity(kernel='gaussian'),
                    {'bandwidth': bandwidths},
                    cv=20)
#encoded, normalised dataset
grid.fit(data_X.dataset)
dens_est = grid.best_estimator_
dens_estt = DensityEstimator(estimator=grid.best_estimator_)
face_method = FACEMethod(factuals=data_X.get_rows_as_data(row_indicies), factuals_target=data_y.get_rows_as_data(row_indicies), model=face_model, kernel_type="kde", epsilon=0.7, t_prediction=0.5, t_density=0.01, conditions=conditionf, density_estimator=dens_estt)
print(f"G: {face_method.explain()}")

[1 0 1 0 1]
0.88
G: [[  0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.        ]
 [  0.           0.           0.           0.           0.
  142.09837797 111.87947344   0.           0.         150.39504627
  114.23647882   0.           0.         150.19989541   0.
  131.58366065 157.47805115 112.31574955 132.74981186 145.83920597
    0.           0.         173.94974939 168.88400231   0.        ]
 [  0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.        ]
 [  0.           0.           0.           0.       