Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Threshold for pairs learners #168

Merged
merged 43 commits into from
Apr 15, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
676ab86
add some tests for testing that different scores work using the scori…
Feb 4, 2019
cc1c3e6
ENH: Add tests and basic threshold implementation
Feb 5, 2019
f95c456
Add support for LSML and more generally quadruplets
Feb 6, 2019
9ffe8f7
Make CalibratedClassifierCV work (for preprocessor case) thanks to cl…
Feb 6, 2019
3354fb1
Fix some tests and PEP8 errors
Feb 7, 2019
12cb5f1
change the sign in decision function
Feb 19, 2019
dd8113e
Add docstring for threshold_ and classes_ in the base _PairsClassifie…
Feb 19, 2019
1c8cd29
remove quadruplets from the test with scikit learn custom scorings
Feb 19, 2019
d12729a
Remove argument y in quadruplets learners and lsml
Feb 20, 2019
dc9e21d
FIX fix docstrings of decision functions
Feb 20, 2019
402729f
FIX the threshold by taking the opposite (to be adapted to the decisi…
Feb 20, 2019
aaac3de
Fix tests to have no y for quadruplets' estimator fit
Feb 21, 2019
e5b1e47
Remove isin to be compatible with old numpy versions
Feb 21, 2019
a0cb3ca
Fix threshold so that it has a positive value and add small test
Feb 21, 2019
8d5fc50
Fix threshold for itml
Feb 21, 2019
0f14b25
FEAT: Add calibrate_threshold and tests
Mar 4, 2019
a6458a2
MAINT: remove starred syntax for compatibility with older versions of…
Mar 5, 2019
fada5cc
Remove debugging prints and make tests for ITML pass, while waiting f…
Mar 5, 2019
32a4889
FIX: from __future__ import division to pass tests for python 2.7
Mar 5, 2019
5cf71b9
Add some documentation for calibration
Mar 11, 2019
c2bc693
DOC: fix style
Mar 11, 2019
e96ee00
Merge branch 'master' into feat/add_threshold
Mar 21, 2019
3ed3430
Address most comments from aurelien's reviews
Mar 21, 2019
69c6945
Remove classes_ attribute and test for CalibratedClassifierCV
Mar 21, 2019
bc39392
Rename make_args_inc_quadruplets into remove_y_quadruplets
Mar 21, 2019
facc546
TST: Fix remaining threshold into min_rate
Mar 21, 2019
f0ca65e
Remove default_threshold and put calibrate_threshold instead
Mar 21, 2019
a6ec283
Use calibrate_threshold for ITML, and remove description
Mar 21, 2019
49fbbd7
ENH: use calibrate_threshold by default and display its parameters fr…
Mar 21, 2019
960b174
Add a small test to test automatic calibration
Mar 21, 2019
c91acf7
Update documentation of the default threshold
Mar 21, 2019
a742186
Inverse sense for threshold comparison to be more intuitive
Mar 21, 2019
9ec1ead
Address remaining review comments
Mar 21, 2019
986fed3
MAINT: Rename threshold_params into calibration_params
Mar 26, 2019
3f5d6d1
TST: Add test for extreme cases
Mar 26, 2019
7b5e4dd
MAINT: rename threshold_params into calibration_params
Mar 26, 2019
a3ec02c
MAINT: rename threshold_params into calibration_params
Mar 26, 2019
ccc66eb
FIX: Make tests work, and add the right threshold (mean between lowes…
Mar 27, 2019
6dff15b
Merge branch 'master' into feat/add_threshold
Mar 27, 2019
719d018
Go back to previous version of finding the threshold
Apr 2, 2019
551d161
Extract method for validating calibration parameters
Apr 2, 2019
594c485
Validate calibration params before fit
Apr 2, 2019
14713c6
Address https://github.com/metric-learn/metric-learn/pull/168#discuss…
Apr 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 26 additions & 5 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy.linalg import cholesky
from scipy.spatial.distance import euclidean
from sklearn.base import BaseEstimator
from sklearn.utils.validation import _is_arraylike
from sklearn.metrics import roc_auc_score
from sklearn.utils.validation import _is_arraylike, check_is_fitted
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
from abc import ABCMeta, abstractmethod
import six
Expand Down Expand Up @@ -296,6 +296,7 @@ def get_mahalanobis_matrix(self):

class _PairsClassifierMixin(BaseMetricLearner):

classes_ = [-1, 1]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need estimators to have a classes_ attribute for CalibratedClassifierCV to work. Typically, scikit-learn estimators deduce it at fit time, but in our case we will always have these ones (see also issue #164). That's why I put it as a class attribute, to be more explicit.

Classes have to be in this order ([-1, 1])(i.e. sorted), otherwise there is a bug: because here https://github.com/scikit-learn/scikit-learn/blob/1a850eb5b601f3bf0f88a43090f83c51b3d8c593/sklearn/calibration.py#L312-L313:
self.label_encoder_.classes_ is a sorted list of classes (so it'll be [-1, 1] in every cases). So when it transforms self.base_estimator.classes_ it will return [0, 1] if our estimator has [-1, 1] as classes, and [1, 0] if it has [1, -1]. In the latter case, it will return IndexError because here (https://github.com/scikit-learn/scikit-learn/blob/1a850eb5b601f3bf0f88a43090f83c51b3d8c593/sklearn/calibration.py#L357) it will try to reach the column n°1 (and not 0) of Y (which does not exist).

Also, we have to put -1 and 1 in this order so that for label_binarizer (which is called by CalibratedClassifierCV), -1 (first element of [-1, 1]) will be considered as the negative label, and 1 will be considered as the positive label. (That's how label_binarizer work, see the example below)

Example: (warning: pos_label and neg_label are not to say what input classes are pos or neg, as one could think, but rather how we want them to appear in the output)

In [1]: from sklearn.preprocessing import label_binarize 
   ...: print(label_binarize([-1, 1, -1], [-1, 1], pos_label=2019, neg_label=2018))                                                                                                                                                                                                        
[[2018]
 [2019]
 [2018]]

And:

In [1]: from sklearn.preprocessing import label_binarize 
   ...: print(label_binarize([-1, 1, -1], [1, -1], pos_label=2019, neg_label=2018))                                                                                                                                               
[[2019]
 [2018]
 [2019]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the classes_ attribute in this PR? Otherwise it could be saved for #173

_tuple_size = 2 # number of points in a tuple, 2 for pairs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we document the attribute threshold_ here in the general class rather than in the downstream classes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I should document it for both abstract and downstream classes, but I don't think I can document it only here: indeed, we can inherit docstrings, but I don't think we can inherit docstrings paragraph-wise (e.g. _PairsClassifierMixin has an attribute threshold_ that we want to transmit to MMC. But the problem is that MMC has also a title, and also other attributes like transformer_). I didn't find an option to do this in sphinx and this impossibility make some sense because if we write two docstrings paragraphs (below the title), would the result in the downstream class be the concatenation of both descriptions ? It could give weird results (e.g. : "This is the abstract class", "This is is the MMC class", would give: "This is the abstract class. This is the MMC class") However there seem to exist some packages that could allow to do something like this. We could use them but maybe later on in the development if things get too replicated. https://github.com/Chilipp/docrep https://github.com/meowklaski/custom_inherit
Then there's also the solution not to care about inheriting the docstrings, but I think it's better for a user to know what are all the attributes in his object without needing to click on the doc of all mother classes (that's what is done in scikit-learn for instance I think)
Finally there's also another option which is to declare attributes as properties so that we can have a more granular inheritance, but I think this adds to much complexity just for docstrings

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(done: I just added them to the base PairsClassifier class too)


def predict(self, pairs):
Expand All @@ -317,7 +318,8 @@ def predict(self, pairs):
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
The predicted learned metric value between samples in every pair.
"""
return self.decision_function(pairs)
check_is_fitted(self, ['threshold_', 'transformer_'])
return - 2 * (self.decision_function(pairs) > self.threshold_) + 1

def decision_function(self, pairs):
"""Returns the learned metric between input pairs.
Expand Down Expand Up @@ -369,6 +371,13 @@ def score(self, pairs, y):
"""
return roc_auc_score(y, self.decision_function(pairs))

def set_default_threshold(self, pairs, y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove if indeed we choose the accuracy-calibrated threshold as default

"""Returns a threshold that is the mean between the similar metrics
mean, and the dissimilar metrics mean"""
similar_threshold = np.mean(self.decision_function(pairs[y==1]))
dissimilar_threshold = np.mean(self.decision_function(pairs[y==1]))
self.threshold_ = np.mean([similar_threshold, dissimilar_threshold])


class _QuadrupletsClassifierMixin(BaseMetricLearner):

Expand All @@ -393,6 +402,7 @@ def predict(self, quadruplets):
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
Predictions of the ordering of pairs, for each quadruplet.
"""
check_is_fitted(self, 'transformer_')
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
preprocessor=self.preprocessor_,
estimator=self, tuple_size=self._tuple_size)
Expand Down Expand Up @@ -435,11 +445,22 @@ def score(self, quadruplets, y=None):
points, or 2D array of indices of quadruplets if the metric learner
uses a preprocessor.

y : Ignored, for scikit-learn compatibility.
y : array-like, shape=(n_constraints,) or `None`
Labels of constraints. y[i] should be 1 if
d(pairs[i, 0], X[i, 1]) is wanted to be larger than
d(X[i, 2], X[i, 3]), and -1 if it is wanted to be smaller. If None,
`y` will be set to `np.ones(quadruplets.shape[0])`, i.e. we want all
first two points to be closer than the last two points in each
quadruplet.

Returns
-------
score : float
The quadruplets score.
"""
return -np.mean(self.predict(quadruplets))
quadruplets = check_input(quadruplets, y, type_of_inputs='tuples',
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that here quadruplets will be checked twice (once here, once in predict). This is because when I do y = np.ones(quadruplets.shape[0]), I want to be sure that I can do quadruplets.shape[0], and otherwise an error message would be returned before (by the check_input method). I think it's fine isn't it ? I don't see any other solution to do so. Note that I also check at the same time y because I like the fact that column_or_1d check will be called on y (since it is not done in accuracy_score).

preprocessor=self.preprocessor_,
estimator=self, tuple_size=self._tuple_size)
if y is None:
y = np.ones(quadruplets.shape[0])
return accuracy_score(y, self.predict(quadruplets))
14 changes: 13 additions & 1 deletion metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ class ITML(_BaseITML, _PairsClassifierMixin):
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See function `transformer_from_metric`.)

threshold_ : `float`
If the distance metric between two points is lower than this threshold,
points will be classified as similar, otherwise they will be
classified as dissimilar.

classes_ : `list`
The possible labels of the pairs `LSML` can fit on. `classes_ = [-1, 1]`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ITML not LSML

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done

where -1 means points in a pair are dissimilar (negative label), and 1
means they are similar (positive label).
"""

def fit(self, pairs, y, bounds=None):
Expand Down Expand Up @@ -176,7 +186,9 @@ def fit(self, pairs, y, bounds=None):
self : object
Returns the instance.
"""
return self._fit(pairs, y, bounds=bounds)
self._fit(pairs, y, bounds=bounds)
self.threshold_ = np.mean(self.bounds_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if the above is yes you mean remove this and replace it by calibrate_threshold like in all pairwise metric learners ?

return self


class ITML_Supervised(_BaseITML, TransformerMixin):
Expand Down
24 changes: 17 additions & 7 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,15 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False,
super(_BaseLSML, self).__init__(preprocessor)

def _fit(self, quadruplets, y=None, weights=None):
quadruplets = self._prepare_inputs(quadruplets,
quadruplets = self._prepare_inputs(quadruplets, y,
type_of_inputs='tuples')

if y is None:
y = np.ones(quadruplets.shape[0])
# we swap the quadruplets where the label is -1 since they are not in
# the right order
quadruplets_to_swap = quadruplets[y == -1]
quadruplets[y == -1] = np.column_stack([quadruplets_to_swap[:, 2:],
quadruplets_to_swap[:, :2]])
# check to make sure that no two constrained vectors are identical
vab = quadruplets[:, 0, :] - quadruplets[:, 1, :]
vcd = quadruplets[:, 2, :] - quadruplets[:, 3, :]
Expand Down Expand Up @@ -144,18 +150,22 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
metric (See function `transformer_from_metric`.)
"""

def fit(self, quadruplets, weights=None):
def fit(self, quadruplets, y=None, weights=None):
"""Learn the LSML model.

Parameters
----------
quadruplets : array-like, shape=(n_constraints, 4, n_features) or
(n_constraints, 4)
3D array-like of quadruplets of points or 2D array of quadruplets of
indicators. In order to supervise the algorithm in the right way, we
should have the four samples ordered in a way such that:
d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) for all 0 <= i <
n_constraints.
indicators.
y : array-like, shape=(n_constraints,) or `None`
Labels of constraints. y[i] should be 1 if
d(pairs[i, 0], X[i, 1]) is wanted to be larger than
d(X[i, 2], X[i, 3]), and -1 if it is wanted to be smaller. If None,
`y` will be set to `np.ones(quadruplets.shape[0])`, i.e. we want to
put all first two points closer than the last two points in each
quadruplet.
weights : (n_constraints,) array of floats, optional
scale factor for each constraint

Expand Down
14 changes: 13 additions & 1 deletion metric_learn/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,16 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See function `transformer_from_metric`.)

threshold_ : `float`
If the distance metric between two points is lower than this threshold,
points will be classified as similar, otherwise they will be
classified as dissimilar.

classes_ : `list`
The possible labels of the pairs `MMC` can fit on. `classes_ = [-1, 1]`,
where -1 means points in a pair are dissimilar (negative label), and 1
means they are similar (positive label).
"""

def fit(self, pairs, y):
Expand All @@ -379,7 +389,9 @@ def fit(self, pairs, y):
self : object
Returns the instance.
"""
return self._fit(pairs, y)
self._fit(pairs, y)
self.set_default_threshold(pairs, y)
return self


class MMC_Supervised(_BaseMMC, TransformerMixin):
Expand Down
14 changes: 13 additions & 1 deletion metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ class SDML(_BaseSDML, _PairsClassifierMixin):
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See function `transformer_from_metric`.)

threshold_ : `float`
If the distance metric between two points is lower than this threshold,
points will be classified as similar, otherwise they will be
classified as dissimilar.

classes_ : `list`
The possible labels of the pairs `SDML` can fit on. `classes_ = [-1, 1]`,
where -1 means points in a pair are dissimilar (negative label), and 1
means they are similar (positive label).
"""

def fit(self, pairs, y):
Expand All @@ -101,7 +111,9 @@ def fit(self, pairs, y):
self : object
Returns the instance.
"""
return self._fit(pairs, y)
self._fit(pairs, y)
self.set_default_threshold(pairs, y)
return self


class SDML_Supervised(_BaseSDML, TransformerMixin):
Expand Down
65 changes: 65 additions & 0 deletions test/test_pairs_classifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split

from test.test_utils import pairs_learners, ids_pairs_learners
from sklearn.utils.testing import set_random_state
from sklearn import clone
import numpy as np


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
ids=ids_pairs_learners)
def test_predict_only_one_or_minus_one(estimator, build_dataset,
with_preprocessor):
"""Test that all predicted values are either +1 or -1"""
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
pairs_train, pairs_test, y_train, y_test = train_test_split(input_data,
labels)
estimator.fit(pairs_train, y_train)
predictions = estimator.predict(pairs_test)
assert np.isin(predictions, [-1, 1]).all()


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
ids=ids_pairs_learners)
def test_predict_monotonous(estimator, build_dataset,
with_preprocessor):
"""Test that there is a threshold distance separating points labeled as
similar and points labeled as dissimilar """
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
pairs_train, pairs_test, y_train, y_test = train_test_split(input_data,
labels)
estimator.fit(pairs_train, y_train)
distances = estimator.score_pairs(pairs_test)
predictions = estimator.predict(pairs_test)
min_dissimilar = np.min(distances[predictions == -1])
max_similar = np.max(distances[predictions == 1])
assert max_similar <= min_dissimilar
separator = np.mean([min_dissimilar, max_similar])
assert (predictions[distances > separator] == -1).all()
assert (predictions[distances < separator] == 1).all()


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
ids=ids_pairs_learners)
def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
with_preprocessor):
"""Test that a NotFittedError is raised if someone tries to predict and
the metric learner has not been fitted."""
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
with pytest.raises(NotFittedError):
estimator.predict(input_data)

65 changes: 65 additions & 0 deletions test/test_quadruplets_classifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split

from test.test_utils import quadruplets_learners, ids_quadruplets_learners
from sklearn.utils.testing import set_random_state
from sklearn import clone
import numpy as np


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
ids=ids_quadruplets_learners)
def test_predict_only_one_or_minus_one(estimator, build_dataset,
with_preprocessor):
"""Test that all predicted values are either +1 or -1"""
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
(quadruplets_train,
quadruplets_test, y_train, y_test) = train_test_split(input_data, labels)
estimator.fit(quadruplets_train, y_train)
predictions = estimator.predict(quadruplets_test)
assert np.isin(predictions, [-1, 1]).all()


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
ids=ids_quadruplets_learners)
def test_predict_monotonous(estimator, build_dataset,
with_preprocessor):
"""Test that there is a threshold distance separating points labeled as
similar and points labeled as dissimilar """
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
(quadruplets_train,
quadruplets_test, y_train, y_test) = train_test_split(input_data, labels)
estimator.fit(quadruplets_train, y_train)
distances = estimator.score_quadruplets(quadruplets_test)
predictions = estimator.predict(quadruplets_test)
min_dissimilar = np.min(distances[predictions == -1])
max_similar = np.max(distances[predictions == 1])
assert max_similar <= min_dissimilar
separator = np.mean([min_dissimilar, max_similar])
assert (predictions[distances > separator] == -1).all()
assert (predictions[distances < separator] == 1).all()


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
ids=ids_quadruplets_learners)
def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
with_preprocessor):
"""Test that a NotFittedError is raised if someone tries to predict and
the metric learner has not been fitted."""
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
with pytest.raises(NotFittedError):
estimator.predict(input_data)