Skip to content

Commit

Permalink
FIX: remove ipython import warning from top-level shap import (#3090)
Browse files Browse the repository at this point in the history
* Delay ipython warning until function called

* Better ImportError

* Promote warning to ImportError

* Remove unused import
  • Loading branch information
connortann committed Jul 13, 2023
1 parent bf0adf8 commit a0219d0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
6 changes: 3 additions & 3 deletions shap/__init__.py
@@ -1,5 +1,4 @@
# flake8: noqa
import warnings

__version__ = "0.42.0"

Expand All @@ -25,12 +24,12 @@

# plotting (only loaded if matplotlib is present)
def unsupported(*args, **kwargs):
warnings.warn(_no_matplotlib_warning)
raise ImportError(_no_matplotlib_warning)


class UnsupportedModule:
def __getattribute__(self, item):
raise ValueError(_no_matplotlib_warning)
raise ImportError(_no_matplotlib_warning)


try:
Expand All @@ -39,6 +38,7 @@ def __getattribute__(self, item):
except ImportError:
have_matplotlib = False
if have_matplotlib:
from . import plots
from .plots._beeswarm import summary_legacy as summary_plot
from .plots._decision import decision as decision_plot, multioutput_decision as multioutput_decision_plot
from .plots._scatter import dependence_legacy as dependence_plot
Expand Down
16 changes: 11 additions & 5 deletions shap/plots/_image.py
@@ -1,20 +1,20 @@
import json
import random
import string
import warnings
from typing import Optional

import matplotlib.pyplot as pl
import numpy as np
from matplotlib.colors import Colormap

from .._explanation import Explanation
from ..utils import ordinal_str

try:
from IPython.display import HTML, display
have_ipython = True
except ImportError:
warnings.warn("IPython could not be loaded!")
have_ipython = False

from .._explanation import Explanation
from ..utils import ordinal_str
from ..utils._legacy import kmeans
from . import colors

Expand Down Expand Up @@ -190,6 +190,12 @@ def image_to_text(shap_values):
for each sample
"""
if not have_ipython:
msg = (
"IPython is required for this function but is not installed."
" Fix this with `pip install ipython`."
)
raise ImportError(msg)

if len(shap_values.values.shape) == 5:
for i in range(shap_values.values.shape[0]):
Expand Down
23 changes: 17 additions & 6 deletions shap/plots/_text.py
Expand Up @@ -112,7 +112,7 @@ def values_min_max(values, base_values):
separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False
)
if display:
ipython_display(HTML(out))
_ipython_display_html(out)
return
else:
return out
Expand Down Expand Up @@ -211,7 +211,7 @@ def values_min_max(values, base_values):
out += "</div>"
out += "</div>"
if display:
ipython_display(HTML(out))
_ipython_display_html(out)
return
else:
return out
Expand Down Expand Up @@ -256,7 +256,7 @@ def values_min_max(values, base_values):
separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False
)
if display:
ipython_display(HTML(out))
_ipython_display_html(out)
return
else:
return out
Expand Down Expand Up @@ -325,7 +325,7 @@ def values_min_max(values, base_values):
out += "</div>"

if display:
ipython_display(HTML(out))
_ipython_display_html(out)
return
else:
return out
Expand Down Expand Up @@ -831,7 +831,7 @@ def merge_tokens(new_tokens, new_values, group_sizes, i):
+ "</div>" \
+ "</div>"

return ipython_display(HTML(out))
return _ipython_display_html(out)

def text_to_text(shap_values):

Expand Down Expand Up @@ -887,7 +887,7 @@ def text_to_text(shap_values):
</script>
"""

ipython_display(HTML(javascript + html))
_ipython_display_html(javascript + html)

def saliency_plot(shap_values):

Expand Down Expand Up @@ -1339,3 +1339,14 @@ def unpack_shap_explanation_contents(shap_values):
clustering = getattr(shap_values, "clustering", None)

return np.array(values), clustering


def _ipython_display_html(data):
"""Check IPython is installed, then display HTML"""
if not have_ipython:
msg = (
"IPython is required for this function but is not installed."
" Fix this with `pip install ipython`."
)
raise ImportError(msg)
return ipython_display(HTML(data))

0 comments on commit a0219d0

Please sign in to comment.