Skip to content

Commit

Permalink
MAINT Slight common tests cleanup (#14511)
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller authored and thomasjpfan committed Jul 30, 2019
1 parent 7a87ac5 commit b162aca
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 23 deletions.
17 changes: 0 additions & 17 deletions sklearn/tests/test_common.py
Expand Up @@ -32,7 +32,6 @@
_safe_tags,
set_checking_parameters,
check_parameters_default_constructible,
check_no_attributes_set_in_init,
check_class_weight_balanced_linear_classifier)


Expand Down Expand Up @@ -111,22 +110,6 @@ def test_estimators(estimator, check):
check(name, estimator)


@pytest.mark.parametrize("name, estimator",
_tested_estimators())
def test_no_attributes_set_in_init(name, estimator):
# input validation etc for all estimators
with ignore_warnings(category=(DeprecationWarning, ConvergenceWarning,
UserWarning, FutureWarning)):
tags = _safe_tags(estimator)
if tags['_skip_test']:
warnings.warn("Explicit SKIP via _skip_test tag for "
"{}.".format(name),
SkipTestWarning)
return
# check this on class
check_no_attributes_set_in_init(name, estimator)


@ignore_warnings(category=DeprecationWarning)
# ignore deprecated open(.., 'U') in numpy distutils
def test_configure():
Expand Down
7 changes: 4 additions & 3 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -72,6 +72,7 @@ def _safe_tags(estimator, key=None):

def _yield_checks(name, estimator):
tags = _safe_tags(estimator)
yield check_no_attributes_set_in_init
yield check_estimators_dtypes
yield check_fit_score_takes_y
yield check_sample_weights_pandas_series
Expand Down Expand Up @@ -288,7 +289,6 @@ def check_estimator(Estimator):
name = Estimator.__name__
estimator = Estimator()
check_parameters_default_constructible(name, Estimator)
check_no_attributes_set_in_init(name, estimator)
else:
# got an instance
estimator = Estimator
Expand Down Expand Up @@ -2056,9 +2056,10 @@ def check_estimators_overwrite_params(name, estimator_orig):
% (name, param_name, original_value, new_value))


def check_no_attributes_set_in_init(name, estimator):
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
def check_no_attributes_set_in_init(name, estimator_orig):
"""Check setting during init. """

estimator = clone(estimator_orig)
if hasattr(type(estimator).__init__, "deprecated_original"):
return

Expand Down
6 changes: 3 additions & 3 deletions sklearn/utils/tests/test_estimator_checks.py
Expand Up @@ -414,7 +414,7 @@ def test_check_estimator():

# doesn't error on actual estimator
check_estimator(LogisticRegression)
check_estimator(LogisticRegression())
check_estimator(LogisticRegression(C=0.01))
check_estimator(MultiTaskElasticNet)
check_estimator(MultiTaskElasticNet())

Expand Down Expand Up @@ -483,11 +483,11 @@ def test_check_estimators_unfitted():


def test_check_no_attributes_set_in_init():
class NonConformantEstimatorPrivateSet:
class NonConformantEstimatorPrivateSet(BaseEstimator):
def __init__(self):
self.you_should_not_set_this_ = None

class NonConformantEstimatorNoParamSet:
class NonConformantEstimatorNoParamSet(BaseEstimator):
def __init__(self, you_should_set_this_=None):
pass

Expand Down

0 comments on commit b162aca

Please sign in to comment.