Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds plot_det_curve and associated display #18176

Merged
merged 22 commits into from
Aug 29, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ details.
metrics.cohen_kappa_score
metrics.confusion_matrix
metrics.dcg_score
metrics.detection_error_tradeoff_curve
metrics.det_curve
metrics.f1_score
metrics.fbeta_score
metrics.hamming_loss
Expand Down Expand Up @@ -1100,6 +1100,7 @@ See the :ref:`visualizations` section of the user guide for further details.
:template: function.rst

metrics.plot_confusion_matrix
metrics.plot_det_curve
metrics.plot_precision_recall_curve
metrics.plot_roc_curve

Expand All @@ -1108,6 +1109,7 @@ See the :ref:`visualizations` section of the user guide for further details.
:template: class.rst

metrics.ConfusionMatrixDisplay
metrics.DetCurveDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay

Expand Down
4 changes: 2 additions & 2 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ Some of these are restricted to the binary classification case:

precision_recall_curve
roc_curve
detection_error_tradeoff_curve
det_curve


Others also work in the multiclass case:
Expand Down Expand Up @@ -1443,7 +1443,7 @@ to the given limit.
Detection error tradeoff (DET)
------------------------------

The function :func:`detection_error_tradeoff_curve` computes the
The function :func:`det_curve` computes the
detection error tradeoff curve (DET) curve [WikipediaDET2017]_.
Quoting Wikipedia:

Expand Down
2 changes: 2 additions & 0 deletions doc/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Functions

inspection.plot_partial_dependence
metrics.plot_confusion_matrix
metrics.plot_det_curve
metrics.plot_precision_recall_curve
metrics.plot_roc_curve

Expand All @@ -91,5 +92,6 @@ Display Objects

inspection.PartialDependenceDisplay
metrics.ConfusionMatrixDisplay
metrics.DetCurveDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay
8 changes: 6 additions & 2 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,15 @@ Changelog
:mod:`sklearn.metrics`
......................

- |Feature| Added :func:`metrics.detection_error_tradeoff_curve` to compute
Detection Error Tradeoff curve classification metric.
- |Feature| Added :func:`metrics.det_curve` to compute Detection Error Tradeoff
curve classification metric.
:pr:`10591` by :user:`Jeremy Karnowski <jkarnows>` and
:user:`Daniel Mohns <dmohns>`.

- |Feature| Added :func:`metrics.plot_det_curve` and :class:`DetCurveDisplay`
to ease the plot of DET curves.
:pr:`18176` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Feature| Added :func:`metrics.mean_absolute_percentage_error` metric and
the associated scorer for regression problems. :issue:`10708` fixed with the
PR :pr:`15007` by :user:`Ashutosh Hathidara <ashutosh1919>`. The scorer and
Expand Down
52 changes: 12 additions & 40 deletions examples/model_selection/plot_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
for the same classification task.

DET curves are commonly plotted in normal deviate scale.
To achieve this we transform the error rates as returned by the
:func:`~sklearn.metrics.detection_error_tradeoff_curve` function and the axis
scale using :func:`scipy.stats.norm`.
To achieve this `plot_det_curve` transforms the error rates as returned by the
:func:`~sklearn.metrics.det_curve` and the axis scale using
:func:`scipy.stats.norm`.

The point of this example is to demonstrate two properties of DET curves,
namely:
Expand Down Expand Up @@ -39,8 +39,8 @@
- See :func:`sklearn.metrics.roc_curve` for further information about ROC
curves.

- See :func:`sklearn.metrics.detection_error_tradeoff_curve` for further
information about DET curves.
- See :func:`sklearn.metrics.det_curve` for further information about
DET curves.

- This example is loosely based on
:ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py`
Expand All @@ -51,15 +51,13 @@

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import detection_error_tradeoff_curve
from sklearn.metrics import plot_det_curve
from sklearn.metrics import plot_roc_curve
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC

from scipy.stats import norm

N_SAMPLES = 1000

classifiers = {
Expand All @@ -79,43 +77,17 @@
# prepare plots
fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(11, 5))

# first prepare the ROC curve
ax_roc.set_title('Receiver Operating Characteristic (ROC) curves')
ax_roc.grid(linestyle='--')

# second prepare the DET curve
ax_det.set_title('Detection Error Tradeoff (DET) curves')
ax_det.set_xlabel('False Positive Rate')
ax_det.set_ylabel('False Negative Rate')
ax_det.set_xlim(-3, 3)
ax_det.set_ylim(-3, 3)
ax_det.grid(linestyle='--')

# customized ticks for DET curve plot to represent normal deviate scale
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
tick_locs = norm.ppf(ticks)
tick_lbls = [
'{:.0%}'.format(s) if (100*s).is_integer() else '{:.1%}'.format(s)
for s in ticks
]
plt.sca(ax_det)
plt.xticks(tick_locs, tick_lbls)
plt.yticks(tick_locs, tick_lbls)

# iterate over classifiers
for name, clf in classifiers.items():
clf.fit(X_train, y_train)

if hasattr(clf, "decision_function"):
y_score = clf.decision_function(X_test)
else:
y_score = clf.predict_proba(X_test)[:, 1]

plot_roc_curve(clf, X_test, y_test, ax=ax_roc, name=name)
det_fpr, det_fnr, _ = detection_error_tradeoff_curve(y_test, y_score)
plot_det_curve(clf, X_test, y_test, ax=ax_det, name=name)

ax_roc.set_title('Receiver Operating Characteristic (ROC) curves')
ax_det.set_title('Detection Error Tradeoff (DET) curves')

# transform errors into normal deviate scale
ax_det.plot(norm.ppf(det_fpr), norm.ppf(det_fnr), label=name)
ax_roc.grid(linestyle='--')
ax_det.grid(linestyle='--')

plt.legend()
plt.show()
8 changes: 6 additions & 2 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._ranking import auc
from ._ranking import average_precision_score
from ._ranking import coverage_error
from ._ranking import detection_error_tradeoff_curve
from ._ranking import det_curve
from ._ranking import dcg_score
from ._ranking import label_ranking_average_precision_score
from ._ranking import label_ranking_loss
Expand Down Expand Up @@ -77,6 +77,8 @@
from ._scorer import SCORERS
from ._scorer import get_scorer

from ._plot.det_curve import plot_det_curve
from ._plot.det_curve import DetCurveDisplay
from ._plot.roc_curve import plot_roc_curve
from ._plot.roc_curve import RocCurveDisplay
from ._plot.precision_recall_curve import plot_precision_recall_curve
Expand Down Expand Up @@ -105,7 +107,8 @@
'coverage_error',
'dcg_score',
'davies_bouldin_score',
'detection_error_tradeoff_curve',
'DetCurveDisplay',
'det_curve',
'euclidean_distances',
'explained_variance_score',
'f1_score',
Expand Down Expand Up @@ -142,6 +145,7 @@
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_confusion_matrix',
'plot_det_curve',
'plot_precision_recall_curve',
'plot_roc_curve',
'PrecisionRecallDisplay',
Expand Down
Loading