Skip to content

Commit

Permalink
FIX the threshold by taking the opposite (to be adapted to the decisi…
Browse files Browse the repository at this point in the history
…on function)
  • Loading branch information
William de Vazelhes committed Feb 20, 2019
1 parent dc9e21d commit 402729f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def predict(self, pairs):
The predicted learned metric value between samples in every pair.
"""
check_is_fitted(self, ['threshold_', 'transformer_'])
return - 2 * (self.decision_function(pairs) > self.threshold_) + 1
return 2 * (self.decision_function(pairs) > self.threshold_) - 1

def decision_function(self, pairs):
"""Returns the decision function used to classify the pairs.
Expand Down Expand Up @@ -387,13 +387,13 @@ def score(self, pairs, y):
return roc_auc_score(y, self.decision_function(pairs))

def set_default_threshold(self, pairs, y):
"""Returns a threshold that is the mean between the similar metrics
mean, and the dissimilar metrics mean"""
similar_threshold = np.mean(self.decision_function(
"""Returns a threshold that is the opposite of the mean between the similar
metrics mean and the dissimilar metrics mean"""
similar_threshold = np.mean(self.score_pairs(
pairs[(y == 1).ravel()]))
dissimilar_threshold = np.mean(self.decision_function(
dissimilar_threshold = np.mean(self.score_pairs(
pairs[(y == -1).ravel()]))
self.threshold_ = np.mean([similar_threshold, dissimilar_threshold])
self.threshold_ = - np.mean([similar_threshold, dissimilar_threshold])

This comment has been minimized.

Copy link
@bellet

bellet Feb 20, 2019

Member

can't we rather keep the threshold positive (i.e., which works with actual distances and not negative distances), and change the predict accordingly? It feels weird to have a negative threshold



class _QuadrupletsClassifierMixin(BaseMetricLearner):
Expand Down
2 changes: 1 addition & 1 deletion metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def fit(self, pairs, y, bounds=None):
Returns the instance.
"""
self._fit(pairs, y, bounds=bounds)
self.threshold_ = np.mean(self.bounds_)
self.threshold_ = - np.mean(self.bounds_)
return self


Expand Down

0 comments on commit 402729f

Please sign in to comment.