Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds HTML visualizations for estimators #14180

Merged
merged 104 commits into from Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
7e75b73
ENH Adds export_html
thomasjpfan Jun 25, 2019
2ce3552
CLN Checks for jupyter context
thomasjpfan Jun 25, 2019
f8d2681
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Jun 28, 2019
8b015e1
ENH Updates style
thomasjpfan Jun 28, 2019
893dbfc
TST refactor test_numeric_stability (#14221)
thomasjpfan Jun 30, 2019
f7bfb0c
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Jul 2, 2019
597a99b
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Jul 8, 2019
9ccdbd0
CLN Renames function
thomasjpfan Jul 8, 2019
4df33f8
ENH Adds sphinx extension to visiualize
thomasjpfan Jul 18, 2019
8f57191
MNT Sets font color
thomasjpfan Jul 19, 2019
677e35a
ENH Update styling
thomasjpfan Jul 21, 2019
343773a
STY Update styling
thomasjpfan Jul 22, 2019
1ac0e4c
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Jul 24, 2019
50ed9f3
STY Removes underline
thomasjpfan Jul 25, 2019
1598cad
ENH Updates style
thomasjpfan Jul 31, 2019
464f6d8
ENH Update style
thomasjpfan Aug 1, 2019
b8b3ab0
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Aug 1, 2019
6cb813b
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Aug 5, 2019
90a4d92
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Aug 5, 2019
f4d882c
STY Update styling
thomasjpfan Aug 5, 2019
57e4d0c
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Oct 10, 2019
4288fbf
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Mar 11, 2020
c4cfe63
CLN Address comments
thomasjpfan Mar 11, 2020
2c700fe
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Mar 12, 2020
2749e07
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Mar 12, 2020
92be3e9
ENH Makes display_estimator privatte
thomasjpfan Mar 13, 2020
1b47170
ENN Major visual changes
thomasjpfan Mar 13, 2020
733bade
ENH Update viz
thomasjpfan Mar 13, 2020
ae98ae9
STY Update
thomasjpfan Mar 13, 2020
7b1de5f
STY Update
thomasjpfan Mar 13, 2020
741bc13
CLN Cleaner code
thomasjpfan Mar 13, 2020
b0dd3f2
CLN Improves logic
thomasjpfan Mar 13, 2020
1b14ce2
CLN More polish
thomasjpfan Mar 13, 2020
e03362f
CLN More polish
thomasjpfan Mar 13, 2020
975c823
STY Minor adjustment
thomasjpfan Mar 13, 2020
ecb3ae6
ENH Adds a _sk_rep_html method
thomasjpfan Mar 14, 2020
3451ab0
CLN Less diffs
thomasjpfan Mar 14, 2020
407cfff
CLN Imports higher
thomasjpfan Mar 14, 2020
80d9b10
ENH Better support for dark themes
thomasjpfan Mar 14, 2020
791374b
DOC Includes note about html
thomasjpfan Mar 14, 2020
d297bc7
STY Update
thomasjpfan Mar 14, 2020
9fde84a
STY Update
thomasjpfan Mar 14, 2020
f254e1d
CLN
thomasjpfan Mar 14, 2020
48aebee
CLN Moves code around
thomasjpfan Mar 14, 2020
d44c38e
ENH Adds stacking viz
thomasjpfan Mar 14, 2020
4d8d72a
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Mar 16, 2020
3811190
ENH Better viz
thomasjpfan Mar 17, 2020
9df8a4b
CLN Improves code quality
thomasjpfan Mar 17, 2020
50ee0b4
STY Update
thomasjpfan Mar 18, 2020
212ba21
TST Fix
thomasjpfan Mar 18, 2020
2a81f16
WIP
thomasjpfan Mar 18, 2020
6c11293
WIP
thomasjpfan Mar 18, 2020
da83a68
CLN Address comments
thomasjpfan Mar 18, 2020
55a20e7
ENH Update sphinx extension
thomasjpfan Mar 20, 2020
169964d
WIP
thomasjpfan Mar 20, 2020
b8a84e0
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Mar 20, 2020
1e1bd1b
REV Less diffs
thomasjpfan Mar 20, 2020
ce0fc2c
WIP
thomasjpfan Mar 22, 2020
93da060
CLN Address comments
thomasjpfan Mar 22, 2020
1c08495
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Mar 22, 2020
0a30ced
STY Flake8
thomasjpfan Mar 22, 2020
f740537
CLN More refactoring
thomasjpfan Mar 26, 2020
f656d8b
CLN Outputs repr in latex
thomasjpfan Mar 27, 2020
856ce5d
CLN Adds more tests
thomasjpfan Mar 27, 2020
b5c26b0
STY Update
thomasjpfan Mar 27, 2020
f56060c
TST Fix
thomasjpfan Mar 30, 2020
d6c7de6
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 13, 2020
52a640a
CLN Move to utils
thomasjpfan Apr 13, 2020
66ebce9
DOC Adds html representation in another example
thomasjpfan Apr 13, 2020
24029d3
DOC Adds reference to _repr_html_
thomasjpfan Apr 14, 2020
adc977b
FIx bug
thomasjpfan Apr 14, 2020
9df7573
CLN Rename secret protocol
thomasjpfan Apr 19, 2020
ef02574
CLN Address comments
thomasjpfan Apr 22, 2020
4f12e71
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 23, 2020
dbd7f2c
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 26, 2020
c1b451c
CLN Address comments
thomasjpfan Apr 27, 2020
17f05e8
BUG Fix
thomasjpfan Apr 27, 2020
47d72ba
DOC Use function
thomasjpfan Apr 27, 2020
bae645b
REV Less diffs
thomasjpfan Apr 27, 2020
c616802
REV Remove
thomasjpfan Apr 27, 2020
1fe69fa
BUG Fix
thomasjpfan Apr 27, 2020
8d23d5b
REV Inner estimators do not show changes
thomasjpfan Apr 27, 2020
3d41caf
ENH Uses _repr_mimebundle_
thomasjpfan Apr 27, 2020
1cc87b6
CLN Updates file names
thomasjpfan Apr 27, 2020
3af9151
DOC Remove sphinx extension
thomasjpfan Apr 27, 2020
0b4a64d
CLN Uses None
thomasjpfan Apr 27, 2020
689d3f2
ENH Uses _repr_mimebundle_
thomasjpfan Apr 27, 2020
e0062f1
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 27, 2020
694124f
REV Less diffs
thomasjpfan Apr 27, 2020
058e45e
ENH Defines _repr_html_
thomasjpfan Apr 27, 2020
662b1c3
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 27, 2020
d174a87
Merge branch 'master' into html_viz
jnothman Apr 27, 2020
c5c03da
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 28, 2020
a642185
CLN Address comments
thomasjpfan Apr 28, 2020
185986b
visual_html -> visual_repr
jnothman Apr 28, 2020
419cc09
CLN Changes name again
thomasjpfan Apr 28, 2020
bc03910
CLN Update config name to estimator='display'
thomasjpfan Apr 28, 2020
d57151c
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 28, 2020
4ed258f
CLN Uses estimator_display
thomasjpfan Apr 28, 2020
bfa953a
CLN Now it's display
thomasjpfan Apr 29, 2020
19f43ba
Merge remote-tracking branch 'upstream/master' into html_viz
thomasjpfan Apr 29, 2020
e688cb7
CLN Address comments
thomasjpfan Apr 29, 2020
b994664
BUG Fix
thomasjpfan Apr 29, 2020
57fea52
CLN Address comments
thomasjpfan Apr 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -1569,6 +1569,7 @@ Plotting
utils.deprecated
utils.estimator_checks.check_estimator
utils.estimator_checks.parametrize_with_checks
utils.estimator_html_repr
utils.extmath.safe_sparse_dot
utils.extmath.randomized_range_finder
utils.extmath.randomized_svd
Expand Down
25 changes: 25 additions & 0 deletions doc/modules/compose.rst
Expand Up @@ -528,6 +528,31 @@ above example would be::
('countvectorizer', CountVectorizer(),
'title')])

.. _visualizing_composite_estimators:

Visualizing Composite Estimators
================================

Estimators can be displayed with a HTML representation when shown in a
jupyter notebook. This can be useful to diagnose or visualize a Pipeline with
many estimators. This visualization is activated by setting the
`display` option in :func:`sklearn.set_config`::

>>> from sklearn import set_config
>>> set_config(display='diagram') # doctest: +SKIP
>>> # diplays HTML representation in a jupyter context
>>> column_trans # doctest: +SKIP

An example of the HTML output can be seen in the
**HTML representation of Pipeline** section of
:ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`.
As an alternative, the HTML can be written to a file using
jnothman marked this conversation as resolved.
Show resolved Hide resolved
:func:`~sklearn.utils.estimator_html_repr`::

>>> from sklearn.utils import estimator_html_repr
>>> with open('my_estimator.html', 'w') as f: # doctest: +SKIP
... f.write(estimator_html_repr(clf))

jnothman marked this conversation as resolved.
Show resolved Hide resolved
.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_compose_plot_column_transformer.py`
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -567,6 +567,9 @@ Changelog
:mod:`sklearn.utils`
....................

- |Feature| Adds :func:`utils.estimator_html_repr` for returning a
HTML representation of an estimator. :pr:`14180` by `Thomas Fan`_.

- |Enhancement| improve error message in :func:`utils.validation.column_or_1d`.
:pr:`15926` by :user:`Loïc Estève <lesteve>`.

Expand Down Expand Up @@ -605,6 +608,11 @@ Changelog
Miscellaneous
.............

- |MajorFeature| Adds a HTML representation of estimators to be shown in
a jupyter notebook or lab. This visualization is acitivated by setting the
`display` option in :func:`sklearn.set_config`. :pr:`14180` by
`Thomas Fan`_.

- |Enhancement| ``scikit-learn`` now works with ``mypy`` without errors.
:pr:`16726` by `Roman Yurchak`_.

Expand Down
9 changes: 9 additions & 0 deletions examples/compose/plot_column_transformer_mixed_types.py
Expand Up @@ -87,6 +87,15 @@
clf.fit(X_train, y_train)
print("model score: %.3f" % clf.score(X_test, y_test))

##############################################################################
# HTML representation of ``Pipeline``
###############################################################################
# When the ``Pipeline`` is printed out in a jupyter notebook an HTML
# representation of the estimator is displayed as follows:
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
from sklearn import set_config
set_config(display='diagram')
clf

###############################################################################
# Use ``ColumnTransformer`` by selecting column by data types
###############################################################################
Expand Down
19 changes: 18 additions & 1 deletion sklearn/_config.py
Expand Up @@ -7,6 +7,7 @@
'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)),
'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)),
'print_changed_only': True,
'display': 'text',
}


Expand All @@ -27,7 +28,7 @@ def get_config():


def set_config(assume_finite=None, working_memory=None,
print_changed_only=None):
print_changed_only=None, display=None):
"""Set global scikit-learn configuration

.. versionadded:: 0.19
Expand Down Expand Up @@ -59,6 +60,13 @@ def set_config(assume_finite=None, working_memory=None,

.. versionadded:: 0.21

display : {'text', 'diagram'}, optional
If 'diagram', estimators will be displayed as text in a jupyter lab
of notebook context. If 'text', estimators will be displayed as
text. Default is 'text'.

.. versionadded:: 0.23

See Also
--------
config_context: Context manager for global scikit-learn configuration
Expand All @@ -70,6 +78,8 @@ def set_config(assume_finite=None, working_memory=None,
_global_config['working_memory'] = working_memory
if print_changed_only is not None:
_global_config['print_changed_only'] = print_changed_only
if display is not None:
_global_config['display'] = display


@contextmanager
Expand Down Expand Up @@ -100,6 +110,13 @@ def config_context(**new_config):
.. versionchanged:: 0.23
Default changed from False to True.

display : {'text', 'diagram'}, optional
If 'diagram', estimators will be displayed as text in a jupyter lab
of notebook context. If 'text', estimators will be displayed as
text. Default is 'text'.

.. versionadded:: 0.23

Notes
-----
All settings, not just those presently modified, will be returned to
Expand Down
13 changes: 13 additions & 0 deletions sklearn/base.py
Expand Up @@ -17,9 +17,11 @@
import numpy as np

from . import __version__
from ._config import get_config
from .utils import _IS_32BIT
from .utils.validation import check_X_y
from .utils.validation import check_array
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _deprecate_positional_args

_DEFAULT_TAGS = {
Expand Down Expand Up @@ -435,6 +437,17 @@ def _validate_data(self, X, y=None, reset=True,

return out

def _repr_html_(self):
"""HTML representation of estimator"""
return estimator_html_repr(self)

def _repr_mimebundle_(self, **kwargs):
"""Mime bundle used by jupyter kernels to display estimator"""
output = {"text/plain": repr(self)}
if get_config()["display"] == 'diagram':
output["text/html"] = estimator_html_repr(self)
return output


class ClassifierMixin:
"""Mixin class for all classifiers in scikit-learn."""
Expand Down
6 changes: 6 additions & 0 deletions sklearn/compose/_column_transformer.py
Expand Up @@ -15,6 +15,7 @@
from joblib import Parallel, delayed

from ..base import clone, TransformerMixin
from ..utils._estimator_html_repr import _VisualBlock
from ..pipeline import _fit_transform_one, _transform_one, _name_estimators
from ..preprocessing import FunctionTransformer
from ..utils import Bunch
Expand Down Expand Up @@ -637,6 +638,11 @@ def _hstack(self, Xs):
Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs]
return np.hstack(Xs)

def _sk_visual_block_(self):
names, transformers, name_details = zip(*self.transformers)
return _VisualBlock('parallel', transformers,
names=names, name_details=name_details)


def _check_X(X):
"""Use check_array only on lists and other non-array-likes / sparse"""
Expand Down
27 changes: 27 additions & 0 deletions sklearn/ensemble/_stacking.py
Expand Up @@ -13,6 +13,7 @@
from ..base import clone
from ..base import ClassifierMixin, RegressorMixin, TransformerMixin
from ..base import is_classifier, is_regressor
from ..utils._estimator_html_repr import _VisualBlock

from ._base import _fit_single_estimator
from ._base import _BaseHeterogeneousEnsemble
Expand Down Expand Up @@ -233,6 +234,14 @@ def predict(self, X, **predict_params):
self.transform(X), **predict_params
)

def _sk_visual_block_(self, final_estimator):
names, estimators = zip(*self.estimators)
parallel = _VisualBlock('parallel', estimators, names=names,
dash_wrapped=False)
serial = _VisualBlock('serial', (parallel, final_estimator),
dash_wrapped=False)
return _VisualBlock('serial', [serial])


class StackingClassifier(ClassifierMixin, _BaseStacking):
"""Stack of estimators with a final classifier.
Expand Down Expand Up @@ -496,6 +505,15 @@ def transform(self, X):
"""
return self._transform(X)

def _sk_visual_block_(self):
# If final_estimator's default changes then this should be
# updated.
if self.final_estimator is None:
final_estimator = LogisticRegression()
else:
final_estimator = self.final_estimator
Comment on lines +511 to +514
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this one. Would it be fair to say an unfitted estimator here doesn't have an html repr? Or have None as given by the user? Or change the final_estimator's default value to LogisticRegression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation, an estimator always has an unfitted estimator. In this stacker, if None is given by the user, it will become LogisticRegression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe leave a comment about that here and also during validation where LogisticRegression replaces None so that in the future they don't go out of sync?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay (missed this comment)

return super()._sk_visual_block_(final_estimator)


class StackingRegressor(RegressorMixin, _BaseStacking):
"""Stack of estimators with a final regressor.
Expand Down Expand Up @@ -665,3 +683,12 @@ def transform(self, X):
Prediction outputs for each estimator.
"""
return self._transform(X)

def _sk_visual_block_(self):
# If final_estimator's default changes then this should be
# updated.
if self.final_estimator is None:
final_estimator = RidgeCV()
else:
final_estimator = self.final_estimator
return super()._sk_visual_block_(final_estimator)
5 changes: 5 additions & 0 deletions sklearn/ensemble/_voting.py
Expand Up @@ -32,6 +32,7 @@
from ..utils.validation import column_or_1d
from ..utils.validation import _deprecate_positional_args
from ..exceptions import NotFittedError
from ..utils._estimator_html_repr import _VisualBlock


class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble):
Expand Down Expand Up @@ -104,6 +105,10 @@ def n_features_in_(self):

return self.estimators_[0].n_features_in_

def _sk_visual_block_(self):
names, estimators = zip(*self.estimators)
return _VisualBlock('parallel', estimators, names=names)


class VotingClassifier(ClassifierMixin, _BaseVoting):
"""Soft Voting/Majority Rule classifier for unfitted estimators.
Expand Down
20 changes: 20 additions & 0 deletions sklearn/pipeline.py
Expand Up @@ -18,6 +18,7 @@
from joblib import Parallel, delayed

from .base import clone, TransformerMixin
from .utils._estimator_html_repr import _VisualBlock
from .utils.metaestimators import if_delegate_has_method
from .utils import Bunch, _print_elapsed_time
from .utils.validation import check_memory
Expand Down Expand Up @@ -623,6 +624,21 @@ def n_features_in_(self):
# delegate to first step (which will call _check_is_fitted)
return self.steps[0][1].n_features_in_

def _sk_visual_block_(self):
_, estimators = zip(*self.steps)

def _get_name(name, est):
if est is None or est == 'passthrough':
return f'{name}: passthrough'
# Is an estimator
return f'{name}: {est.__class__.__name__}'
names = [_get_name(name, est) for name, est in self.steps]
name_details = [str(est) for est in estimators]
return _VisualBlock('serial', estimators,
names=names,
name_details=name_details,
dash_wrapped=False)


def _name_estimators(estimators):
"""Generate names for estimators."""
Expand Down Expand Up @@ -1004,6 +1020,10 @@ def n_features_in_(self):
# X is passed to all transformers so we just delegate to the first one
return self.transformer_list[0][1].n_features_in_

def _sk_visual_block_(self):
names, transformers = zip(*self.transformer_list)
return _VisualBlock('parallel', transformers, names=names)


def make_union(*transformers, **kwargs):
"""
Expand Down
14 changes: 14 additions & 0 deletions sklearn/tests/test_base.py
Expand Up @@ -23,6 +23,7 @@

from sklearn.base import TransformerMixin
from sklearn.utils._mocking import MockDataFrame
from sklearn import config_context
import pickle


Expand Down Expand Up @@ -511,3 +512,16 @@ def fit(self, X, y=None):
params = est.get_params()

assert params['param'] is None


def test_repr_mimebundle_():
# Checks the display configuration flag controls the json output
tree = DecisionTreeClassifier()
output = tree._repr_mimebundle_()
assert "text/plain" in output
assert "text/html" not in output

with config_context(display='diagram'):
output = tree._repr_mimebundle_()
assert "text/plain" in output
assert "text/html" in output
9 changes: 6 additions & 3 deletions sklearn/tests/test_config.py
Expand Up @@ -4,15 +4,17 @@

def test_config_context():
assert get_config() == {'assume_finite': False, 'working_memory': 1024,
'print_changed_only': True}
'print_changed_only': True,
'display': 'text'}

# Not using as a context manager affects nothing
config_context(assume_finite=True)
assert get_config()['assume_finite'] is False

with config_context(assume_finite=True):
assert get_config() == {'assume_finite': True, 'working_memory': 1024,
'print_changed_only': True}
'print_changed_only': True,
'display': 'text'}
assert get_config()['assume_finite'] is False

with config_context(assume_finite=True):
Expand All @@ -37,7 +39,8 @@ def test_config_context():
assert get_config()['assume_finite'] is True

assert get_config() == {'assume_finite': False, 'working_memory': 1024,
'print_changed_only': True}
'print_changed_only': True,
'display': 'text'}

# No positional arguments
assert_raises(TypeError, config_context, True)
Expand Down
3 changes: 2 additions & 1 deletion sklearn/utils/__init__.py
Expand Up @@ -25,6 +25,7 @@
from ..exceptions import DataConversionWarning
from .deprecation import deprecated
from .fixes import np_version
from ._estimator_html_repr import estimator_html_repr
from .validation import (as_float_array,
assert_all_finite,
check_random_state, column_or_1d, check_array,
Expand Down Expand Up @@ -52,7 +53,7 @@
"check_symmetric", "indices_to_mask", "deprecated",
"parallel_backend", "register_parallel_backend",
"resample", "shuffle", "check_matplotlib_support", "all_estimators",
"DataConversionWarning"
"DataConversionWarning", "estimator_html_repr"
]

IS_PYPY = platform.python_implementation() == 'PyPy'
Expand Down