-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
test_partition.py
49 lines (32 loc) · 1.86 KB
/
test_partition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
""" This file contains tests for partition explainer.
"""
import pickle
import shap
from . import common
def test_translation(basic_translation_scenario):
model, tokenizer, data = basic_translation_scenario
common.test_additivity(shap.explainers.PartitionExplainer, model, tokenizer, data)
def test_translation_auto(basic_translation_scenario):
model, tokenizer, data = basic_translation_scenario
common.test_additivity(shap.Explainer, model, tokenizer, data)
def test_translation_algorithm_arg(basic_translation_scenario):
model, tokenizer, data = basic_translation_scenario
common.test_additivity(shap.Explainer, model, tokenizer, data, algorithm="partition")
def test_tabular_single_output():
model, data = common.basic_xgboost_scenario(100)
common.test_additivity(shap.explainers.PartitionExplainer, model.predict, shap.maskers.Partition(data), data)
def test_tabular_multi_output():
model, data = common.basic_xgboost_scenario(100)
common.test_additivity(shap.explainers.PartitionExplainer, model.predict_proba, shap.maskers.Partition(data), data)
def test_serialization(basic_translation_scenario):
model, tokenizer, data = basic_translation_scenario
common.test_serialization(shap.explainers.PartitionExplainer, model, tokenizer, data)
def test_serialization_no_model_or_masker(basic_translation_scenario):
model, tokenizer, data = basic_translation_scenario
common.test_serialization(
shap.explainers.Partition, model, tokenizer, data, model_saver=None, masker_saver=None,
model_loader=lambda _: model, masker_loader=lambda _: tokenizer
)
def test_serialization_custom_model_save(basic_translation_scenario):
model, tokenizer, data = basic_translation_scenario
common.test_serialization(shap.explainers.PartitionExplainer, model, tokenizer, data, model_saver=pickle.dump, model_loader=pickle.load)