New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Plotting API starting with ROC curve #14357
Changes from 17 commits
aa753b1
d5ba421
a0e4723
8ae4c70
eaac39c
763d723
2330e06
4e33b28
40381ae
5f83a80
663fc58
dea2b1b
3e65c77
b4a8d0e
eba2453
74b8d7b
272845f
c51da17
5fc6f29
8832788
56ff821
43f1787
bdf782a
6cc8712
7051e5e
d461543
bda4435
af4209c
014851e
50fb1cd
e620d8c
4d8534c
46cc0c5
6771d8d
dd7f3cf
e668471
94f29e3
9546755
71b19a3
f7f2e2e
e2f42fa
30925a8
2e8db1f
fab9d07
759d1b4
140287d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1647,3 +1647,63 @@ make this task easier and faster (in no particular order). | |
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely | ||
useful to see every occurrence of a pattern (e.g. a function call or a | ||
variable) in the code base. | ||
|
||
|
||
.. _plotting_api: | ||
|
||
Plotting API | ||
------------ | ||
|
||
Scikit-learn defines a simple API for creating visualizations for model. The | ||
key features of this API is to run calculations once and the flexibility to | ||
position and style the plotting elements. This logic is encapsulated into a | ||
visualizer object where the calculations are done during construction and | ||
plotting is done in a ``plot`` method. The visualizer object's ``__init__`` | ||
method contains only parameters that are needed to calculate the items in the | ||
visualization and saving the items as attributes. The ``plot`` method takes in | ||
parameters that only have to do with visualization, such as a matplotlib axes. | ||
The ``plot`` method will store the matplotlib artists as attributes allowing | ||
for style adjustments through the visualizer object. Along with the visualizer | ||
object, a ``plot_*`` helper function is defined to create, plot and return the | ||
object. This function will contain the parameters from both the ``__init__`` | ||
and ``plot`` methods for the visualizer object and pass the parameters to the | ||
respective methods. | ||
|
||
For example, the `RocCurveVisualizer` defines the following methods: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... and attributes |
||
|
||
.. code-block:: python | ||
|
||
class RocCurveVisualizer: | ||
def __init__(self, estimator, X, y, *, pos_label=None, | ||
sample_weight=None, drop_intermediate=True, | ||
response_method="predict_proba", label=None): | ||
... | ||
self.fpr_ = ... | ||
self.tpr_ = ... | ||
self.label_ = ... | ||
|
||
def plot(self, ax=None): | ||
... | ||
self.line_ = ... | ||
self.ax_ = ax | ||
self.figure_ = ax.figure_ | ||
|
||
def plot_roc_curve(estimator, X, y, *, pos_label=None, | ||
sample_weight=None, drop_intermediate=True, | ||
response_method="predict_proba", label=None, ax=None): | ||
viz = RocCurveVisualizer(estimator, | ||
X, | ||
y, | ||
sample_weight=sample_weight, | ||
pos_label=pos_label, | ||
drop_intermediate=drop_intermediate, | ||
response_method=response_method, | ||
label=label) | ||
viz.plot(ax=ax) | ||
return viz | ||
``` | ||
|
||
Note that the ``__init__`` method defines attributes that are going to be used | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this paragraph can be integrated into the previous one IMO. |
||
for plotting and the ``plot`` method defines attributes that are related to the | ||
the matplotlib object itself. The line artist is stored as an attribute to | ||
allow for customizations after calling ``plot``. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1003,6 +1003,28 @@ See the :ref:`metrics` section of the user guide for further details. | |
metrics.pairwise_distances_chunked | ||
|
||
|
||
Plotting metrics | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a confusing name considering that right above we have "clustering metrics", "pairwise metrics", etc. Maybe just "plotting" or "plotting tools" is enough? |
||
---------------- | ||
|
||
.. automodule:: sklearn.metrics.plot | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we've decide to use sklearn.XXX.plot.plot_XXX, right? |
||
:no-members: | ||
:no-inherited-members: | ||
|
||
.. currentmodule:: sklearn | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: function.rst | ||
|
||
metrics.plot_roc_curve | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: class_without_init.rst | ||
|
||
metrics.plot.RocCurveVisualizer | ||
|
||
|
||
.. _mixture_ref: | ||
|
||
:mod:`sklearn.mixture`: Gaussian Mixture Models | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ | |
|
||
from sklearn import svm, datasets | ||
from sklearn.metrics import roc_curve, auc | ||
from sklearn.metrics import plot_roc_curve | ||
from sklearn.model_selection import StratifiedKFold | ||
|
||
# ############################################################################# | ||
|
@@ -61,44 +62,42 @@ | |
classifier = svm.SVC(kernel='linear', probability=True, | ||
random_state=random_state) | ||
|
||
visualizers = [] | ||
tprs = [] | ||
aucs = [] | ||
mean_fpr = np.linspace(0, 1, 100) | ||
|
||
i = 0 | ||
for train, test in cv.split(X, y): | ||
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test]) | ||
# Compute ROC curve and area the curve | ||
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1]) | ||
tprs.append(interp(mean_fpr, fpr, tpr)) | ||
tprs[-1][0] = 0.0 | ||
roc_auc = auc(fpr, tpr) | ||
aucs.append(roc_auc) | ||
plt.plot(fpr, tpr, lw=1, alpha=0.3, | ||
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc)) | ||
|
||
i += 1 | ||
fig, ax = plt.subplots() | ||
for i, (train, test) in enumerate(cv.split(X, y)): | ||
classifier.fit(X[train], y[train]) | ||
viz = plot_roc_curve(classifier, X[test], y[test], | ||
name='ROC fold {}'.format(i), | ||
line_kw={'alpha': 0.3, 'lw': 1}, ax=ax) | ||
visualizers.append(viz) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You never use this list, remove it? |
||
|
||
interp_tpr = interp(mean_fpr, viz.fpr_, viz.tpr_) | ||
interp_tpr[0] = 0.0 | ||
tprs.append(interp_tpr) | ||
aucs.append(viz.auc_) | ||
|
||
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't ax.plot here too ? |
||
label='Chance', alpha=.8) | ||
|
||
mean_tpr = np.mean(tprs, axis=0) | ||
mean_tpr[-1] = 1.0 | ||
mean_auc = auc(mean_fpr, mean_tpr) | ||
std_auc = np.std(aucs) | ||
plt.plot(mean_fpr, mean_tpr, color='b', | ||
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc), | ||
lw=2, alpha=.8) | ||
ax.plot(mean_fpr, mean_tpr, color='b', | ||
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc), | ||
lw=2, alpha=.8) | ||
|
||
std_tpr = np.std(tprs, axis=0) | ||
tprs_upper = np.minimum(mean_tpr + std_tpr, 1) | ||
tprs_lower = np.maximum(mean_tpr - std_tpr, 0) | ||
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, | ||
label=r'$\pm$ 1 std. dev.') | ||
|
||
plt.xlim([-0.05, 1.05]) | ||
plt.ylim([-0.05, 1.05]) | ||
plt.xlabel('False Positive Rate') | ||
plt.ylabel('True Positive Rate') | ||
plt.title('Receiver operating characteristic example') | ||
plt.legend(loc="lower right") | ||
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, | ||
label=r'$\pm$ 1 std. dev.') | ||
|
||
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], | ||
title="Receiver operating characteristic example") | ||
ax.legend(loc="lower right") | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .roc_curve import plot_roc_curve, RocCurveVisualizer | ||
|
||
__all__ = [ | ||
'plot_roc_curve', | ||
'RocCurveVisualizer' | ||
] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,191 @@ | ||||||
from .. import auc | ||||||
from .. import roc_curve | ||||||
|
||||||
from ...utils import check_matplotlib_support # noqa | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need noqa? (same below) |
||||||
|
||||||
|
||||||
class RocCurveVisualizer: | ||||||
"""ROC Curve visualization | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a note to say that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't say it shouldn't be instantiated, I'd say something like that this is usually instantiated by the function and should only be manually instantiated for advanced use cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me yes |
||||||
|
||||||
Parameters | ||||||
---------- | ||||||
|
||||||
estimator : estimator instance | ||||||
Trained classifier. | ||||||
|
||||||
X : {array-like, sparse matrix}, shape (n_samples, n_features) | ||||||
Input values. | ||||||
|
||||||
y : array-like, shape (n_samples, ) | ||||||
Target values. | ||||||
|
||||||
pos_label : int str or None, default=None | ||||||
The label of the positive class. | ||||||
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, | ||||||
`pos_label` is set to 1, otherwise an error will be raised. | ||||||
|
||||||
sample_weight : array-like of shape = [n_samples], optional | ||||||
Sample weights. | ||||||
|
||||||
drop_intermediate : boolean, default=True | ||||||
Whether to drop some suboptimal thresholds which would not appear | ||||||
on a plotted ROC curve. This is useful in order to create lighter | ||||||
ROC curves. | ||||||
|
||||||
response_method : 'predict_proba', 'decision_function', or 'auto' \ | ||||||
default='auto' | ||||||
Method to call estimator to get target scores | ||||||
|
||||||
Attributes | ||||||
---------- | ||||||
fpr_ : ndarray | ||||||
False positive rate. | ||||||
tpr_ : ndarray | ||||||
True positive rate. | ||||||
auc_ : | ||||||
Area under ROC curve. | ||||||
name_ : str | ||||||
Name of estimator. | ||||||
line_ : matplotlib Artist | ||||||
ROC Curve. | ||||||
ax_ : matplotlib Axes | ||||||
Axes with ROC curv | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
figure_ : matplotlib Figure | ||||||
Figure containing the curve | ||||||
""" | ||||||
|
||||||
def __init__(self, estimator, X, y, *, pos_label=None, sample_weight=None, | ||||||
drop_intermediate=True, response_method="auto"): | ||||||
"""Computes and stores values needed for visualization""" | ||||||
|
||||||
if response_method != "auto": | ||||||
prediction_method = getattr(estimator, response_method, None) | ||||||
if prediction_method is None: | ||||||
raise ValueError( | ||||||
"response method {} not defined".format(response_method)) | ||||||
else: | ||||||
predict_proba = getattr(estimator, 'predict_proba', None) | ||||||
decision_function = getattr(estimator, 'decision_function', None) | ||||||
prediction_method = predict_proba or decision_function | ||||||
|
||||||
if prediction_method is None: | ||||||
if response_method == 'predict_proba': | ||||||
raise ValueError('The estimator has no predict_proba method') | ||||||
else: | ||||||
raise ValueError( | ||||||
'The estimator has no decision_function method') | ||||||
|
||||||
y_pred = prediction_method(X) | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should move the computing part outside of the init, in a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have considered this and sided with doing the computation in With this viz = ROCVisualizer(response_method="predict_proba")
viz.compute(estimator, X, y)
viz.plot() This comes down how useful a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @thomasjpfan. I don't see a point in building visualization pipelines as yellowbrick does. The only purpose seems to be to "look more like sklearn", resulting in a harder to understand api and more complex code. |
||||||
if y_pred.ndim != 1: | ||||||
if y_pred.shape[1] > 2: | ||||||
raise ValueError("Estimator must be a binary classifier") | ||||||
y_pred = y_pred[:, 1] | ||||||
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label, | ||||||
drop_intermediate=drop_intermediate) | ||||||
|
||||||
self.fpr_ = fpr | ||||||
self.tpr_ = tpr | ||||||
self.auc_ = auc(fpr, tpr) | ||||||
self.name_ = estimator.__class__.__name__ | ||||||
|
||||||
def plot(self, ax=None, line_kw=None, name=None): | ||||||
"""Plot visualization | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
ax : Matplotlib Axes, default=None | ||||||
Axes object to plot on. | ||||||
|
||||||
line_kw : dict or None, default=None | ||||||
Keyword arguments to pass to. | ||||||
|
||||||
name : str or None, default=None | ||||||
Name of ROC Curve for labeling. If `None`, use the name of the | ||||||
estimator. | ||||||
""" | ||||||
check_matplotlib_support('plot_roc_curve') # noqa | ||||||
import matplotlib.pyplot as plt # noqa | ||||||
|
||||||
if ax is None: | ||||||
fig, ax = plt.subplots() | ||||||
|
||||||
if name is None: | ||||||
name = self.name_ | ||||||
|
||||||
label = "{} (AUC = {:0.2f})".format(name, self.auc_) | ||||||
line_kwargs = {"label": label} | ||||||
if line_kw is not None: | ||||||
line_kwargs.update(**line_kw) | ||||||
self.line_ = ax.plot(self.fpr_, self.tpr_, **line_kwargs)[0] | ||||||
ax.set_xlabel("False Positive Rate") | ||||||
ax.set_ylabel("True Positive Rate") | ||||||
ax.legend() | ||||||
|
||||||
self.ax_ = ax | ||||||
self.figure_ = ax.figure | ||||||
return self | ||||||
|
||||||
|
||||||
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's only one artist, I think we don't need to have a dict as an argument and can just do |
||||||
drop_intermediate=True, response_method="auto", | ||||||
name=None, ax=None, line_kw=None): | ||||||
"""Plot Receiver operating characteristic (ROC) curve | ||||||
|
||||||
Note: this implementation is restricted to the binary classification task. | ||||||
|
||||||
Read more in the :ref:`User Guide <roc_metrics>`. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
|
||||||
estimator : estimator instance | ||||||
Trained classifier. | ||||||
|
||||||
X : {array-like, sparse matrix}, shape (n_samples, n_features) | ||||||
Input values. | ||||||
|
||||||
y : array-like, shape (n_samples,, ) | ||||||
Target values. | ||||||
|
||||||
pos_label : int or str, default=None | ||||||
The label of the positive class. | ||||||
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, | ||||||
`pos_label` is set to 1, otherwise an error will be raised. | ||||||
|
||||||
sample_weight : array-like of shape = [n_samples], optional (default=None) | ||||||
Sample weights. | ||||||
|
||||||
drop_intermediate : boolean, default=True | ||||||
Whether to drop some suboptimal thresholds which would not appear | ||||||
on a plotted ROC curve. This is useful in order to create lighter | ||||||
ROC curves. | ||||||
|
||||||
response_method : 'predict_proba', 'decision_function', or 'auto' \ | ||||||
default='auto' | ||||||
Method to call estimator to get target scores | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please describe 'auto' |
||||||
|
||||||
name : str or None, default=None | ||||||
Name of ROC Curve for labeling. If `None`, use the name of the | ||||||
estimator. | ||||||
|
||||||
ax : matplotlib axes, default=None | ||||||
axes object to plot on | ||||||
|
||||||
line_kw : dict or None, default=None | ||||||
Keyword arguments to pass to | ||||||
|
||||||
Returns | ||||||
------- | ||||||
viz : :class:`sklearn.metrics.plot.RocCurveVisualizer` | ||||||
object that stores computed values | ||||||
""" | ||||||
viz = RocCurveVisualizer(estimator, | ||||||
X, | ||||||
y, | ||||||
sample_weight=sample_weight, | ||||||
pos_label=pos_label, | ||||||
drop_intermediate=drop_intermediate, | ||||||
response_method=response_method) | ||||||
viz.plot(ax=ax, line_kw=line_kw, name=name) | ||||||
return viz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
====
I think: right now this is considered a subsection of "Reading the existing code base"There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I think this whole section might be better suited for the user guide instead of the contributing guide