-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Threshold for pairs learners #168
Changes from 1 commit
676ab86
cc1c3e6
f95c456
9ffe8f7
3354fb1
12cb5f1
dd8113e
1c8cd29
d12729a
dc9e21d
402729f
aaac3de
e5b1e47
a0cb3ca
8d5fc50
0f14b25
a6458a2
fada5cc
32a4889
5cf71b9
c2bc693
e96ee00
3ed3430
69c6945
bc39392
facc546
f0ca65e
a6ec283
49fbbd7
960b174
c91acf7
a742186
9ec1ead
986fed3
3f5d6d1
7b5e4dd
a3ec02c
ccc66eb
6dff15b
719d018
551d161
594c485
14713c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from sklearn.base import BaseEstimator | ||
from sklearn.metrics.ranking import _binary_clf_curve | ||
from sklearn.utils.extmath import stable_cumsum | ||
from sklearn.utils.validation import _is_arraylike, check_is_fitted | ||
from sklearn.metrics import roc_auc_score, precision_recall_curve, roc_curve | ||
from sklearn.metrics import roc_auc_score, roc_curve | ||
import numpy as np | ||
from abc import ABCMeta, abstractmethod | ||
import six | ||
|
@@ -490,20 +491,39 @@ def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy', | |
scores_sorted = scores[scores_sorted_idces] | ||
# true labels ordered by decision_function value: (higher first) | ||
y_ordered = y_valid[scores_sorted_idces] | ||
# we need to add a threshold that will reject all points | ||
scores_sorted = np.concatenate([[scores_sorted[0] + 1], scores_sorted]) | ||
|
||
# finds the threshold that maximizes the accuracy: | ||
cum_tp = stable_cumsum(y_ordered == 1) # cumulative number of true | ||
# positives | ||
# we need to add the point where all samples are rejected: | ||
cum_tp = np.concatenate([[0.], cum_tp]) | ||
cum_tn_inverted = stable_cumsum(y_ordered[::-1] == -1) | ||
cum_tn = np.concatenate([[0], cum_tn_inverted[:-1]])[::-1] | ||
cum_tn = np.concatenate([[0.], cum_tn_inverted])[::-1] | ||
cum_accuracy = (cum_tp + cum_tn) / n_samples | ||
imax = np.argmax(cum_accuracy) | ||
# note: we want a positive threshold (distance), so we take - threshold | ||
self.threshold_ = - scores_sorted[imax] | ||
if imax == len(scores_sorted): # if the best is to accept all points | ||
# we set the threshold to (minus) [the lowest score - 1] | ||
self.threshold_ = - (scores_sorted[imax] - 1) | ||
else: | ||
# otherwise, we set the threshold to the mean between the lowest | ||
# accepted score and the highest accepted score | ||
self.threshold_ = - np.mean(scores_sorted[imax: imax + 2]) | ||
# note: if the best is to reject all points it's already one of the | ||
# thresholds (scores_sorted[0] + 1) | ||
return self | ||
|
||
if strategy == 'f_beta': | ||
precision, recall, thresholds = precision_recall_curve( | ||
fps, tps, thresholds = _binary_clf_curve( | ||
y_valid, self.decision_function(pairs_valid), pos_label=1) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to take part of the code (but not all) from |
||
precision = tps / (tps + fps) | ||
precision[np.isnan(precision)] = 0 | ||
recall = tps / tps[-1] | ||
|
||
# here the thresholds are decreasing | ||
# We ignore the warnings here, in the same taste as | ||
# https://github.com/scikit-learn/scikit-learn/blob/62d205980446a1abc1065 | ||
# f4332fd74eee57fcf73/sklearn/metrics/classification.py#L1284 | ||
|
@@ -516,26 +536,45 @@ def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy', | |
f_beta[np.isnan(f_beta)] = 0. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to set nans to zero otherwise they will be considered higher than the others (also discussed in https://github.com/scikit-learn/scikit-learn/pull/10117/files#r262115773) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, probably good to mention it in a comment in the code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, done |
||
imax = np.argmax(f_beta) | ||
# note: we want a positive threshold (distance), so we take - threshold | ||
self.threshold_ = - thresholds[imax] | ||
if imax == len(thresholds): # the best is to accept all points | ||
# we set the threshold to (minus) [the lowest score - 1] | ||
self.threshold_ = - (thresholds[imax] - 1) | ||
else: | ||
# otherwise, we set the threshold to the mean between the lowest | ||
# accepted score and the highest rejected score | ||
self.threshold_ = - np.mean(thresholds[imax: imax + 2]) | ||
# Note: we don't need to deal with rejecting all points (i.e. threshold = | ||
# max_scores + 1), since this can never happen to be optimal | ||
# (see a more detailed discussion in test_calibrate_threshold_extreme) | ||
return self | ||
|
||
fpr, tpr, thresholds = roc_curve(y_valid, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that there is a particular corner case that is not explicitely coded here but comes from scikit-learn: we should decide if that's what we want or not: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note also that some points can be deleted in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say that it is fine this way? It does not seem very important anyway so going with the sklearn approach is fine |
||
self.decision_function(pairs_valid), | ||
pos_label=1) | ||
pos_label=1, drop_intermediate=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need |
||
# here the thresholds are decreasing | ||
fpr, tpr, thresholds = fpr, tpr, thresholds | ||
|
||
if strategy == 'max_tpr': | ||
indices = np.where(1 - fpr >= min_rate)[0] | ||
max_tpr_index = np.argmax(tpr[indices]) | ||
# note: we want a positive threshold (distance), so we take - threshold | ||
self.threshold_ = - thresholds[indices[max_tpr_index]] | ||
if strategy in ['max_tpr', 'max_tnr']: | ||
if strategy == 'max_tpr': | ||
indices = np.where(1 - fpr >= min_rate)[0] | ||
imax = np.argmax(tpr[indices]) | ||
|
||
if strategy == 'max_tnr': | ||
indices = np.where(tpr >= min_rate)[0] | ||
max_tnr_index = np.argmax(1 - fpr[indices]) | ||
if strategy == 'max_tnr': | ||
indices = np.where(tpr >= min_rate)[0] | ||
imax = np.argmax(1 - fpr[indices]) | ||
|
||
imax_valid = indices[imax] | ||
# note: we want a positive threshold (distance), so we take - threshold | ||
self.threshold_ = - thresholds[indices[max_tnr_index]] | ||
return self | ||
if indices[imax] == len(thresholds): # we want to accept everything | ||
self.threshold_ = - (thresholds[imax_valid] - 1) | ||
elif indices[imax] == 0: # we want to reject everything | ||
# thanks to roc_curve, the first point should be always max_threshold | ||
# + 1 (we should always go through the "if" statement in roc_curve), | ||
# see: https://github.com/scikit-learn/scikit-learn/pull/13523 | ||
self.threshold_ = - (thresholds[imax_valid]) | ||
else: | ||
self.threshold_ = - np.mean(thresholds[imax_valid: imax_valid + 2]) | ||
return self | ||
|
||
|
||
class _QuadrupletsClassifierMixin(BaseMetricLearner): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to include the last
cum_tn_inverted
now, since it's the one where all samples will be rejected