In [1]:
import os
import sys
import inspect

sys.path.insert(1, os.path.join(sys.path[0], '..'))

import fatapi
from fatapi.data import Data
from fatapi.model import BlackBox, Model, DensityEstimator
import numpy as np
from fatapi.methods import FACEMethod
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=100, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=1)
clf = MLPClassifier(random_state=1, max_iter=300).fit(X_train, y_train)

print(clf.predict(X_test[:5, :]))
print(clf.score(X_test, y_test))
data_X = Data(dataset=X_test, dtype="data", encoded=True)
data_y = Data(dataset=y_test, dtype="target", encoded=True)

# Has to return a boolean 
def conditionf(**kwargs):
    return True

row_indicies = [0,1,2,3,4]
blackb = BlackBox(clf)
face_model = Model(data_X, data_y, blackbox=blackb)
bandwidths = 10 ** np.linspace(-2, 1, 100)  
grid = GridSearchCV(KernelDensity(kernel='gaussian'),
                    {'bandwidth': bandwidths},
                    cv=20)
#encoded, normalised dataset
grid.fit(data_X.dataset)
dens_est = grid.best_estimator_
dens_estt = DensityEstimator(estimator=dens_est)
factuals = data_X.get_rows_as_data(row_indicies)
factuals_target = data_y.get_rows_as_data(row_indicies)

face_method = FACEMethod(factuals=factuals, 
                         factuals_target=factuals_target, 
                         model=face_model, kernel_type="kde", 
                         t_prediction=0.5, epsilon=0.7,
                         t_density=0.0, t_radius_limit=1.10, n_neighbours=20,
                         K=10, conditions=conditionf, density_estimator=dens_estt)
face_method.explain()
print(f"Graph: {face_method.get_graph()}")
print(f"Paths: {face_method.get_explain_paths()}")
print(f"Candidate Indexes: {face_method.get_explain_candidates()}")

counterfactuals_as_indexes = face_method.get_counterfactuals(as_indexes=True)
counterfactuals = face_method.get_counterfactuals()
counterfactuals_data, counterfactuals_target = face_method.get_counterfactuals_as_data()

print(f"Counterfactuals: {face_method.get_counterfactuals(True)}")

print(f"\nfor factual X[{row_indicies[0]}] (as data: {factuals.dataset[0]}), the counterfactual is X[{counterfactuals_as_indexes[0]}] (as classification: {counterfactuals[0]})")
print(f"\nX[{counterfactuals_as_indexes[0]}]: {counterfactuals_data[0]}, Y[{counterfactuals_as_indexes[0]}]: {counterfactuals_target[0]}")

[1 0 1 0 1]
0.88
Graph: [[  0.         112.57176708 220.16807468 197.39020876 138.25354424
  131.45324515  94.94605226 107.60612432 159.69028866 167.35360196
  168.19924302 194.14507213 108.97600195 188.67056034 133.16751049
  138.68799647 150.78044241 157.27496418 142.51660328 198.03055776
  159.1367721  174.45406741 182.69219212 184.96573136 111.34771086]
 [112.57176708   0.         178.64358421 167.75957386 133.94370394
  142.09837797 111.87947344 146.07103048 122.9530168  150.39504627
  114.23647882 140.94246804 130.15339729 150.19989541 135.30434116
  131.58366065 157.47805115 112.31574955 132.74981186 145.83920597
  182.52857628 157.59849275 173.94974939 168.88400231 151.32914559]
 [220.16807468 178.64358421   0.         168.03301662 140.83258562
  256.44172687 219.18518211 174.74552491 182.2396636  231.83429055
  173.37380049 135.78565636 176.17626343 188.815039   157.62208836
  227.74336801 198.89022188 179.88652883 221.53469572 179.1279021
  252.13336688 191.92474883 226.25920