-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Bug fix and new feature: fix implementation of average precision score and add eleven-point interpolated option #7356
Changes from 3 commits
3673ee1
62d7e18
ee8bb5c
4c46f53
10e80f3
c2371dd
da07e85
59e805d
3f64c6d
852e043
a92d774
5cf87b0
bca5cb8
be236e2
8a11570
5088b90
e78d8a8
a9a0cd6
fc7a72c
bf73e74
89005d1
91d466f
64f28ee
ff2e31a
de7b660
99db671
11124ec
48ba926
bef0c01
ee82a96
3a56b67
65b657a
8ab2e69
eabc3bd
5ee2a22
1cbb03b
88156c1
8bfd59a
81412bd
1b4539c
4fae063
3cc067a
b769998
323a59c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,8 +37,10 @@ | |
from .base import _average_binary_score | ||
|
||
|
||
def auc(x, y, reorder=False): | ||
"""Compute Area Under the Curve (AUC) using the trapezoidal rule | ||
def auc(x, y, reorder=False, interpolation='linear', | ||
interpolation_direction='right'): | ||
"""Estimate Area Under the Curve (AUC) using finitely many points and an | ||
interpolation strategy. | ||
|
||
This is a general function, given points on a curve. For computing the | ||
area under the ROC-curve, see :func:`roc_auc_score`. | ||
|
@@ -55,6 +57,25 @@ def auc(x, y, reorder=False): | |
If True, assume that the curve is ascending in the case of ties, as for | ||
an ROC curve. If the curve is non-ascending, the result will be wrong. | ||
|
||
interpolation : string ['trapezoid' (default), 'step'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. linear, not trapezoid ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! |
||
This determines the type of interpolation performed on the data. | ||
|
||
``'linear'``: | ||
Use the trapezoidal rule (linearly interpolating between points). | ||
``'step'``: | ||
Use a step function where we ascend/descend from each point to the | ||
y-value of the subsequent point. | ||
|
||
interpolation_direction : string ['right' (default), 'left'] | ||
This determines the direction to interpolate from. The value is ignored | ||
unless interpolation is 'step'. | ||
|
||
``'right'``: | ||
Intermediate points inherit their y-value from the subsequent | ||
point. | ||
``'left'``: | ||
Intermediate points inherit their y-value from the previous point. | ||
|
||
Returns | ||
------- | ||
auc : float | ||
|
@@ -100,20 +121,48 @@ def auc(x, y, reorder=False): | |
raise ValueError("Reordering is not turned on, and " | ||
"the x array is not increasing: %s" % x) | ||
|
||
area = direction * np.trapz(y, x) | ||
if isinstance(area, np.memmap): | ||
# Reductions such as .sum used internally in np.trapz do not return a | ||
# scalar by default for numpy.memmap instances contrary to | ||
# regular numpy.ndarray instances. | ||
area = area.dtype.type(area) | ||
if interpolation == 'linear': | ||
|
||
area = direction * np.trapz(y, x) | ||
if isinstance(area, np.memmap): | ||
# Reductions such as .sum used internally in np.trapz do not return | ||
# a scalar by default for numpy.memmap instances contrary to | ||
# regular numpy.ndarray instances. | ||
area = area.dtype.type(area) | ||
|
||
elif interpolation == 'step': | ||
|
||
# we need the data to start in ascending order | ||
if direction == -1: | ||
x, y = list(reversed(x)), list(reversed(y)) | ||
|
||
if interpolation_direction == 'right': | ||
# The left-most y-value is not used | ||
area = sum(np.diff(x) * np.array(y)[1:]) | ||
|
||
elif interpolation_direction == 'left': | ||
# The right-most y-value is not used | ||
area = sum(np.diff(x) * np.array(y)[:-1]) | ||
|
||
else: | ||
raise ValueError("interpolation_direction '{}' not recognised." | ||
" Should be one of ['right', 'left']".format( | ||
interpolation_direction)) | ||
|
||
else: | ||
raise ValueError("interpolation value '{}' not recognized. " | ||
"Should be one of ['linear', 'step']".format( | ||
interpolation)) | ||
|
||
return area | ||
|
||
|
||
def average_precision_score(y_true, y_score, average="macro", | ||
sample_weight=None): | ||
sample_weight=None, interpolation="linear"): | ||
"""Compute average precision (AP) from prediction scores | ||
|
||
This score corresponds to the area under the precision-recall curve. | ||
This score corresponds to the area under the precision-recall curve, where | ||
points are joined using either linear or step-wise interpolation. | ||
|
||
Note: this implementation is restricted to the binary classification task | ||
or multilabel classification task. | ||
|
@@ -127,8 +176,7 @@ def average_precision_score(y_true, y_score, average="macro", | |
|
||
y_score : array, shape = [n_samples] or [n_samples, n_classes] | ||
Target scores, can either be probability estimates of the positive | ||
class, confidence values, or non-thresholded measure of decisions | ||
(as returned by "decision_function" on some classifiers). | ||
class, confidence values, or binary decisions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you change this? Also this is not true, this shouldn't be binary decisions! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea! I'll revert it. |
||
|
||
average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted'] | ||
If ``None``, the scores for each class are returned. Otherwise, | ||
|
@@ -149,15 +197,30 @@ def average_precision_score(y_true, y_score, average="macro", | |
sample_weight : array-like of shape = [n_samples], optional | ||
Sample weights. | ||
|
||
interpolation : string ['linear' (default), 'step'] | ||
Determines the kind of interpolation used when computed AUC. If there | ||
are many repeated scores, 'step' is recommended to avoid under- or | ||
over-estimating the AUC. See `Roam Analytics blog post | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem a good source to me. I'm happy with citing the IR book, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem - I'll remove it and add the IR book to the references. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @amueller Related question: in the references section, is it okay to include the link to my blog post? I'm never sure whether the references are to justify the implementation, or to provide useful guidance for users. I feel like my post might be helpful for the latter but not for the former! |
||
<https://github.com/roaminsight/roamresearch/blob/master/BlogPosts/ | ||
Average_precision/Average_precision_post.ipynb>` | ||
for details. | ||
|
||
``'linear'``: | ||
Linearly interpolates between operating points. | ||
``'step'``: | ||
Uses a step function to interpolate between operating points. | ||
|
||
Returns | ||
------- | ||
average_precision : float | ||
|
||
References | ||
---------- | ||
.. [1] `Wikipedia entry for the Average precision | ||
<https://en.wikipedia.org/wiki/Average_precision>`_ | ||
|
||
<http://en.wikipedia.org/wiki/Average_precision>`_ | ||
.. [2] `Roam Analytics blog post | ||
<https://github.com/roaminsight/roamresearch/blob/master/BlogPosts/ | ||
Average_precision/Average_precision_post.ipynb>` | ||
See also | ||
-------- | ||
roc_auc_score : Area under the ROC curve | ||
|
@@ -178,8 +241,20 @@ def average_precision_score(y_true, y_score, average="macro", | |
def _binary_average_precision(y_true, y_score, sample_weight=None): | ||
precision, recall, thresholds = precision_recall_curve( | ||
y_true, y_score, sample_weight=sample_weight) | ||
return auc(recall, precision) | ||
|
||
return auc(recall, precision, interpolation=interpolation, | ||
interpolation_direction='right') | ||
|
||
if interpolation == "linear": | ||
# Check for number of unique predictions. If this is substantially less | ||
# than the number of predictions, linear interpolation is likely to be | ||
# biased. | ||
n_discrete_predictions = len(np.unique(y_score)) | ||
if n_discrete_predictions < 0.75 * len(y_score): | ||
warnings.warn("Number of unique scores is less than 75% of the " | ||
"number of scores provided. Linear interpolation " | ||
"is likely to be biased in this case. You may wish " | ||
"to use step interpolation instead. See docstring " | ||
"for details.") | ||
return _average_binary_score(_binary_average_precision, y_true, y_score, | ||
average, sample_weight=sample_weight) | ||
|
||
|
@@ -253,7 +328,7 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None): | |
|
||
fpr, tpr, tresholds = roc_curve(y_true, y_score, | ||
sample_weight=sample_weight) | ||
return auc(fpr, tpr, reorder=True) | ||
return auc(fpr, tpr, reorder=True, interpolation='linear') | ||
|
||
return _average_binary_score( | ||
_binary_roc_auc_score, y_true, y_score, average, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove these newlines? They are important and makes things much more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's my IDE trying to be clever when I copied/pasted. I've restored them.