Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make precision_recall_curve plots less confusing #7372

Closed
amueller opened this issue Sep 8, 2016 · 5 comments
Closed

make precision_recall_curve plots less confusing #7372

amueller opened this issue Sep 8, 2016 · 5 comments
Labels
Easy Well-defined and straightforward way to resolve Enhancement help wanted module:metrics

Comments

@amueller
Copy link
Member

amueller commented Sep 8, 2016

Currently the "precision_recall_curve" plots in the examples are kind of confusing because they use linear interpolation, which is pretty bad.
We could use plt.step instead of plt.plot but that would look a bit weird. We could also interpolate the curve on a fixed grid to get a Pascal VOC like thing.

In addition, I think we should have an explanation of the different ways average precision and the precision recall curve are calculated.

I might work on that tomorrow.

@ndingwall
Copy link
Contributor

What's wrong with plt.step? The areas under these curve correspond to the average precision computed using the wikipedia metric.

screen shot 2016-09-16 at 6 02 54 pm

Pascal VOC might be more complicated. Maybe we could add dotted vertical lines indicating each recall threshold, and circle the operating points that are selected for averaging.

@amueller
Copy link
Member Author

Nothing is wrong with this. I think this is the right way. Maybe some shading to illustrate the area?

So I think there is two things here: fixing the existing examples with plt.step, and adding a new example that compares the different methods. If you think adding the IR book method is too complicated, I'd be happy with just adding the 11 point method and the previously used linear interpolation (and saying that linear interpolation is bad) and maybe giving an example of why 11 point might be bad.

@ndingwall
Copy link
Contributor

ndingwall commented Sep 20, 2016

@amueller I've now pushed an update to my branch for this example. I've also written a minimal example that highlights the difference between the old and new implementations, but I'm not sure where that would go. For now, here it is:

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import auc
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve

# Generate some data where scores are constant (e.g. output of DummyClassifier)
y_true = np.concatenate((np.ones(15), np.zeros(85)))
y_score = [0.5 for _ in range(len(y_true))]
p, r, _ = precision_recall_curve(y_true, y_score)
# Reverse the lists so recall is in ascending order
p = p[::-1]
r = r[::-1]
# Compute metrics
average_precision = average_precision_score(y_true, y_score)
linear_area = auc(r, p)
interpolated_av_pr = average_precision_score(y_true, y_score, 
                                             interpolation='eleven_point')

# Plot figures
fig = plt.figure()
ax = plt.subplot(111)

# Plot P-R curve with linear interpolation
plt.plot(r, p, color='b')
plt.fill_between(r, 0, p, alpha=0.2, 
                 label='(Old) Linear interpolation (AUC {:0.2f})'.format(
                         linear_area))

# Plot P-R curve with step interpolation
plt.step(r, p, color='r')
plt.fill_between([v for v in r for _ in (0, 1)][:-1], 0, 
                 [v for v in p for _ in (0, 1)][1:], 
                 alpha=0.2, color='r',
                 label='(New) Step interpolation (AUC {:0.2f})'.format(
                         average_precision))

# Plot the 11 operating points chosen for 11-point interpolated av. precision
recall_circles = list()
precision_circles = list()
for threshold in np.arange(0, 1.1, 0.1):
    i = sum(r[1:] >= threshold)
    recall_circles.append(r[-i])
    precision_circles.append(p[-i])
for this_r, this_p in zip(recall_circles, precision_circles):
    t = plt.text(this_r + 0.0075, this_p + 0.01, "{:3.3f}".format(this_p), 
                 color='g')
plt.scatter(recall_circles, precision_circles, marker='o', s=100, 
            facecolor='none', edgecolor='g',
           label='11-point interpolated precisions (mean = {:0.2f})'.format(
                    interpolated_av_pr))

# Set limits, etc
plt.title("Old vs new implemenations of `average_precision_score`\n "
          "on a dummy model that makes constant predictions")
plt.xlim((0, 1))
plt.ylim((0, 1))
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1, 
                 box.width, box.height * 0.9])
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), frameon=False)
plt.show()

This produces the following plot:

screen shot 2016-09-20 at 3 31 32 pm

@amueller amueller added Easy Well-defined and straightforward way to resolve Need Contributor Sprint labels Jul 14, 2017
@kurchi1205
Copy link

I want to make my first contribution . Can I work on this?

@glemaitre
Copy link
Member

I am closing this issue because the PrecisionRecallDisplay is already now doing this type of plot. We should probably revisit the problem and introduce the possibility of passing a parameter to interpolate the PR curve and the AP in a similar manner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Easy Well-defined and straightforward way to resolve Enhancement help wanted module:metrics
Projects
None yet
Development

No branches or pull requests

6 participants