Skip to content

Commit

Permalink
Add Surrogate Explainer (#19)
Browse files Browse the repository at this point in the history
* Starting to work on the Surrogate

* Implemented remaining functions of the SurrogateExplainer

* Delete skanderkamoun.xml

* Modified SurrogateExplainer as per comments

* bug fixes and tests
  • Loading branch information
skanderkam authored and aredier committed Oct 8, 2019
1 parent 67b547e commit 67371f2
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 8 deletions.
18 changes: 18 additions & 0 deletions tests/test_surogate_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pandas as pd
from sklearn.tree import DecisionTreeClassifier

from trelawney.surrogate_explainer import SurrogateExplainer


def test_surogate_explainer_single(fake_dataset, fitted_decision_tree):
explainer = SurrogateExplainer(DecisionTreeClassifier(max_depth=3))
explainer.fit(fitted_decision_tree, *fake_dataset)
explanation = explainer.feature_importance(pd.DataFrame([[30, 0.1]], columns=['real', 'fake']))
assert abs(explanation['real']) > abs(explanation['fake'])


def test_surogate_explainer_multiple(fake_dataset, fitted_decision_tree):
explainer = SurrogateExplainer(DecisionTreeClassifier(max_depth=3))
explainer.fit(fitted_decision_tree, *fake_dataset)
explanation = explainer.feature_importance(pd.DataFrame([[5, 0.1], [95, -0.5]], columns=['real', 'fake']))
assert abs(explanation['real']) > abs(explanation['fake'])
2 changes: 2 additions & 0 deletions trelawney/lime_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LimeExplainer(BaseExplainer):
>>> # creating and fiting the explainer
>>> explainer = LimeExplainer()
>>> explainer.fit(model, X, y)
<trelawney.lime_explainer.LimeExplainer object at ...>
>>> # explaining observation
>>> explanation = explainer.explain_local(pd.DataFrame([[5, 0.1]]))[0]
>>> abs(explanation['real']) > abs(explanation['fake'])
Expand All @@ -54,6 +55,7 @@ def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train:
class_names=self.class_names,
categorical_features=self.categorical_features,
discretize_continuous=True)
return self

def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> Dict[str, float]:
raise NotImplementedError('we are not sure global explaination is mathematically sound for LIME, it is still'
Expand Down
1 change: 1 addition & 0 deletions trelawney/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _find_right_explainer(self, x_train):
def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame):
super().fit(model, x_train, y_train)
self._explainer = self._find_right_explainer(x_train)
return self

def _get_shap_values(self, x_explain):
shap_values = self._explainer.shap_values(x_explain.values)
Expand Down
80 changes: 80 additions & 0 deletions trelawney/surrogate_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import operator
from typing import List, Optional, Dict, Callable, Tuple, Union

import pandas as pd
import numpy as np
import sklearn

from trelawney.base_explainer import BaseExplainer
import trelawney.tree_explainer
# import trelawney.logreg_explainer


class SurrogateExplainer(BaseExplainer):
"""
A surrogate model is a substitution model used to explain the initial model. Therefore, substitution models are
generally simpler than the initial ones. Here, we use single trees and logistic regressions as surrogates.
"""

def __init__(self, surrogate_model: sklearn.base.BaseEstimator, ):
if type(surrogate_model) not in [sklearn.tree.tree.DecisionTreeClassifier,
sklearn.linear_model.base.LinearRegression]:
raise NotImplementedError('SurrogateExplainer is only available for single trees (single_tree) and logistic'
'regression (logistic_regression) at the time being.')
self._surrogate = surrogate_model
self._explainer = None
self._x_train = None
self._model_to_explain = None
self._adequation_metric = None

def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame, ):
self._model_to_explain = model
self._x_train = x_train.values
if type(self._surrogate) == sklearn.tree.tree.DecisionTreeClassifier:
self._surrogate.fit(self._x_train, self._model_to_explain.predict(self._x_train))
self._adequation_metric = sklearn.metrics.accuracy_score
self._explainer = trelawney.tree_explainer.TreeExplainer().fit(self._surrogate, x_train, y_train)
else:
raise NotImplementedError
# self._surrogate.fit(x_train, self._model_to_explain.predict_probas(x_train))
# self._adequation_metric = sklearn.metrics.mean_squared_error
# self._explainer = trelawney.logreg_explainer.LogRegExplainer().fit(self._surrogate)
return self

def adequation_score(self, metric: Union[Callable[[np.ndarray, np.ndarray], float], str] = 'auto', ):
"""
returns an adequation score between the output of the surrogate and the output of the initial model based on
the x_train set given.
"""
if metric != 'auto':
self._adequation_metric = metric
return self._adequation_metric(self._model_to_explain.predict(self._x_train),
self._surrogate.predict(self._x_train))

def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None, ) -> Dict[str, float]:
"""
returns a relative importance of each feature globally as a dict.
:param x_explain: the dataset to explain on
:param n_cols: the maximum number of features to return
"""
return self._explainer.feature_importance(x_explain, n_cols)

def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None, ) -> List[Dict[str, float]]:
"""
returns local relative importance of features for a specific observation.
:param x_explain: the dataset to explain on
:param n_cols: the maximum number of features to return
"""
return self._explainer.explain_local(x_explain, n_cols)

def plot_tree(self, out_path: str = './tree_viz.png'):
"""
returns the colored plot of the decision tree and saves an Image in the wd.
:param x_explain: the dataset to explain on
:param n_cols: the maximum number of features to return
:param out_file: name of the generated plot
"""
if type(self._surrogate) != sklearn.tree.tree.DecisionTreeClassifier:
raise TypeError('plot_tree is only available for single tree surrogate')
return self._explainer.plot_tree(out_path=out_path)

13 changes: 5 additions & 8 deletions trelawney/tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train:

self._model_to_explain = model
self._feature_names = x_train.columns
return self

def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> Dict[str, float]:
"""
Expand All @@ -52,16 +53,12 @@ def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -
"""
raise NotImplementedError('no consensus on which values can explain the path followed by an observation')

def plot_tree(self, out_path: str = './tree_viz.png'):
def plot_tree(self, out_path: str = './tree_viz'):
"""
creates a png file of the tree saved in out_path
:param out_path: the path to save the png representation of the tree to
"""
with tempfile.TemporaryDirectory() as dir_path:
dot_path = os.path.join(dir_path, 'tree.dot')

tree.export_graphviz(self._model_to_explain, out_file=dot_path + '.dot', filled=True, rounded=True,
special_characters=True, feature_names=self._feature_names,
class_names=self.class_names)
call(['dot', '-Tpng', dot_path, '-o', out_path, '-Gdpi=600'])
tree.export_graphviz(self._model_to_explain, out_file=out_path + '.dot', filled=True, rounded=True,
special_characters=True, feature_names=self._feature_names, class_names=self.class_names)
call(['dot', '-Tpng', out_path + '.dot', '-o', out_path + '.png', '-Gdpi=600'])

0 comments on commit 67371f2

Please sign in to comment.