Skip to content

Commit

Permalink
add support for spark decision tree regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
QuentinAmbard committed Oct 25, 2019
2 parents 6a3f622 + 51760b7 commit ab5b9e7
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 39 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -19,4 +19,5 @@ shap/_cext.so
/notebooks/deep_explainer/mnist_data
/data/CRIC*
/docs/artwork/local
/shap/_cext*
/shap/_cext*
.idea
56 changes: 35 additions & 21 deletions shap/explainers/tree.py
Expand Up @@ -381,6 +381,7 @@ def __init__(self, model, data=None, data_missing=None):
# we use names like keras
objective_name_map = {
"mse": "squared_error",
"variance": "squared_error",
"friedman_mse": "squared_error",
"reg:linear": "squared_error",
"reg:squarederror": "squared_error",
Expand Down Expand Up @@ -521,28 +522,35 @@ def __init__(self, model, data=None, data_missing=None):

self.trees = [Tree(e.tree_, scaling=model.learning_rate, data=data, data_missing=data_missing) for e in model.estimators_[:,0]]
self.objective = objective_name_map.get(model.criterion, None)
elif str(type(model)).endswith("pyspark.ml.classification.RandomForestClassificationModel'>") \
or str(type(model)).endswith("pyspark.ml.classification.GBTClassificationModel'>"):
import pyspark
elif "pyspark.ml" in str(type(model)):
assert_import("pyspark")
self.original_model = model
self.model_type = "pyspark"
self.trees = [Tree(tree, scaling=model.treeWeights[i]) for i, tree in enumerate(model.trees)]
if model._java_obj.getImpurity() == 'variance':
assert False, "Unsupported objective: variance"
# model._java_obj.getImpurity() can be gini, entropy or variance.
self.objective = objective_name_map.get(model._java_obj.getImpurity(), None)
self.tree_output = "raw_value"
elif str(type(model)).endswith("pyspark.ml.classification.DecisionTreeClassificationModel'>"):
import pyspark
self.original_model = model
self.model_type = "pyspark"
self.trees = [Tree(model, scaling=1)]
#model._java_obj.getImpurity() can be gini, entropy or variance.
if model._java_obj.getImpurity() == 'variance':
#TODO handle variance as loss?
assert False, "Unsupported objective: variance"
self.objective = objective_name_map.get(model._java_obj.getImpurity(), None)
#TODO base_offset?
self.tree_output = "raw_value"
if "Classification" in str(type(model)):
normalize = True
self.tree_output = "probability"
else:
normalize = False
self.tree_output = "raw_value"
# Spark Random forest, create 1 weighted (avg) tree per sub-model
if str(type(model)).endswith("pyspark.ml.classification.RandomForestClassificationModel'>") \
or str(type(model)).endswith("pyspark.ml.regression.RandomForestRegressionModel'>"):
sum_weight = sum(model.treeWeights) # output is average of trees
self.trees = [Tree(tree, normalize=normalize, scaling=model.treeWeights[i]/sum_weight) for i, tree in enumerate(model.trees)]
# Spark GBT, create 1 weighted (learning rate) tree per sub-model
elif str(type(model)).endswith("pyspark.ml.classification.GBTClassificationModel'>") \
or str(type(model)).endswith("pyspark.ml.regression.GBTRegressionModel'>"):
self.objective = "squared_error" # GBT subtree use the variance
self.tree_output = "raw_value"
self.trees = [Tree(tree, normalize=False, scaling=model.treeWeights[i]) for i, tree in enumerate(model.trees)]
# Spark Basic model (single tree)
elif str(type(model)).endswith("pyspark.ml.classification.DecisionTreeClassificationModel'>") \
or str(type(model)).endswith("pyspark.ml.regression.DecisionTreeRegressionModel'>"):
self.trees = [Tree(model, normalize=normalize, scaling=1)]
else:
assert False, "Unsupported Spark model type: " + str(type(model))
elif str(type(model)).endswith("xgboost.core.Booster'>"):
import xgboost
self.original_model = model
Expand Down Expand Up @@ -835,7 +843,8 @@ def __init__(self, tree, normalize=False, scaling=1.0, data=None, data_missing=N
self.values = tree["value"] * scaling
self.node_sample_weight = tree["node_sample_weight"]

elif str(type(tree)).endswith("pyspark.ml.classification.DecisionTreeClassificationModel'>"):
elif str(type(tree)).endswith("pyspark.ml.classification.DecisionTreeClassificationModel'>") \
or str(type(tree)).endswith("pyspark.ml.regression.DecisionTreeRegressionModel'>"):
#model._java_obj.numNodes() doesn't give leaves, need to recompute the size
def getNumNodes(node, size):
size = size + 1
Expand All @@ -855,7 +864,10 @@ def getNumNodes(node, size):
self.node_sample_weight = np.full(num_nodes, -2, dtype=np.float64)
def buildTree(index, node):
index = index + 1
self.values[index] = [e for e in node.impurityStats().stats()] #NDarray(numLabel): 1 per label: number of item for each label which went through this node
if tree._java_obj.getImpurity() == 'variance':
self.values[index] = [node.prediction()] #prediction for the node
else:
self.values[index] = [e for e in node.impurityStats().stats()] #for gini: NDarray(numLabel): 1 per label: number of item for each label which went through this node
self.node_sample_weight[index] = node.impurityStats().count() #weighted count of element trough this node

if node.subtreeDepth() == 0:
Expand All @@ -877,6 +889,8 @@ def buildTree(index, node):
#default Not supported with mlib? (TODO)
self.children_default = self.children_left
self.values = np.asarray(self.values)
if normalize:
self.values = (self.values.T / self.values.sum(1)).T
self.values = self.values * scaling

elif type(tree) == dict and 'tree_structure' in tree:
Expand Down
84 changes: 67 additions & 17 deletions tests/explainers/test_tree.py
Expand Up @@ -187,17 +187,17 @@ def test_xgboost_mixed_types():
shap_values = shap.TreeExplainer(bst).shap_values(X)
shap.dependence_plot(0, shap_values, X, show=False)

def test_pyspark_decision_tree():
def test_pyspark_classifier_decision_tree():
try:
import pyspark
import sklearn.datasets
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import RandomForestClassifier, DecisionTreeClassifier, GBTClassifier
import pandas as pd
except:
print("Skipping test_pyspark_decision_tree!")
print("Skipping test_pyspark_classifier_decision_tree!")
return
import shap

Expand All @@ -210,22 +210,71 @@ def test_pyspark_decision_tree():
iris = VectorAssembler(inputCols=col[:-1],outputCol="features").transform(iris)
iris = StringIndexer(inputCol="type", outputCol="label").fit(iris).transform(iris)

dt = DecisionTreeClassifier(labelCol="label", featuresCol="features")
model = dt.fit(iris)
explainer = shap.TreeExplainer(model)
X = pd.DataFrame(data=iris_sk.data, columns=iris_sk.feature_names)[:100] # pylint: disable=E1101
classifiers = [GBTClassifier(labelCol="label", featuresCol="features"),
RandomForestClassifier(labelCol="label", featuresCol="features"),
DecisionTreeClassifier(labelCol="label", featuresCol="features")]
for classifier in classifiers:
model = classifier.fit(iris)
explainer = shap.TreeExplainer(model)
X = pd.DataFrame(data=iris_sk.data, columns=iris_sk.feature_names)[:100] # pylint: disable=E1101

shap_values = explainer.shap_values(X)
expected_values = explainer.expected_value
shap_values = explainer.shap_values(X)
expected_values = explainer.expected_value

# validate values sum to the margin prediction of the model plus expected_value
predictions = model.transform(iris).select("rawPrediction")\
.rdd.map(lambda x:[float(y) for y in x['rawPrediction']]).toDF(['class0','class1']).toPandas()
diffs = expected_values[0] + shap_values[0].sum(1) - predictions.class0
assert np.max(np.abs(diffs)) < 1e-6, "SHAP values don't sum to model output for class0!"
diffs = expected_values[1] + shap_values[1].sum(1) - predictions.class1
assert np.max(np.abs(diffs)) < 1e-6, "SHAP values don't sum to model output for class1!"
assert (np.abs(expected_values - predictions.mean()) < 1e-6).all(), "Bad expected_value!"
predictions = model.transform(iris).select("rawPrediction")\
.rdd.map(lambda x:[float(y) for y in x['rawPrediction']]).toDF(['class0','class1']).toPandas()

if str(type(model)).endswith("GBTClassificationModel'>"):
diffs = expected_values + shap_values.sum(1) - predictions.class1
assert np.max(np.abs(diffs)) < 1e-6, "SHAP values don't sum to model output for class0!"
else:
normalizedPredictions = (predictions.T / predictions.sum(1)).T
diffs = expected_values[0] + shap_values[0].sum(1) - normalizedPredictions.class0
assert np.max(np.abs(diffs)) < 1e-6, "SHAP values don't sum to model output for class0!"+model
diffs = expected_values[1] + shap_values[1].sum(1) - normalizedPredictions.class1
assert np.max(np.abs(diffs)) < 1e-6, "SHAP values don't sum to model output for class1!"+model
assert (np.abs(expected_values - normalizedPredictions.mean()) < 1e-1).all(), "Bad expected_value!"+model
spark.stop()

def test_pyspark_regression_decision_tree():
try:
import pyspark
import sklearn.datasets
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.regression import DecisionTreeRegressor, GBTRegressor, RandomForestRegressor
import pandas as pd
except:
print("Skipping test_pyspark_regression_decision_tree!")
return
import shap

iris_sk = sklearn.datasets.load_iris()
iris = pd.DataFrame(data= np.c_[iris_sk['data'], iris_sk['target']], columns= iris_sk['feature_names'] + ['target'])[:100]
spark = SparkSession.builder.config(conf=SparkConf().set("spark.master", "local[*]")).getOrCreate()

# Simple regressor: try to predict sepal length based on the other features
col = ["sepal_length","sepal_width","petal_length","petal_width","type"]
iris = spark.createDataFrame(iris, col).drop("type")
iris = VectorAssembler(inputCols=col[1:-1],outputCol="features").transform(iris)

regressors = [GBTRegressor(labelCol="sepal_length", featuresCol="features"),
RandomForestRegressor(labelCol="sepal_length", featuresCol="features"),
DecisionTreeRegressor(labelCol="sepal_length", featuresCol="features")]
for regressor in regressors:
model = regressor.fit(iris)
explainer = shap.TreeExplainer(model)
X = pd.DataFrame(data=iris_sk.data, columns=iris_sk.feature_names).drop('sepal length (cm)', 1)[:100] # pylint: disable=E1101

shap_values = explainer.shap_values(X)
expected_values = explainer.expected_value

# validate values sum to the margin prediction of the model plus expected_value
predictions = model.transform(iris).select("prediction").toPandas()
diffs = expected_values + shap_values.sum(1) - predictions["prediction"]
assert np.max(np.abs(diffs)) < 1e-6, "SHAP values don't sum to model output for class0!"
assert (np.abs(expected_values - predictions.mean()) < 1e-1).all(), "Bad expected_value!"
spark.stop()

def test_sklearn_random_forest_multiclass():
Expand Down Expand Up @@ -955,6 +1004,7 @@ def test_xgboost_classifier_independent_margin():

assert np.allclose(shap_values.sum(1) + e.expected_value, model.predict(X, output_margin=True))


def test_xgboost_classifier_independent_probability():
try:
import xgboost
Expand Down

0 comments on commit ab5b9e7

Please sign in to comment.