Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Plotting API starting with ROC curve #14357

Merged
merged 46 commits into from Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
aa753b1
WIP
thomasjpfan Jul 5, 2019
d5ba421
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 10, 2019
a0e4723
ENH Adds plot_roc_curve
thomasjpfan Jul 10, 2019
8ae4c70
DOC Adds docs
thomasjpfan Jul 11, 2019
eaac39c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 11, 2019
763d723
BUG Fix
thomasjpfan Jul 11, 2019
2330e06
DOC Adds label to parameters
thomasjpfan Jul 11, 2019
4e33b28
DOC Adds label as a parameter
thomasjpfan Jul 11, 2019
40381ae
API Update with kwargs
thomasjpfan Jul 11, 2019
5f83a80
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 14, 2019
663fc58
BLD Add tests to setup
thomasjpfan Jul 14, 2019
dea2b1b
DOC Update docs
thomasjpfan Jul 14, 2019
3e65c77
DOC Update ordering
thomasjpfan Jul 14, 2019
b4a8d0e
ENH Adds auc to labels
thomasjpfan Jul 17, 2019
eba2453
CLN Updates example with plotting api
thomasjpfan Jul 17, 2019
74b8d7b
TST Updates test
thomasjpfan Jul 17, 2019
272845f
CLN Moves name to plot
thomasjpfan Jul 17, 2019
c51da17
TST Adds more tests
thomasjpfan Jul 18, 2019
5fc6f29
DOC Removes line_kw parameter docstring
thomasjpfan Jul 22, 2019
8832788
DOC Fix docs
thomasjpfan Jul 22, 2019
56ff821
CLN Removes unused import
thomasjpfan Jul 22, 2019
43f1787
CLN Uses kwargs
thomasjpfan Jul 22, 2019
bdf782a
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 23, 2019
6cc8712
CLN Does computation in plot_* function
thomasjpfan Jul 23, 2019
7051e5e
DOC Adds type
thomasjpfan Jul 23, 2019
d461543
CLN Address comments
thomasjpfan Jul 29, 2019
bda4435
BUG Spelling
thomasjpfan Jul 29, 2019
af4209c
TST Assert message
thomasjpfan Jul 29, 2019
014851e
CLN Moves plot to _plot
thomasjpfan Jul 29, 2019
50fb1cd
WIP
thomasjpfan Jul 29, 2019
e620d8c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 30, 2019
4d8534c
DOC Adds user guide
thomasjpfan Jul 30, 2019
46cc0c5
BLD Fix
thomasjpfan Jul 30, 2019
6771d8d
BLD Build docs [doc build]
thomasjpfan Jul 31, 2019
dd7f3cf
CLN Address comments
thomasjpfan Jul 31, 2019
e668471
CLN Address comments
thomasjpfan Aug 1, 2019
94f29e3
STY Update styling
thomasjpfan Aug 1, 2019
9546755
STY Update styling
thomasjpfan Aug 1, 2019
71b19a3
DOC Adds note about parameters stored as attributes
thomasjpfan Aug 1, 2019
f7f2e2e
CLN Updates to roc_auc
thomasjpfan Aug 5, 2019
e2f42fa
ENH Adds check for response_method
thomasjpfan Aug 5, 2019
30925a8
DOC Adds parameters
thomasjpfan Aug 5, 2019
2e8db1f
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Aug 5, 2019
fab9d07
CLN Renames to display
thomasjpfan Aug 5, 2019
759d1b4
CLN Address comments
thomasjpfan Aug 7, 2019
140287d
improve error message
qinhanmin2014 Aug 8, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 60 additions & 0 deletions doc/developers/contributing.rst
Expand Up @@ -1647,3 +1647,63 @@ make this task easier and faster (in no particular order).
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely
useful to see every occurrence of a pattern (e.g. a function call or a
variable) in the code base.


.. _plotting_api:

Plotting API
------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be ==== I think: right now this is considered a subsection of "Reading the existing code base"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think this whole section might be better suited for the user guide instead of the contributing guide


Scikit-learn defines a simple API for creating visualizations for model. The
key features of this API is to run calculations once and the flexibility to
position and style the plotting elements. This logic is encapsulated into a
visualizer object where the calculations are done during construction and
plotting is done in a ``plot`` method. The visualizer object's ``__init__``
method contains only parameters that are needed to calculate the items in the
visualization and saving the items as attributes. The ``plot`` method takes in
parameters that only have to do with visualization, such as a matplotlib axes.
The ``plot`` method will store the matplotlib artists as attributes allowing
for style adjustments through the visualizer object. Along with the visualizer
object, a ``plot_*`` helper function is defined to create, plot and return the
object. This function will contain the parameters from both the ``__init__``
and ``plot`` methods for the visualizer object and pass the parameters to the
respective methods.

For example, the `RocCurveVisualizer` defines the following methods:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and attributes


.. code-block:: python

class RocCurveVisualizer:
def __init__(self, estimator, X, y, *, pos_label=None,
sample_weight=None, drop_intermediate=True,
response_method="predict_proba", label=None):
...
self.fpr_ = ...
self.tpr_ = ...
self.label_ = ...

def plot(self, ax=None):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_

def plot_roc_curve(estimator, X, y, *, pos_label=None,
sample_weight=None, drop_intermediate=True,
response_method="predict_proba", label=None, ax=None):
viz = RocCurveVisualizer(estimator,
X,
y,
sample_weight=sample_weight,
pos_label=pos_label,
drop_intermediate=drop_intermediate,
response_method=response_method,
label=label)
viz.plot(ax=ax)
return viz
```

Note that the ``__init__`` method defines attributes that are going to be used
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this paragraph can be integrated into the previous one IMO.

for plotting and the ``plot`` method defines attributes that are related to the
the matplotlib object itself. The line artist is stored as an attribute to
allow for customizations after calling ``plot``.
22 changes: 22 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1003,6 +1003,28 @@ See the :ref:`metrics` section of the user guide for further details.
metrics.pairwise_distances_chunked


Plotting metrics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a confusing name considering that right above we have "clustering metrics", "pairwise metrics", etc.

Maybe just "plotting" or "plotting tools" is enough?

----------------

.. automodule:: sklearn.metrics.plot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we've decide to use sklearn.XXX.plot.plot_XXX, right?
We'll put things like plot_decision_boundary in sklearn.inspection, right?

:no-members:
:no-inherited-members:

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: function.rst

metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class_without_init.rst

metrics.plot.RocCurveVisualizer


.. _mixture_ref:

:mod:`sklearn.mixture`: Gaussian Mixture Models
Expand Down
49 changes: 24 additions & 25 deletions examples/model_selection/plot_roc_crossval.py
Expand Up @@ -37,6 +37,7 @@

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import plot_roc_curve
from sklearn.model_selection import StratifiedKFold

# #############################################################################
Expand All @@ -61,44 +62,42 @@
classifier = svm.SVC(kernel='linear', probability=True,
random_state=random_state)

visualizers = []
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)

i = 0
for train, test in cv.split(X, y):
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
# Compute ROC curve and area the curve
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
tprs.append(interp(mean_fpr, fpr, tpr))
tprs[-1][0] = 0.0
roc_auc = auc(fpr, tpr)
aucs.append(roc_auc)
plt.plot(fpr, tpr, lw=1, alpha=0.3,
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))

i += 1
fig, ax = plt.subplots()
for i, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = plot_roc_curve(classifier, X[test], y[test],
name='ROC fold {}'.format(i),
line_kw={'alpha': 0.3, 'lw': 1}, ax=ax)
visualizers.append(viz)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You never use this list, remove it?


interp_tpr = interp(mean_fpr, viz.fpr_, viz.tpr_)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.auc_)

plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't ax.plot here too ?

label='Chance', alpha=.8)

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
plt.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)
ax.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)

std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')

ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
title="Receiver operating characteristic example")
ax.legend(loc="lower right")
plt.show()
4 changes: 4 additions & 0 deletions sklearn/metrics/__init__.py
Expand Up @@ -71,6 +71,9 @@
from .scorer import SCORERS
from .scorer import get_scorer

from .plot import plot_roc_curve


__all__ = [
'accuracy_score',
'adjusted_mutual_info_score',
Expand Down Expand Up @@ -119,6 +122,7 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_roc_curve',
'precision_recall_curve',
'precision_recall_fscore_support',
'precision_score',
Expand Down
6 changes: 6 additions & 0 deletions sklearn/metrics/plot/__init__.py
@@ -0,0 +1,6 @@
from .roc_curve import plot_roc_curve, RocCurveVisualizer

__all__ = [
'plot_roc_curve',
'RocCurveVisualizer'
]
191 changes: 191 additions & 0 deletions sklearn/metrics/plot/roc_curve.py
@@ -0,0 +1,191 @@
from .. import auc
from .. import roc_curve

from ...utils import check_matplotlib_support # noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need noqa? (same below)



class RocCurveVisualizer:
"""ROC Curve visualization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a note to say that

  • this object should not be instantiated directly, and indicate to plot_roc_curve instead.
  • API of __init__ may break without deprecation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't say it shouldn't be instantiated, I'd say something like that this is usually instantiated by the function and should only be manually instantiated for advanced use cases?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me yes


Parameters
----------

estimator : estimator instance
Trained classifier.

X : {array-like, sparse matrix}, shape (n_samples, n_features)
Input values.

y : array-like, shape (n_samples, )
Target values.

pos_label : int str or None, default=None
The label of the positive class.
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1},
`pos_label` is set to 1, otherwise an error will be raised.

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

drop_intermediate : boolean, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.

response_method : 'predict_proba', 'decision_function', or 'auto' \
default='auto'
Method to call estimator to get target scores

Attributes
----------
fpr_ : ndarray
False positive rate.
tpr_ : ndarray
True positive rate.
auc_ :
Area under ROC curve.
name_ : str
Name of estimator.
line_ : matplotlib Artist
ROC Curve.
ax_ : matplotlib Axes
Axes with ROC curv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Axes with ROC curv
Axes with ROC Curve

figure_ : matplotlib Figure
Figure containing the curve
"""

def __init__(self, estimator, X, y, *, pos_label=None, sample_weight=None,
drop_intermediate=True, response_method="auto"):
"""Computes and stores values needed for visualization"""

if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(
"response method {} not defined".format(response_method))
else:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function

if prediction_method is None:
if response_method == 'predict_proba':
raise ValueError('The estimator has no predict_proba method')
else:
raise ValueError(
'The estimator has no decision_function method')

y_pred = prediction_method(X)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should move the computing part outside of the init, in a compute (fit ?) method which looks more like sklearn. Then plot_XXX would do

viz = XXXVisualizer()
viz.compute()
viz.plot()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have considered this and sided with doing the computation in __init__ to simplify the Visualizer API.

With this compute proposal the API looks like:

viz = ROCVisualizer(response_method="predict_proba")
viz.compute(estimator, X, y)
viz.plot()

This comes down how useful a Visualizer object without computing. Our estimators make use of this "delay" computation through fit because we want to put them into Pipelines, etc. For Visualizer, I am not envisioning a need for delaying computation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @thomasjpfan. I don't see a point in building visualization pipelines as yellowbrick does. The only purpose seems to be to "look more like sklearn", resulting in a harder to understand api and more complex code.
I don't think "looking like an estimator" is a good goal if you do something entirely different.

if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator must be a binary classifier")
y_pred = y_pred[:, 1]
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
drop_intermediate=drop_intermediate)

self.fpr_ = fpr
self.tpr_ = tpr
self.auc_ = auc(fpr, tpr)
self.name_ = estimator.__class__.__name__

def plot(self, ax=None, line_kw=None, name=None):
"""Plot visualization

Parameters
----------
ax : Matplotlib Axes, default=None
Axes object to plot on.

line_kw : dict or None, default=None
Keyword arguments to pass to.

name : str or None, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.
"""
check_matplotlib_support('plot_roc_curve') # noqa
import matplotlib.pyplot as plt # noqa

if ax is None:
fig, ax = plt.subplots()

if name is None:
name = self.name_

label = "{} (AUC = {:0.2f})".format(name, self.auc_)
line_kwargs = {"label": label}
if line_kw is not None:
line_kwargs.update(**line_kw)
self.line_ = ax.plot(self.fpr_, self.tpr_, **line_kwargs)[0]
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()

self.ax_ = ax
self.figure_ = ax.figure
return self


def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's only one artist, I think we don't need to have a dict as an argument and can just do **kwargs instead of having line_kw={...}

drop_intermediate=True, response_method="auto",
name=None, ax=None, line_kw=None):
"""Plot Receiver operating characteristic (ROC) curve

Note: this implementation is restricted to the binary classification task.

Read more in the :ref:`User Guide <roc_metrics>`.

Parameters
----------

estimator : estimator instance
Trained classifier.

X : {array-like, sparse matrix}, shape (n_samples, n_features)
Input values.

y : array-like, shape (n_samples,, )
Target values.

pos_label : int or str, default=None
The label of the positive class.
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1},
`pos_label` is set to 1, otherwise an error will be raised.

sample_weight : array-like of shape = [n_samples], optional (default=None)
Sample weights.

drop_intermediate : boolean, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.

response_method : 'predict_proba', 'decision_function', or 'auto' \
default='auto'
Method to call estimator to get target scores
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please describe 'auto'


name : str or None, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.

ax : matplotlib axes, default=None
axes object to plot on

line_kw : dict or None, default=None
Keyword arguments to pass to

Returns
-------
viz : :class:`sklearn.metrics.plot.RocCurveVisualizer`
object that stores computed values
"""
viz = RocCurveVisualizer(estimator,
X,
y,
sample_weight=sample_weight,
pos_label=pos_label,
drop_intermediate=drop_intermediate,
response_method=response_method)
viz.plot(ax=ax, line_kw=line_kw, name=name)
return viz
Empty file.