Skip to content

Commit

Permalink
Update slicer pin to 0.0.8 and add test (#3560)
Browse files Browse the repository at this point in the history
* install latest slicer commit and add test

* add comment

* fix test

* Pin slicer to 0.08

* Fix typo in pin

---------

Co-authored-by: connortann <71127464+connortann@users.noreply.github.com>
  • Loading branch information
CloseChoice and connortann committed Apr 5, 2024
1 parent bbbe821 commit dffc346
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -18,7 +18,7 @@ dependencies = [
'pandas',
'tqdm>=4.27.0',
'packaging>20.9',
'slicer==0.0.7',
'slicer==0.0.8',
'numba',
'cloudpickle'
]
Expand Down
16 changes: 16 additions & 0 deletions tests/explainers/test_deep.py
Expand Up @@ -275,6 +275,22 @@ def test_tf_keras_imdb_lstm(random_seed):
np.testing.testing_allclose(sums, diff, atol=1e-02), "Sum of SHAP values does not match difference!"


def test_tf_deep_imbdb_transformers():
# GH 3522
transformers = pytest.importorskip('transformers')

from shap import models

# data from datasets imdb dataset
short_data = ['I lov', 'Worth', 'its a', 'STAR ', 'First', 'I had', 'Isaac', 'It ac', 'Techn', 'Hones']
classifier = transformers.pipeline("sentiment-analysis", return_all_scores=True)
pmodel = models.TransformersPipeline(classifier, rescale_to_logits=True)
explainer3 = shap.Explainer(pmodel, classifier.tokenizer)
shap_values3 = explainer3(short_data[:10])
shap.plots.text(shap_values3[:, :, 1])
shap.plots.bar(shap_values3[:, :, 1].mean(0))


def test_tf_deep_multi_inputs_multi_outputs():
tf = pytest.importorskip('tensorflow')

Expand Down

0 comments on commit dffc346

Please sign in to comment.