Skip to content

Commit

Permalink
ENH Adds HTML visualizations for estimators (#14180)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and adrinjalali committed Apr 30, 2020
1 parent b5a1417 commit 425564b
Show file tree
Hide file tree
Showing 15 changed files with 732 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -1569,6 +1569,7 @@ Plotting
utils.deprecated
utils.estimator_checks.check_estimator
utils.estimator_checks.parametrize_with_checks
utils.estimator_html_repr
utils.extmath.safe_sparse_dot
utils.extmath.randomized_range_finder
utils.extmath.randomized_svd
Expand Down
25 changes: 25 additions & 0 deletions doc/modules/compose.rst
Expand Up @@ -528,6 +528,31 @@ above example would be::
('countvectorizer', CountVectorizer(),
'title')])

.. _visualizing_composite_estimators:

Visualizing Composite Estimators
================================

Estimators can be displayed with a HTML representation when shown in a
jupyter notebook. This can be useful to diagnose or visualize a Pipeline with
many estimators. This visualization is activated by setting the
`display` option in :func:`sklearn.set_config`::

>>> from sklearn import set_config
>>> set_config(display='diagram') # doctest: +SKIP
>>> # diplays HTML representation in a jupyter context
>>> column_trans # doctest: +SKIP

An example of the HTML output can be seen in the
**HTML representation of Pipeline** section of
:ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`.
As an alternative, the HTML can be written to a file using
:func:`~sklearn.utils.estimator_html_repr`::

>>> from sklearn.utils import estimator_html_repr
>>> with open('my_estimator.html', 'w') as f: # doctest: +SKIP
... f.write(estimator_html_repr(clf))

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_compose_plot_column_transformer.py`
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -567,6 +567,9 @@ Changelog
:mod:`sklearn.utils`
....................

- |Feature| Adds :func:`utils.estimator_html_repr` for returning a
HTML representation of an estimator. :pr:`14180` by `Thomas Fan`_.

- |Enhancement| improve error message in :func:`utils.validation.column_or_1d`.
:pr:`15926` by :user:`Loïc Estève <lesteve>`.

Expand Down Expand Up @@ -605,6 +608,11 @@ Changelog
Miscellaneous
.............

- |MajorFeature| Adds a HTML representation of estimators to be shown in
a jupyter notebook or lab. This visualization is acitivated by setting the
`display` option in :func:`sklearn.set_config`. :pr:`14180` by
`Thomas Fan`_.

- |Enhancement| ``scikit-learn`` now works with ``mypy`` without errors.
:pr:`16726` by `Roman Yurchak`_.

Expand Down
9 changes: 9 additions & 0 deletions examples/compose/plot_column_transformer_mixed_types.py
Expand Up @@ -87,6 +87,15 @@
clf.fit(X_train, y_train)
print("model score: %.3f" % clf.score(X_test, y_test))

##############################################################################
# HTML representation of ``Pipeline``
###############################################################################
# When the ``Pipeline`` is printed out in a jupyter notebook an HTML
# representation of the estimator is displayed as follows:
from sklearn import set_config
set_config(display='diagram')
clf

###############################################################################
# Use ``ColumnTransformer`` by selecting column by data types
###############################################################################
Expand Down
19 changes: 18 additions & 1 deletion sklearn/_config.py
Expand Up @@ -7,6 +7,7 @@
'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)),
'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)),
'print_changed_only': True,
'display': 'text',
}


Expand All @@ -27,7 +28,7 @@ def get_config():


def set_config(assume_finite=None, working_memory=None,
print_changed_only=None):
print_changed_only=None, display=None):
"""Set global scikit-learn configuration
.. versionadded:: 0.19
Expand Down Expand Up @@ -59,6 +60,13 @@ def set_config(assume_finite=None, working_memory=None,
.. versionadded:: 0.21
display : {'text', 'diagram'}, optional
If 'diagram', estimators will be displayed as text in a jupyter lab
of notebook context. If 'text', estimators will be displayed as
text. Default is 'text'.
.. versionadded:: 0.23
See Also
--------
config_context: Context manager for global scikit-learn configuration
Expand All @@ -70,6 +78,8 @@ def set_config(assume_finite=None, working_memory=None,
_global_config['working_memory'] = working_memory
if print_changed_only is not None:
_global_config['print_changed_only'] = print_changed_only
if display is not None:
_global_config['display'] = display


@contextmanager
Expand Down Expand Up @@ -100,6 +110,13 @@ def config_context(**new_config):
.. versionchanged:: 0.23
Default changed from False to True.
display : {'text', 'diagram'}, optional
If 'diagram', estimators will be displayed as text in a jupyter lab
of notebook context. If 'text', estimators will be displayed as
text. Default is 'text'.
.. versionadded:: 0.23
Notes
-----
All settings, not just those presently modified, will be returned to
Expand Down
13 changes: 13 additions & 0 deletions sklearn/base.py
Expand Up @@ -17,9 +17,11 @@
import numpy as np

from . import __version__
from ._config import get_config
from .utils import _IS_32BIT
from .utils.validation import check_X_y
from .utils.validation import check_array
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _deprecate_positional_args

_DEFAULT_TAGS = {
Expand Down Expand Up @@ -435,6 +437,17 @@ def _validate_data(self, X, y=None, reset=True,

return out

def _repr_html_(self):
"""HTML representation of estimator"""
return estimator_html_repr(self)

def _repr_mimebundle_(self, **kwargs):
"""Mime bundle used by jupyter kernels to display estimator"""
output = {"text/plain": repr(self)}
if get_config()["display"] == 'diagram':
output["text/html"] = estimator_html_repr(self)
return output


class ClassifierMixin:
"""Mixin class for all classifiers in scikit-learn."""
Expand Down
6 changes: 6 additions & 0 deletions sklearn/compose/_column_transformer.py
Expand Up @@ -15,6 +15,7 @@
from joblib import Parallel, delayed

from ..base import clone, TransformerMixin
from ..utils._estimator_html_repr import _VisualBlock
from ..pipeline import _fit_transform_one, _transform_one, _name_estimators
from ..preprocessing import FunctionTransformer
from ..utils import Bunch
Expand Down Expand Up @@ -637,6 +638,11 @@ def _hstack(self, Xs):
Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs]
return np.hstack(Xs)

def _sk_visual_block_(self):
names, transformers, name_details = zip(*self.transformers)
return _VisualBlock('parallel', transformers,
names=names, name_details=name_details)


def _check_X(X):
"""Use check_array only on lists and other non-array-likes / sparse"""
Expand Down
27 changes: 27 additions & 0 deletions sklearn/ensemble/_stacking.py
Expand Up @@ -13,6 +13,7 @@
from ..base import clone
from ..base import ClassifierMixin, RegressorMixin, TransformerMixin
from ..base import is_classifier, is_regressor
from ..utils._estimator_html_repr import _VisualBlock

from ._base import _fit_single_estimator
from ._base import _BaseHeterogeneousEnsemble
Expand Down Expand Up @@ -233,6 +234,14 @@ def predict(self, X, **predict_params):
self.transform(X), **predict_params
)

def _sk_visual_block_(self, final_estimator):
names, estimators = zip(*self.estimators)
parallel = _VisualBlock('parallel', estimators, names=names,
dash_wrapped=False)
serial = _VisualBlock('serial', (parallel, final_estimator),
dash_wrapped=False)
return _VisualBlock('serial', [serial])


class StackingClassifier(ClassifierMixin, _BaseStacking):
"""Stack of estimators with a final classifier.
Expand Down Expand Up @@ -496,6 +505,15 @@ def transform(self, X):
"""
return self._transform(X)

def _sk_visual_block_(self):
# If final_estimator's default changes then this should be
# updated.
if self.final_estimator is None:
final_estimator = LogisticRegression()
else:
final_estimator = self.final_estimator
return super()._sk_visual_block_(final_estimator)


class StackingRegressor(RegressorMixin, _BaseStacking):
"""Stack of estimators with a final regressor.
Expand Down Expand Up @@ -665,3 +683,12 @@ def transform(self, X):
Prediction outputs for each estimator.
"""
return self._transform(X)

def _sk_visual_block_(self):
# If final_estimator's default changes then this should be
# updated.
if self.final_estimator is None:
final_estimator = RidgeCV()
else:
final_estimator = self.final_estimator
return super()._sk_visual_block_(final_estimator)
5 changes: 5 additions & 0 deletions sklearn/ensemble/_voting.py
Expand Up @@ -32,6 +32,7 @@
from ..utils.validation import column_or_1d
from ..utils.validation import _deprecate_positional_args
from ..exceptions import NotFittedError
from ..utils._estimator_html_repr import _VisualBlock


class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble):
Expand Down Expand Up @@ -104,6 +105,10 @@ def n_features_in_(self):

return self.estimators_[0].n_features_in_

def _sk_visual_block_(self):
names, estimators = zip(*self.estimators)
return _VisualBlock('parallel', estimators, names=names)


class VotingClassifier(ClassifierMixin, _BaseVoting):
"""Soft Voting/Majority Rule classifier for unfitted estimators.
Expand Down
20 changes: 20 additions & 0 deletions sklearn/pipeline.py
Expand Up @@ -18,6 +18,7 @@
from joblib import Parallel, delayed

from .base import clone, TransformerMixin
from .utils._estimator_html_repr import _VisualBlock
from .utils.metaestimators import if_delegate_has_method
from .utils import Bunch, _print_elapsed_time
from .utils.validation import check_memory
Expand Down Expand Up @@ -623,6 +624,21 @@ def n_features_in_(self):
# delegate to first step (which will call _check_is_fitted)
return self.steps[0][1].n_features_in_

def _sk_visual_block_(self):
_, estimators = zip(*self.steps)

def _get_name(name, est):
if est is None or est == 'passthrough':
return f'{name}: passthrough'
# Is an estimator
return f'{name}: {est.__class__.__name__}'
names = [_get_name(name, est) for name, est in self.steps]
name_details = [str(est) for est in estimators]
return _VisualBlock('serial', estimators,
names=names,
name_details=name_details,
dash_wrapped=False)


def _name_estimators(estimators):
"""Generate names for estimators."""
Expand Down Expand Up @@ -1004,6 +1020,10 @@ def n_features_in_(self):
# X is passed to all transformers so we just delegate to the first one
return self.transformer_list[0][1].n_features_in_

def _sk_visual_block_(self):
names, transformers = zip(*self.transformer_list)
return _VisualBlock('parallel', transformers, names=names)


def make_union(*transformers, **kwargs):
"""
Expand Down
14 changes: 14 additions & 0 deletions sklearn/tests/test_base.py
Expand Up @@ -23,6 +23,7 @@

from sklearn.base import TransformerMixin
from sklearn.utils._mocking import MockDataFrame
from sklearn import config_context
import pickle


Expand Down Expand Up @@ -511,3 +512,16 @@ def fit(self, X, y=None):
params = est.get_params()

assert params['param'] is None


def test_repr_mimebundle_():
# Checks the display configuration flag controls the json output
tree = DecisionTreeClassifier()
output = tree._repr_mimebundle_()
assert "text/plain" in output
assert "text/html" not in output

with config_context(display='diagram'):
output = tree._repr_mimebundle_()
assert "text/plain" in output
assert "text/html" in output
9 changes: 6 additions & 3 deletions sklearn/tests/test_config.py
Expand Up @@ -4,15 +4,17 @@

def test_config_context():
assert get_config() == {'assume_finite': False, 'working_memory': 1024,
'print_changed_only': True}
'print_changed_only': True,
'display': 'text'}

# Not using as a context manager affects nothing
config_context(assume_finite=True)
assert get_config()['assume_finite'] is False

with config_context(assume_finite=True):
assert get_config() == {'assume_finite': True, 'working_memory': 1024,
'print_changed_only': True}
'print_changed_only': True,
'display': 'text'}
assert get_config()['assume_finite'] is False

with config_context(assume_finite=True):
Expand All @@ -37,7 +39,8 @@ def test_config_context():
assert get_config()['assume_finite'] is True

assert get_config() == {'assume_finite': False, 'working_memory': 1024,
'print_changed_only': True}
'print_changed_only': True,
'display': 'text'}

# No positional arguments
assert_raises(TypeError, config_context, True)
Expand Down
3 changes: 2 additions & 1 deletion sklearn/utils/__init__.py
Expand Up @@ -25,6 +25,7 @@
from ..exceptions import DataConversionWarning
from .deprecation import deprecated
from .fixes import np_version
from ._estimator_html_repr import estimator_html_repr
from .validation import (as_float_array,
assert_all_finite,
check_random_state, column_or_1d, check_array,
Expand Down Expand Up @@ -52,7 +53,7 @@
"check_symmetric", "indices_to_mask", "deprecated",
"parallel_backend", "register_parallel_backend",
"resample", "shuffle", "check_matplotlib_support", "all_estimators",
"DataConversionWarning"
"DataConversionWarning", "estimator_html_repr"
]

IS_PYPY = platform.python_implementation() == 'PyPy'
Expand Down

0 comments on commit 425564b

Please sign in to comment.