New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Plotting API starting with ROC curve #14357
Changes from 28 commits
aa753b1
d5ba421
a0e4723
8ae4c70
eaac39c
763d723
2330e06
4e33b28
40381ae
5f83a80
663fc58
dea2b1b
3e65c77
b4a8d0e
eba2453
74b8d7b
272845f
c51da17
5fc6f29
8832788
56ff821
43f1787
bdf782a
6cc8712
7051e5e
d461543
bda4435
af4209c
014851e
50fb1cd
e620d8c
4d8534c
46cc0c5
6771d8d
dd7f3cf
e668471
94f29e3
9546755
71b19a3
f7f2e2e
e2f42fa
30925a8
2e8db1f
fab9d07
759d1b4
140287d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1647,3 +1647,53 @@ make this task easier and faster (in no particular order). | |
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely | ||
useful to see every occurrence of a pattern (e.g. a function call or a | ||
variable) in the code base. | ||
|
||
|
||
.. _plotting_api: | ||
|
||
Plotting API | ||
============ | ||
|
||
Scikit-learn defines a simple API for creating visualizations for machine | ||
learning. The key features of this API is to run calculations once and to have | ||
the flexibility to adjust the visualizations after the fact. This logic is | ||
encapsulated into a visualizer object where the computed data is stored and | ||
the plotting is done in a `plot` method. The visualizer object's `__init__` | ||
method contains only the data needed to create the visualization. The `plot` | ||
method takes in parameters that only have to do with visualization, such as a | ||
matplotlib axes. The `plot` method will store the matplotlib artists as | ||
attributes allowing for style adjustments through the visualizer object. A | ||
`plot_*` helper function accepts parameters to do the computation and the | ||
parameters used for plotting. After the function creates the visualizer with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "After the function"... This sentence sounds grammatically odd. Also, single backticks everywhere now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sentence was updated at the end. (Now just updated one more time to be hopefully clearer)
|
||
the computed values, it calls the visualizer's plot method. Note that the | ||
`plot` method defines attributes related to matplotlib, such as the line | ||
artist. This allows for customizations after calling the `plot` method. | ||
|
||
For example, the `RocCurveVisualizer` defines the following methods and | ||
attributes: | ||
|
||
.. code-block:: python | ||
|
||
class RocCurveVisualizer: | ||
def __init__(self, fpr, tpr, auc_roc, estimator_name): | ||
... | ||
self.fpr = ... | ||
self.tpr = ... | ||
self.auc_roc = ... | ||
self.estimator_name = estimator_name | ||
|
||
def plot(self, ax=None, name=None, **kwargs): | ||
... | ||
self.line_ = ... | ||
self.ax_ = ax | ||
self.figure_ = ax.figure_ | ||
|
||
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, | ||
drop_intermediate=True, response_method="auto", | ||
name=None, ax=None, **kwargs): | ||
# do computation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I read this comment as referring to the line below, which is not the intent. Maybe # ...
# Do computation
# ... |
||
viz = RocCurveVisualizer(fpr, tpr, auc_roc, | ||
estimator.__class__.__name__) | ||
return viz.plot(ax=ax, name=name, **kwargs) | ||
``` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1007,6 +1007,28 @@ See the :ref:`metrics` section of the user guide for further details. | |
metrics.pairwise_distances_chunked | ||
|
||
|
||
Plotting tools | ||
-------------- | ||
|
||
.. automodule:: sklearn.metrics.plot | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we've decide to use sklearn.XXX.plot.plot_XXX, right? |
||
:no-members: | ||
:no-inherited-members: | ||
|
||
.. currentmodule:: sklearn | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: function.rst | ||
|
||
metrics.plot_roc_curve | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: class_without_init.rst | ||
|
||
metrics.plot.RocCurveVisualizer | ||
|
||
|
||
.. _mixture_ref: | ||
|
||
:mod:`sklearn.mixture`: Gaussian Mixture Models | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .roc_curve import plot_roc_curve, RocCurveVisualizer | ||
|
||
__all__ = [ | ||
'plot_roc_curve', | ||
'RocCurveVisualizer' | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from .. import auc | ||
from .. import roc_curve | ||
|
||
from ...utils import check_matplotlib_support | ||
|
||
|
||
class RocCurveVisualizer: | ||
"""ROC Curve visualization. | ||
|
||
Parameters | ||
---------- | ||
fpr : ndarray | ||
False positive rate. | ||
tpr : ndarray | ||
True positive rate. | ||
auc_roc : float | ||
Area under ROC curve. | ||
estimator_name : str | ||
Name of estimator. | ||
|
||
Attributes | ||
---------- | ||
line_ : matplotlib Artist | ||
ROC Curve. | ||
ax_ : matplotlib Axes | ||
Axes with ROC Curve | ||
figure_ : matplotlib Figure | ||
Figure containing the curve | ||
""" | ||
|
||
def __init__(self, fpr, tpr, auc_roc, estimator_name): | ||
self.fpr = fpr | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these be private? They are not documented. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are documented as parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But not as attributes, hence my question. |
||
self.tpr = tpr | ||
self.auc_roc = auc_roc | ||
self.estimator_name = estimator_name | ||
|
||
def plot(self, ax=None, name=None, **kwargs): | ||
"""Plot visualization | ||
|
||
Extra keyword arguments will be passed to matplotlib's ``plot``. | ||
|
||
Parameters | ||
---------- | ||
ax : Matplotlib Axes or None, default=None | ||
Axes object to plot on. | ||
|
||
name : str or None, default=None | ||
Name of ROC Curve for labeling. If `None`, use the name of the | ||
estimator. | ||
""" | ||
check_matplotlib_support('plot_roc_curve') | ||
import matplotlib.pyplot as plt | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots() | ||
|
||
name = self.estimator_name if name is None else name | ||
|
||
if 'label' not in kwargs: | ||
label = "{} (AUC = {:0.2f})".format(name, self.auc_roc) | ||
kwargs['label'] = label | ||
self.line_ = ax.plot(self.fpr, self.tpr, **kwargs)[0] | ||
ax.set_xlabel("False Positive Rate") | ||
ax.set_ylabel("True Positive Rate") | ||
ax.legend() | ||
|
||
self.ax_ = ax | ||
self.figure_ = ax.figure | ||
return self | ||
|
||
|
||
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's only one artist, I think we don't need to have a dict as an argument and can just do |
||
drop_intermediate=True, response_method="auto", | ||
name=None, ax=None, **kwargs): | ||
"""Plot Receiver operating characteristic (ROC) curve | ||
|
||
Extra keyword arguments will be passed to matplotlib's `plot`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be double quotes to avoid resolving |
||
|
||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
Trained classifier. | ||
|
||
X : {array-like, sparse matrix}, shape (n_samples, n_features) | ||
Input values. | ||
|
||
y : array-like, shape (n_samples, ) | ||
Target values. | ||
|
||
pos_label : int or str, default=None | ||
The label of the positive class. | ||
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, | ||
`pos_label` is set to 1, otherwise an error will be raised. | ||
|
||
sample_weight : array-like, shape (n_samples, ) or None, default=None | ||
Sample weights. | ||
|
||
drop_intermediate : boolean, default=True | ||
Whether to drop some suboptimal thresholds which would not appear | ||
on a plotted ROC curve. This is useful in order to create lighter | ||
ROC curves. | ||
|
||
response_method : 'predict_proba', 'decision_function', or 'auto' \ | ||
default='auto' | ||
Specifies whether to use `predict_proba` or `decision_function` as the | ||
target response. If set to 'auto', `predict_proba` is tried first | ||
and if it does not exist `decision_function` is tried next. | ||
|
||
name : str or None, default=None | ||
Name of ROC Curve for labeling. If `None`, use the name of the | ||
estimator. | ||
|
||
ax : matplotlib axes, default=None | ||
axes object to plot on | ||
|
||
Returns | ||
------- | ||
viz : :class:`sklearn.metrics.plot.RocCurveVisualizer` | ||
object that stores computed values | ||
""" | ||
if response_method != "auto": | ||
prediction_method = getattr(estimator, response_method, None) | ||
if prediction_method is None: | ||
raise ValueError( | ||
"response method {} is not defined".format(response_method)) | ||
else: | ||
predict_proba = getattr(estimator, 'predict_proba', None) | ||
decision_function = getattr(estimator, 'decision_function', None) | ||
prediction_method = predict_proba or decision_function | ||
|
||
if prediction_method is None: | ||
raise ValueError('response methods not defined') | ||
|
||
y_pred = prediction_method(X) | ||
|
||
if y_pred.ndim != 1: | ||
if y_pred.shape[1] > 2: | ||
raise ValueError("Estimator must be a binary classifier") | ||
y_pred = y_pred[:, 1] | ||
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label, | ||
drop_intermediate=drop_intermediate) | ||
auc_roc = auc(fpr, tpr) | ||
viz = RocCurveVisualizer(fpr, tpr, auc_roc, estimator.__class__.__name__) | ||
return viz.plot(ax=ax, name=name, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to actually have a very simple example that illustrates this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link to example of user guide now?