In [1]:
import os
import sys
import inspect

sys.path.insert(1, os.path.join(sys.path[0], '..'))

import fatapi
from fatapi.data import Data
from fatapi.model import BlackBox, Model, DensityEstimator
import numpy as np
from fatapi.methods import FACEMethod
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=100, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=1)
clf = MLPClassifier(random_state=1, max_iter=300).fit(X_train, y_train)

print(f"Predicted classes of X_test[:5, :]: {clf.predict(X_test[:5, :])}")
print(f"Classification accuracy: {clf.score(X_test, y_test)}\n")
data_X = Data(dataset=X_test, dtype="data", encoded=True)
data_y = Data(dataset=y_test, dtype="target", encoded=True)

# Has to return a boolean 
def conditionf(**kwargs):
    return True

row_indicies = [0,1,2,3,4]
blackb = BlackBox(clf)
face_model = Model(data_X, data_y, blackbox=blackb)
bandwidths = 10 ** np.linspace(-2, 1, 100)  
grid = GridSearchCV(KernelDensity(kernel='gaussian'),
                    {'bandwidth': bandwidths},
                    cv=20)

grid.fit(data_X.dataset)
dens_est = grid.best_estimator_
dens_estt = DensityEstimator(estimator=dens_est)
factuals = data_X.get_rows_as_data(row_indicies)
factuals_target = data_y.get_rows_as_data(row_indicies)

face_method = FACEMethod(factuals=factuals, 
                         factuals_target=factuals_target, 
                         model=face_model, kernel_type="kde", 
                         t_prediction=0.5, epsilon=0.7,
                         t_density=0.0, t_radius_limit=1.10, n_neighbours=20,
                         K=10, conditions=conditionf, density_estimator=dens_estt)
face_method.explain()
print(f"Graph [Distances]: {face_method.get_graph()}\n")
print(f"Paths [Indexes]: {face_method.get_explain_paths()}\n")
print(f"Candidates for Counterfactuals [Indexes]: {face_method.get_explain_candidates()}")

counterfactuals_as_indexes = face_method.get_counterfactuals(as_indexes=True)
counterfactuals = face_method.get_counterfactuals()
counterfactuals_data, counterfactuals_target = face_method.get_counterfactuals_as_data()

print(f"Counterfactuals [Indexes]: {face_method.get_counterfactuals(True)}")

print(f"\nfor factual X[{row_indicies[0]}] (as data: {factuals.dataset[0]}), the counterfactual is X[{counterfactuals_as_indexes[0]}] (as classification: {counterfactuals[0]})")
print(f"\nCounterfactual for X[0] (X[6]) as target (Y) and data (X): \nX[{counterfactuals_as_indexes[0]}]: {counterfactuals_data[0]}, Y[{counterfactuals_as_indexes[0]}]: {counterfactuals_target[0]}")

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed')).History will not be written to the database.
Predicted classes of X_test[:5, :]: [1 0 1 0 1]
Classification accuracy: 0.88



ValueError: Invalid argument in __init__: Data: [[ 6.06548400e-01  8.16957655e-01  1.05132077e+00  1.65712464e+00
  -4.59717681e-01 -5.88963928e-01  6.50323214e-01  5.03170861e-02
   1.70548352e+00  1.41767401e+00 -7.37289628e-01  3.09816759e-01
   7.78174179e-01 -1.12478707e+00 -1.28393266e+00  8.07509886e-02
   9.49961101e-02 -3.64538050e-01  1.64665066e-01 -7.67803746e-01]
 [-2.90545028e-03 -1.18951588e+00 -2.92578935e-01  1.49640531e+00
  -1.20115566e+00 -9.55756520e-02  5.37705087e-01 -3.48471140e-01
  -2.99094967e-01  1.67072922e+00 -6.06303023e-02  3.29489967e-01
   4.40956001e-01  3.83606926e-01  8.56514986e-01 -2.82005902e-01
   8.15600360e-01  2.19477494e-01  9.64022632e-01  4.99224881e-02]
 [-2.79144404e+00  1.06080576e+00 -2.26261533e+00  4.89219194e-01
   2.09012280e-01  7.90061048e-01  2.02447357e-01  1.52186577e+00
  -1.76440808e+00 -1.11683226e+00  2.00556158e+00 -1.76936557e-01
   1.55356522e+00  5.85308978e-01 -1.47841532e-02 -1.22606996e+00
  -4.19717870e-01 -4.82726862e-01  4.52713622e-01 -3.84824623e-02]
 [-2.51852197e-01  3.29535191e-01 -3.01603199e+00  3.00253676e-01
   1.16533544e+00  1.12643073e+00 -2.19552167e-01  8.05913307e-01
   1.96645295e-01 -4.29595674e-01 -1.32648965e+00  1.00819561e+00
   1.11548937e+00  1.34104147e+00  9.31688688e-01  2.00514053e+00
  -1.61964569e+00 -1.48941229e-01  3.08204134e-01 -1.87626349e-01]
 [-9.77773002e-01  1.11298159e+00 -3.69255902e-01  1.00568668e+00
   1.06032751e+00 -3.27882802e-01 -3.48984191e-01  6.14726276e-01
  -1.71116766e+00  3.53567216e-01  1.71957132e-01 -5.22356465e-01
  -1.39528303e+00 -9.08018711e-01 -1.24490005e+00 -2.60466059e-01
   2.65642403e-01  9.81122462e-02  4.90561044e-01  4.45096710e-01]], Categoricals: [], Numericals: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], IsEncoded: True is not of type <class 'numpy.ndarray'>