-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+2] Update repo to work with both new and old scikit-learn #313
Changes from 5 commits
3760c60
777f8d6
e5c240d
b754930
195da0a
0c780b0
bfea742
61755dd
4dac77b
ee856ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,71 +4,170 @@ | |
import metric_learn | ||
import numpy as np | ||
from sklearn import clone | ||
from sklearn.utils.testing import set_random_state | ||
import sklearn | ||
from packaging import version | ||
from test.test_utils import ids_metric_learners, metric_learners, remove_y | ||
skversion = version.parse(sklearn.__version__) | ||
if skversion >= version.parse('0.22.0'): | ||
from sklearn.utils._testing import set_random_state | ||
else: | ||
from sklearn.utils.testing import set_random_state | ||
|
||
|
||
def remove_spaces(s): | ||
return re.sub(r'\s+', '', s) | ||
|
||
|
||
def sk_repr_kwargs(def_kwargs, nndef_kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I thought it could be good to test the str repr for both the old (<=0.22) sklearn version and the newer ones, so I made a string representation that depends on the sklearn version |
||
"""Given the non-default arguments, and the default | ||
keywords arguments, build the string that will appear | ||
in the __repr__ of the estimator, depending on the | ||
version of scikit-learn. | ||
""" | ||
if skversion >= version.parse('0.22.0'): | ||
def_kwargs = "" | ||
nndef_kwargs = eval(f"dict({nndef_kwargs})") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than using eval, could we pass actual dicts here instead? The values are all simple literals, so it should be easier that way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you're right, will do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
def_kwargs = eval(f"dict({def_kwargs})") | ||
def_kwargs.update(nndef_kwargs) | ||
args_str = ",".join(f"{key}={strify(value)}" | ||
for key, value in def_kwargs.items()) | ||
return args_str | ||
|
||
|
||
def strify(obj): | ||
"""Function to add the quotation marks if the object | ||
is a string""" | ||
if type(obj) is str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if isinstance(obj, str) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depending on the inputs we expect to pass, it might be simpler to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right, thanks ! Will do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually there was a pb for booleans ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return f"'{obj}'" | ||
else: | ||
return str(obj) | ||
|
||
|
||
class TestStringRepr(unittest.TestCase): | ||
|
||
def test_covariance(self): | ||
def_kwargs = "preprocessor=None" | ||
nndef_kwargs = "" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.Covariance())), | ||
remove_spaces("Covariance()")) | ||
remove_spaces(f"Covariance({merged_kwargs})")) | ||
|
||
def test_lmnn(self): | ||
def_kwargs = """convergence_tol=0.001, init='auto', k=3, | ||
learn_rate=1e-07, max_iter=1000, min_iter=50, n_components=None, | ||
preprocessor=None, random_state=None, regularization=0.5, | ||
verbose=False""" | ||
nndef_kwargs = "convergence_tol=0.01, k=6" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, k=6))), | ||
remove_spaces("LMNN(convergence_tol=0.01, k=6)")) | ||
remove_spaces(f"LMNN({merged_kwargs})")) | ||
|
||
def test_nca(self): | ||
def_kwargs = """init='auto', max_iter=100, n_components=None, | ||
preprocessor=None, random_state=None, tol=None, verbose=False""" | ||
nndef_kwargs = "max_iter=42" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.NCA(max_iter=42))), | ||
remove_spaces("NCA(max_iter=42)")) | ||
remove_spaces(f"NCA({merged_kwargs})")) | ||
|
||
def test_lfda(self): | ||
def_kwargs = """embedding_type='weighted', k=None, | ||
n_components=None,preprocessor=None""" | ||
nndef_kwargs = "k=2" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.LFDA(k=2))), | ||
remove_spaces("LFDA(k=2)")) | ||
remove_spaces(f"LFDA({merged_kwargs})")) | ||
|
||
def test_itml(self): | ||
def_kwargs = """convergence_threshold=0.001, gamma=1.0, | ||
max_iter=1000, preprocessor=None, prior='identity', random_state=None, | ||
verbose=False""" | ||
nndef_kwargs = "gamma=0.5" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.ITML(gamma=0.5))), | ||
remove_spaces("ITML(gamma=0.5)")) | ||
remove_spaces(f"ITML({merged_kwargs})")) | ||
def_kwargs = """convergence_threshold=0.001, gamma=1.0, | ||
max_iter=1000, num_constraints=None,preprocessor=None, | ||
prior='identity', random_state=None, verbose=False""" | ||
nndef_kwargs = "num_constraints=7" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.ITML_Supervised(num_constraints=7))), | ||
remove_spaces("ITML_Supervised(num_constraints=7)")) | ||
remove_spaces(f"ITML_Supervised({merged_kwargs})")) | ||
|
||
def test_lsml(self): | ||
def_kwargs = """max_iter=1000, preprocessor=None, prior='identity', | ||
random_state=None, tol=0.001, verbose=False""" | ||
nndef_kwargs = "tol=0.1" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.LSML(tol=0.1))), | ||
remove_spaces("LSML(tol=0.1)")) | ||
remove_spaces(f"LSML({merged_kwargs})")) | ||
def_kwargs = """max_iter=1000, num_constraints=None, preprocessor=None, | ||
prior='identity', random_state=None, tol=0.001, verbose=False, | ||
weights=None""" | ||
nndef_kwargs = "verbose=True" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.LSML_Supervised(verbose=True))), | ||
remove_spaces("LSML_Supervised(verbose=True)")) | ||
remove_spaces(f"LSML_Supervised({merged_kwargs})")) | ||
|
||
def test_sdml(self): | ||
def_kwargs = """ | ||
balance_param=0.5, preprocessor=None, prior='identity', random_state=None, | ||
sparsity_param=0.01, verbose=False""" | ||
nndef_kwargs = "verbose=True" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.SDML(verbose=True))), | ||
remove_spaces("SDML(verbose=True)")) | ||
remove_spaces(f"SDML({merged_kwargs})")) | ||
def_kwargs = """balance_param=0.5, num_constraints=None, | ||
preprocessor=None, prior='identity', random_state=None, | ||
sparsity_param=0.01, verbose=False""" | ||
nndef_kwargs = "sparsity_param=0.5" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.SDML_Supervised(sparsity_param=0.5))), | ||
remove_spaces("SDML_Supervised(sparsity_param=0.5)")) | ||
remove_spaces(f"SDML_Supervised({merged_kwargs})")) | ||
|
||
def test_rca(self): | ||
def_kwargs = "n_components=None, preprocessor=None" | ||
nndef_kwargs = "n_components=3" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.RCA(n_components=3))), | ||
remove_spaces("RCA(n_components=3)")) | ||
remove_spaces(f"RCA({merged_kwargs})")) | ||
def_kwargs = """chunk_size=2, n_components=None, num_chunks=100, | ||
preprocessor=None, random_state=None""" | ||
nndef_kwargs = "num_chunks=5" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.RCA_Supervised(num_chunks=5))), | ||
remove_spaces("RCA_Supervised(num_chunks=5)")) | ||
remove_spaces(f"RCA_Supervised({merged_kwargs})")) | ||
|
||
def test_mlkr(self): | ||
def_kwargs = """init='auto', max_iter=1000, n_components=None, | ||
preprocessor=None, random_state=None, tol=None, verbose=False""" | ||
nndef_kwargs = "max_iter=777" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.MLKR(max_iter=777))), | ||
remove_spaces("MLKR(max_iter=777)")) | ||
remove_spaces(f"MLKR({merged_kwargs})")) | ||
|
||
def test_mmc(self): | ||
def_kwargs = """convergence_threshold=0.001, diagonal=False, | ||
diagonal_c=1.0, init='identity', max_iter=100, max_proj=10000, | ||
preprocessor=None, random_state=None, verbose=False""" | ||
nndef_kwargs = "diagonal=True" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.MMC(diagonal=True))), | ||
remove_spaces("MMC(diagonal=True)")) | ||
remove_spaces(f"MMC({merged_kwargs})")) | ||
def_kwargs = """convergence_threshold=1e-06, diagonal=False, | ||
diagonal_c=1.0, init='identity', max_iter=100, max_proj=10000, | ||
num_constraints=None, preprocessor=None, random_state=None, | ||
verbose=False""" | ||
nndef_kwargs = "max_iter=1" | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.MMC_Supervised(max_iter=1))), | ||
remove_spaces("MMC_Supervised(max_iter=1)")) | ||
remove_spaces(f"MMC_Supervised({merged_kwargs})")) | ||
|
||
|
||
@pytest.mark.parametrize('estimator, build_dataset', metric_learners, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,12 @@ | |
from sklearn.datasets import make_spd_matrix, make_blobs | ||
from sklearn.utils import check_random_state, shuffle | ||
from sklearn.utils.multiclass import type_of_target | ||
from sklearn.utils.testing import set_random_state | ||
import sklearn | ||
from packaging import version | ||
if version.parse(sklearn.__version__) >= version.parse('0.22.0'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like something we could do once during module initialization, then have a global constant like Or make a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, will do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
from sklearn.utils._testing import set_random_state | ||
else: | ||
from sklearn.utils.testing import set_random_state | ||
|
||
from metric_learn._util import make_context, _initialize_metric_mahalanobis | ||
from metric_learn.base_metric import (_QuadrupletsClassifierMixin, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,17 @@ | |
from sklearn.base import TransformerMixin | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.utils import check_random_state | ||
from sklearn.utils.estimator_checks import is_public_parameter | ||
from sklearn.utils.testing import (assert_allclose_dense_sparse, | ||
set_random_state) | ||
import sklearn | ||
from packaging import version | ||
if version.parse(sklearn.__version__) >= version.parse('0.22.0'): | ||
from sklearn.utils._testing import (assert_allclose_dense_sparse, | ||
set_random_state, _get_args) | ||
from sklearn.utils.estimator_checks import (_is_public_parameter | ||
as is_public_parameter) | ||
else: | ||
from sklearn.utils.testing import (assert_allclose_dense_sparse, | ||
set_random_state, _get_args) | ||
from sklearn.utils.estimator_checks import is_public_parameter | ||
|
||
from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA, | ||
ITML_Supervised, LSML_Supervised, | ||
|
@@ -16,8 +24,10 @@ | |
import numpy as np | ||
from sklearn.model_selection import (cross_val_score, cross_val_predict, | ||
train_test_split, KFold) | ||
from sklearn.metrics.scorer import get_scorer | ||
from sklearn.utils.testing import _get_args | ||
if version.parse(sklearn.__version__) >= version.parse('0.22.0'): | ||
from sklearn.metrics._scorer import get_scorer | ||
else: | ||
from sklearn.metrics.scorer import get_scorer | ||
from test.test_utils import (metric_learners, ids_metric_learners, | ||
mock_preprocessor, tuples_learners, | ||
ids_tuples_learners, pairs_learners, | ||
|
@@ -52,37 +62,37 @@ def __init__(self, sparsity_param=0.01, | |
|
||
class TestSklearnCompat(unittest.TestCase): | ||
def test_covariance(self): | ||
check_estimator(Covariance) | ||
check_estimator(Covariance()) | ||
|
||
def test_lmnn(self): | ||
check_estimator(LMNN) | ||
check_estimator(LMNN()) | ||
|
||
def test_lfda(self): | ||
check_estimator(LFDA) | ||
check_estimator(LFDA()) | ||
|
||
def test_mlkr(self): | ||
check_estimator(MLKR) | ||
check_estimator(MLKR()) | ||
|
||
def test_nca(self): | ||
check_estimator(NCA) | ||
check_estimator(NCA()) | ||
|
||
def test_lsml(self): | ||
check_estimator(LSML_Supervised) | ||
check_estimator(LSML_Supervised()) | ||
|
||
def test_itml(self): | ||
check_estimator(ITML_Supervised) | ||
check_estimator(ITML_Supervised()) | ||
|
||
def test_mmc(self): | ||
check_estimator(MMC_Supervised) | ||
check_estimator(MMC_Supervised()) | ||
|
||
def test_sdml(self): | ||
check_estimator(Stable_SDML_Supervised) | ||
check_estimator(Stable_SDML_Supervised()) | ||
|
||
def test_rca(self): | ||
check_estimator(Stable_RCA_Supervised) | ||
check_estimator(Stable_RCA_Supervised()) | ||
|
||
def test_scml(self): | ||
check_estimator(SCML_Supervised) | ||
check_estimator(SCML_Supervised()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here scikit-learn had return an error saying that checks should be run on estimators instances, not classes |
||
|
||
|
||
RNG = check_random_state(0) | ||
|
@@ -121,7 +131,8 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor): | |
|
||
# we subsample the data for the test to be more efficient | ||
input_data, _, labels, _ = train_test_split(input_data, labels, | ||
train_size=20) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test was failing here because of classes with too few labels in LMNN (see comments above), so in the end I created a toy example with a bit more samples (which I guess makes sense because the role of this particular test is not to test edge cases, but rather the fact that array-like objects work with our estimators), |
||
train_size=40, | ||
random_state=42) | ||
X = X[:10] | ||
|
||
estimator = clone(estimator) | ||
|
@@ -160,7 +171,7 @@ def test_various_scoring_on_tuples_learners(estimator, build_dataset, | |
with_preprocessor): | ||
"""Tests that scikit-learn's scoring returns something finite, | ||
for other scoring than default scoring. (List of scikit-learn's scores can be | ||
found in sklearn.metrics.scorer). For each type of output (predict, | ||
found in sklearn.metrics._scorer). For each type of output (predict, | ||
predict_proba, decision_function), we test a bunch of scores. | ||
We only test on pairs learners because quadruplets don't have a y argument. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a note here to clarify this additional test's purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, done