From a0219d00d9f713b8f06771ddb50d91a6fbe66beb Mon Sep 17 00:00:00 2001 From: connortann <71127464+connortann@users.noreply.github.com> Date: Thu, 13 Jul 2023 13:38:22 +0100 Subject: [PATCH] FIX: remove ipython import warning from top-level shap import (#3090) * Delay ipython warning until function called * Better ImportError * Promote warning to ImportError * Remove unused import --- shap/__init__.py | 6 +++--- shap/plots/_image.py | 16 +++++++++++----- shap/plots/_text.py | 23 +++++++++++++++++------ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/shap/__init__.py b/shap/__init__.py index 41a7a5336..be0baf211 100644 --- a/shap/__init__.py +++ b/shap/__init__.py @@ -1,5 +1,4 @@ # flake8: noqa -import warnings __version__ = "0.42.0" @@ -25,12 +24,12 @@ # plotting (only loaded if matplotlib is present) def unsupported(*args, **kwargs): - warnings.warn(_no_matplotlib_warning) + raise ImportError(_no_matplotlib_warning) class UnsupportedModule: def __getattribute__(self, item): - raise ValueError(_no_matplotlib_warning) + raise ImportError(_no_matplotlib_warning) try: @@ -39,6 +38,7 @@ def __getattribute__(self, item): except ImportError: have_matplotlib = False if have_matplotlib: + from . import plots from .plots._beeswarm import summary_legacy as summary_plot from .plots._decision import decision as decision_plot, multioutput_decision as multioutput_decision_plot from .plots._scatter import dependence_legacy as dependence_plot diff --git a/shap/plots/_image.py b/shap/plots/_image.py index 10c643c46..7fa9a4f59 100644 --- a/shap/plots/_image.py +++ b/shap/plots/_image.py @@ -1,20 +1,20 @@ import json import random import string -import warnings from typing import Optional import matplotlib.pyplot as pl import numpy as np from matplotlib.colors import Colormap -from .._explanation import Explanation -from ..utils import ordinal_str - try: from IPython.display import HTML, display + have_ipython = True except ImportError: - warnings.warn("IPython could not be loaded!") + have_ipython = False + +from .._explanation import Explanation +from ..utils import ordinal_str from ..utils._legacy import kmeans from . import colors @@ -190,6 +190,12 @@ def image_to_text(shap_values): for each sample """ + if not have_ipython: + msg = ( + "IPython is required for this function but is not installed." + " Fix this with `pip install ipython`." + ) + raise ImportError(msg) if len(shap_values.values.shape) == 5: for i in range(shap_values.values.shape[0]): diff --git a/shap/plots/_text.py b/shap/plots/_text.py index ffead86b6..7e3944547 100644 --- a/shap/plots/_text.py +++ b/shap/plots/_text.py @@ -112,7 +112,7 @@ def values_min_max(values, base_values): separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False ) if display: - ipython_display(HTML(out)) + _ipython_display_html(out) return else: return out @@ -211,7 +211,7 @@ def values_min_max(values, base_values): out += "" out += "" if display: - ipython_display(HTML(out)) + _ipython_display_html(out) return else: return out @@ -256,7 +256,7 @@ def values_min_max(values, base_values): separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False ) if display: - ipython_display(HTML(out)) + _ipython_display_html(out) return else: return out @@ -325,7 +325,7 @@ def values_min_max(values, base_values): out += "" if display: - ipython_display(HTML(out)) + _ipython_display_html(out) return else: return out @@ -831,7 +831,7 @@ def merge_tokens(new_tokens, new_values, group_sizes, i): + "" \ + "" - return ipython_display(HTML(out)) + return _ipython_display_html(out) def text_to_text(shap_values): @@ -887,7 +887,7 @@ def text_to_text(shap_values): """ - ipython_display(HTML(javascript + html)) + _ipython_display_html(javascript + html) def saliency_plot(shap_values): @@ -1339,3 +1339,14 @@ def unpack_shap_explanation_contents(shap_values): clustering = getattr(shap_values, "clustering", None) return np.array(values), clustering + + +def _ipython_display_html(data): + """Check IPython is installed, then display HTML""" + if not have_ipython: + msg = ( + "IPython is required for this function but is not installed." + " Fix this with `pip install ipython`." + ) + raise ImportError(msg) + return ipython_display(HTML(data))