-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Starting to work on the Surrogate * Implemented remaining functions of the SurrogateExplainer * Delete skanderkamoun.xml * Modified SurrogateExplainer as per comments * bug fixes and tests
- Loading branch information
1 parent
67b547e
commit 67371f2
Showing
5 changed files
with
106 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pandas as pd | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
from trelawney.surrogate_explainer import SurrogateExplainer | ||
|
||
|
||
def test_surogate_explainer_single(fake_dataset, fitted_decision_tree): | ||
explainer = SurrogateExplainer(DecisionTreeClassifier(max_depth=3)) | ||
explainer.fit(fitted_decision_tree, *fake_dataset) | ||
explanation = explainer.feature_importance(pd.DataFrame([[30, 0.1]], columns=['real', 'fake'])) | ||
assert abs(explanation['real']) > abs(explanation['fake']) | ||
|
||
|
||
def test_surogate_explainer_multiple(fake_dataset, fitted_decision_tree): | ||
explainer = SurrogateExplainer(DecisionTreeClassifier(max_depth=3)) | ||
explainer.fit(fitted_decision_tree, *fake_dataset) | ||
explanation = explainer.feature_importance(pd.DataFrame([[5, 0.1], [95, -0.5]], columns=['real', 'fake'])) | ||
assert abs(explanation['real']) > abs(explanation['fake']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import operator | ||
from typing import List, Optional, Dict, Callable, Tuple, Union | ||
|
||
import pandas as pd | ||
import numpy as np | ||
import sklearn | ||
|
||
from trelawney.base_explainer import BaseExplainer | ||
import trelawney.tree_explainer | ||
# import trelawney.logreg_explainer | ||
|
||
|
||
class SurrogateExplainer(BaseExplainer): | ||
""" | ||
A surrogate model is a substitution model used to explain the initial model. Therefore, substitution models are | ||
generally simpler than the initial ones. Here, we use single trees and logistic regressions as surrogates. | ||
""" | ||
|
||
def __init__(self, surrogate_model: sklearn.base.BaseEstimator, ): | ||
if type(surrogate_model) not in [sklearn.tree.tree.DecisionTreeClassifier, | ||
sklearn.linear_model.base.LinearRegression]: | ||
raise NotImplementedError('SurrogateExplainer is only available for single trees (single_tree) and logistic' | ||
'regression (logistic_regression) at the time being.') | ||
self._surrogate = surrogate_model | ||
self._explainer = None | ||
self._x_train = None | ||
self._model_to_explain = None | ||
self._adequation_metric = None | ||
|
||
def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame, ): | ||
self._model_to_explain = model | ||
self._x_train = x_train.values | ||
if type(self._surrogate) == sklearn.tree.tree.DecisionTreeClassifier: | ||
self._surrogate.fit(self._x_train, self._model_to_explain.predict(self._x_train)) | ||
self._adequation_metric = sklearn.metrics.accuracy_score | ||
self._explainer = trelawney.tree_explainer.TreeExplainer().fit(self._surrogate, x_train, y_train) | ||
else: | ||
raise NotImplementedError | ||
# self._surrogate.fit(x_train, self._model_to_explain.predict_probas(x_train)) | ||
# self._adequation_metric = sklearn.metrics.mean_squared_error | ||
# self._explainer = trelawney.logreg_explainer.LogRegExplainer().fit(self._surrogate) | ||
return self | ||
|
||
def adequation_score(self, metric: Union[Callable[[np.ndarray, np.ndarray], float], str] = 'auto', ): | ||
""" | ||
returns an adequation score between the output of the surrogate and the output of the initial model based on | ||
the x_train set given. | ||
""" | ||
if metric != 'auto': | ||
self._adequation_metric = metric | ||
return self._adequation_metric(self._model_to_explain.predict(self._x_train), | ||
self._surrogate.predict(self._x_train)) | ||
|
||
def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None, ) -> Dict[str, float]: | ||
""" | ||
returns a relative importance of each feature globally as a dict. | ||
:param x_explain: the dataset to explain on | ||
:param n_cols: the maximum number of features to return | ||
""" | ||
return self._explainer.feature_importance(x_explain, n_cols) | ||
|
||
def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None, ) -> List[Dict[str, float]]: | ||
""" | ||
returns local relative importance of features for a specific observation. | ||
:param x_explain: the dataset to explain on | ||
:param n_cols: the maximum number of features to return | ||
""" | ||
return self._explainer.explain_local(x_explain, n_cols) | ||
|
||
def plot_tree(self, out_path: str = './tree_viz.png'): | ||
""" | ||
returns the colored plot of the decision tree and saves an Image in the wd. | ||
:param x_explain: the dataset to explain on | ||
:param n_cols: the maximum number of features to return | ||
:param out_file: name of the generated plot | ||
""" | ||
if type(self._surrogate) != sklearn.tree.tree.DecisionTreeClassifier: | ||
raise TypeError('plot_tree is only available for single tree surrogate') | ||
return self._explainer.plot_tree(out_path=out_path) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters