Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Plotting API starting with ROC curve #14357

Merged
merged 46 commits into from Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
aa753b1
WIP
thomasjpfan Jul 5, 2019
d5ba421
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 10, 2019
a0e4723
ENH Adds plot_roc_curve
thomasjpfan Jul 10, 2019
8ae4c70
DOC Adds docs
thomasjpfan Jul 11, 2019
eaac39c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 11, 2019
763d723
BUG Fix
thomasjpfan Jul 11, 2019
2330e06
DOC Adds label to parameters
thomasjpfan Jul 11, 2019
4e33b28
DOC Adds label as a parameter
thomasjpfan Jul 11, 2019
40381ae
API Update with kwargs
thomasjpfan Jul 11, 2019
5f83a80
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 14, 2019
663fc58
BLD Add tests to setup
thomasjpfan Jul 14, 2019
dea2b1b
DOC Update docs
thomasjpfan Jul 14, 2019
3e65c77
DOC Update ordering
thomasjpfan Jul 14, 2019
b4a8d0e
ENH Adds auc to labels
thomasjpfan Jul 17, 2019
eba2453
CLN Updates example with plotting api
thomasjpfan Jul 17, 2019
74b8d7b
TST Updates test
thomasjpfan Jul 17, 2019
272845f
CLN Moves name to plot
thomasjpfan Jul 17, 2019
c51da17
TST Adds more tests
thomasjpfan Jul 18, 2019
5fc6f29
DOC Removes line_kw parameter docstring
thomasjpfan Jul 22, 2019
8832788
DOC Fix docs
thomasjpfan Jul 22, 2019
56ff821
CLN Removes unused import
thomasjpfan Jul 22, 2019
43f1787
CLN Uses kwargs
thomasjpfan Jul 22, 2019
bdf782a
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 23, 2019
6cc8712
CLN Does computation in plot_* function
thomasjpfan Jul 23, 2019
7051e5e
DOC Adds type
thomasjpfan Jul 23, 2019
d461543
CLN Address comments
thomasjpfan Jul 29, 2019
bda4435
BUG Spelling
thomasjpfan Jul 29, 2019
af4209c
TST Assert message
thomasjpfan Jul 29, 2019
014851e
CLN Moves plot to _plot
thomasjpfan Jul 29, 2019
50fb1cd
WIP
thomasjpfan Jul 29, 2019
e620d8c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 30, 2019
4d8534c
DOC Adds user guide
thomasjpfan Jul 30, 2019
46cc0c5
BLD Fix
thomasjpfan Jul 30, 2019
6771d8d
BLD Build docs [doc build]
thomasjpfan Jul 31, 2019
dd7f3cf
CLN Address comments
thomasjpfan Jul 31, 2019
e668471
CLN Address comments
thomasjpfan Aug 1, 2019
94f29e3
STY Update styling
thomasjpfan Aug 1, 2019
9546755
STY Update styling
thomasjpfan Aug 1, 2019
71b19a3
DOC Adds note about parameters stored as attributes
thomasjpfan Aug 1, 2019
f7f2e2e
CLN Updates to roc_auc
thomasjpfan Aug 5, 2019
e2f42fa
ENH Adds check for response_method
thomasjpfan Aug 5, 2019
30925a8
DOC Adds parameters
thomasjpfan Aug 5, 2019
2e8db1f
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Aug 5, 2019
fab9d07
CLN Renames to display
thomasjpfan Aug 5, 2019
759d1b4
CLN Address comments
thomasjpfan Aug 7, 2019
140287d
improve error message
qinhanmin2014 Aug 8, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions doc/developers/contributing.rst
Expand Up @@ -1647,3 +1647,53 @@ make this task easier and faster (in no particular order).
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely
useful to see every occurrence of a pattern (e.g. a function call or a
variable) in the code base.


.. _plotting_api:

Plotting API
============

Scikit-learn defines a simple API for creating visualizations for machine
learning. The key features of this API is to run calculations once and to have
the flexibility to adjust the visualizations after the fact. This logic is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to actually have a very simple example that illustrates this feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to example of user guide now?

encapsulated into a visualizer object where the computed data is stored and
the plotting is done in a `plot` method. The visualizer object's `__init__`
method contains only the data needed to create the visualization. The `plot`
method takes in parameters that only have to do with visualization, such as a
matplotlib axes. The `plot` method will store the matplotlib artists as
attributes allowing for style adjustments through the visualizer object. A
`plot_*` helper function accepts parameters to do the computation and the
parameters used for plotting. After the function creates the visualizer with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"After the function"...

This sentence sounds grammatically odd.

Also, single backticks everywhere now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence was updated at the end. (Now just updated one more time to be hopefully clearer)

After the helper function creates the visualizer
with the computed values, it calls the visualizer's plot method.

the computed values, it calls the visualizer's plot method. Note that the
`plot` method defines attributes related to matplotlib, such as the line
artist. This allows for customizations after calling the `plot` method.

For example, the `RocCurveVisualizer` defines the following methods and
attributes:

.. code-block:: python

class RocCurveVisualizer:
def __init__(self, fpr, tpr, auc_roc, estimator_name):
...
self.fpr = ...
self.tpr = ...
self.auc_roc = ...
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
# do computation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read this comment as referring to the line below, which is not the intent.

Maybe

# ...
# Do computation
# ...

viz = RocCurveVisualizer(fpr, tpr, auc_roc,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)
```

22 changes: 22 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1007,6 +1007,28 @@ See the :ref:`metrics` section of the user guide for further details.
metrics.pairwise_distances_chunked


Plotting tools
--------------

.. automodule:: sklearn.metrics.plot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we've decide to use sklearn.XXX.plot.plot_XXX, right?
We'll put things like plot_decision_boundary in sklearn.inspection, right?

:no-members:
:no-inherited-members:

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: function.rst

metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class_without_init.rst

metrics.plot.RocCurveVisualizer


.. _mixture_ref:

:mod:`sklearn.mixture`: Gaussian Mixture Models
Expand Down
52 changes: 24 additions & 28 deletions examples/model_selection/plot_roc_crossval.py
Expand Up @@ -36,7 +36,8 @@
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import auc
from sklearn.metrics import plot_roc_curve
from sklearn.model_selection import StratifiedKFold

# #############################################################################
Expand Down Expand Up @@ -65,40 +66,35 @@
aucs = []
mean_fpr = np.linspace(0, 1, 100)

i = 0
for train, test in cv.split(X, y):
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
# Compute ROC curve and area the curve
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
tprs.append(interp(mean_fpr, fpr, tpr))
tprs[-1][0] = 0.0
roc_auc = auc(fpr, tpr)
aucs.append(roc_auc)
plt.plot(fpr, tpr, lw=1, alpha=0.3,
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))

i += 1
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8)
fig, ax = plt.subplots()
for i, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = plot_roc_curve(classifier, X[test], y[test],
name='ROC fold {}'.format(i),
alpha=0.3, lw=1, ax=ax)
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.auc_roc)

ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8)

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
plt.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)
ax.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)

std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')

ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
title="Receiver operating characteristic example")
ax.legend(loc="lower right")
plt.show()
4 changes: 4 additions & 0 deletions sklearn/metrics/__init__.py
Expand Up @@ -74,6 +74,9 @@
from .scorer import SCORERS
from .scorer import get_scorer

from .plot import plot_roc_curve


__all__ = [
'accuracy_score',
'adjusted_mutual_info_score',
Expand Down Expand Up @@ -125,6 +128,7 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_roc_curve',
'precision_recall_curve',
'precision_recall_fscore_support',
'precision_score',
Expand Down
6 changes: 6 additions & 0 deletions sklearn/metrics/plot/__init__.py
@@ -0,0 +1,6 @@
from .roc_curve import plot_roc_curve, RocCurveVisualizer

__all__ = [
'plot_roc_curve',
'RocCurveVisualizer'
]
144 changes: 144 additions & 0 deletions sklearn/metrics/plot/roc_curve.py
@@ -0,0 +1,144 @@
from .. import auc
from .. import roc_curve

from ...utils import check_matplotlib_support


class RocCurveVisualizer:
"""ROC Curve visualization.

Parameters
----------
fpr : ndarray
False positive rate.
tpr : ndarray
True positive rate.
auc_roc : float
Area under ROC curve.
estimator_name : str
Name of estimator.

Attributes
----------
line_ : matplotlib Artist
ROC Curve.
ax_ : matplotlib Axes
Axes with ROC Curve
figure_ : matplotlib Figure
Figure containing the curve
"""

def __init__(self, fpr, tpr, auc_roc, estimator_name):
self.fpr = fpr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be private? They are not documented.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are documented as parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But not as attributes, hence my question.

self.tpr = tpr
self.auc_roc = auc_roc
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
"""Plot visualization

Extra keyword arguments will be passed to matplotlib's ``plot``.

Parameters
----------
ax : Matplotlib Axes or None, default=None
Axes object to plot on.

name : str or None, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.
"""
check_matplotlib_support('plot_roc_curve')
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name

if 'label' not in kwargs:
label = "{} (AUC = {:0.2f})".format(name, self.auc_roc)
kwargs['label'] = label
self.line_ = ax.plot(self.fpr, self.tpr, **kwargs)[0]
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()

self.ax_ = ax
self.figure_ = ax.figure
return self


def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's only one artist, I think we don't need to have a dict as an argument and can just do **kwargs instead of having line_kw={...}

drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
"""Plot Receiver operating characteristic (ROC) curve

Extra keyword arguments will be passed to matplotlib's `plot`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be double quotes to avoid resolving


Parameters
----------
estimator : estimator instance
Trained classifier.

X : {array-like, sparse matrix}, shape (n_samples, n_features)
Input values.

y : array-like, shape (n_samples, )
Target values.

pos_label : int or str, default=None
The label of the positive class.
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1},
`pos_label` is set to 1, otherwise an error will be raised.

sample_weight : array-like, shape (n_samples, ) or None, default=None
Sample weights.

drop_intermediate : boolean, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.

response_method : 'predict_proba', 'decision_function', or 'auto' \
default='auto'
Specifies whether to use `predict_proba` or `decision_function` as the
target response. If set to 'auto', `predict_proba` is tried first
and if it does not exist `decision_function` is tried next.

name : str or None, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.

ax : matplotlib axes, default=None
axes object to plot on

Returns
-------
viz : :class:`sklearn.metrics.plot.RocCurveVisualizer`
object that stores computed values
"""
if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(
"response method {} is not defined".format(response_method))
else:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function

if prediction_method is None:
raise ValueError('response methods not defined')

y_pred = prediction_method(X)

if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator must be a binary classifier")
y_pred = y_pred[:, 1]
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
drop_intermediate=drop_intermediate)
auc_roc = auc(fpr, tpr)
viz = RocCurveVisualizer(fpr, tpr, auc_roc, estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)
Empty file.