Skip to content

Commit

Permalink
[MRG+1] MAINT Parametrize common estimator tests with pytest (#11063)
Browse files Browse the repository at this point in the history
  • Loading branch information
rth authored and glemaitre committed May 7, 2018
1 parent aaf9cf0 commit 67cc975
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 23 deletions.
18 changes: 16 additions & 2 deletions doc/developers/tips.rst
Expand Up @@ -64,8 +64,22 @@ will be displayed as a color background behind the line number.
Useful pytest aliases and flags
-------------------------------

We recommend using pytest to run unit tests. When a unit tests fail, the
following tricks can make debugging easier:
The full test suite takes fairly long to run. For faster iterations,
it is possibly to select a subset of tests using pytest selectors.
In particular, one can run a `single test based on its node ID
<https://docs.pytest.org/en/latest/example/markers.html#selecting-tests-based-on-their-node-id>`_::

pytest -v sklearn/linear_model/tests/test_logistic.py::test_sparsify

or use the `-k pytest parameter
<https://docs.pytest.org/en/latest/example/markers.html#using-k-expr-to-select-tests-based-on-their-name>`_
to select tests based on their name. For instance,::

pytest sklearn/tests/test_common.py -v -k LogisticRegression

will run all :term:`common tests` for the ``LogisticRegression`` estimator.

When a unit tests fail, the following tricks can make debugging easier:

1. The command line argument ``pytest -l`` instructs pytest to print the local
variables when a failure occurs.
Expand Down
69 changes: 48 additions & 21 deletions sklearn/tests/test_common.py
Expand Up @@ -13,6 +13,8 @@
import re
import pkgutil

import pytest

from sklearn.utils.testing import assert_false, clean_warning_registry
from sklearn.utils.testing import all_estimators
from sklearn.utils.testing import assert_equal
Expand Down Expand Up @@ -41,34 +43,57 @@ def test_all_estimator_no_base_class():


def test_all_estimators():
# Test that estimators are default-constructible, cloneable
# and have working repr.
estimators = all_estimators(include_meta_estimators=True)

# Meta sanity-check to make sure that the estimator introspection runs
# properly
assert_greater(len(estimators), 0)

for name, Estimator in estimators:
# some can just not be sensibly default constructed
yield check_parameters_default_constructible, name, Estimator

@pytest.mark.parametrize(
'name, Estimator',
all_estimators(include_meta_estimators=True)
)
def test_parameters_default_constructible(name, Estimator):
# Test that estimators are default-constructible
check_parameters_default_constructible(name, Estimator)

def test_non_meta_estimators():
# input validation etc for non-meta estimators
estimators = all_estimators()
for name, Estimator in estimators:

def _tested_non_meta_estimators():
for name, Estimator in all_estimators():
if issubclass(Estimator, BiclusterMixin):
continue
if name.startswith("_"):
continue
yield name, Estimator


def _generate_checks_per_estimator(check_generator, estimators):
for name, Estimator in estimators:
estimator = Estimator()
# check this on class
yield check_no_attributes_set_in_init, name, estimator
for check in check_generator(name, estimator):
yield name, Estimator, check

for check in _yield_all_checks(name, estimator):
set_checking_parameters(estimator)
yield check, name, estimator

@pytest.mark.parametrize(
"name, Estimator, check",
_generate_checks_per_estimator(_yield_all_checks,
_tested_non_meta_estimators())
)
def test_non_meta_estimators(name, Estimator, check):
# Common tests for non-meta estimators
estimator = Estimator()
set_checking_parameters(estimator)
check(name, estimator)


@pytest.mark.parametrize("name, Estimator",
_tested_non_meta_estimators())
def test_no_attributes_set_in_init(name, Estimator):
# input validation etc for non-meta estimators
estimator = Estimator()
# check this on class
check_no_attributes_set_in_init(name, estimator)


def test_configure():
Expand All @@ -95,19 +120,21 @@ def test_configure():
os.chdir(cwd)


def test_class_weight_balanced_linear_classifiers():
def _tested_linear_classifiers():
classifiers = all_estimators(type_filter='classifier')

clean_warning_registry()
with warnings.catch_warnings(record=True):
linear_classifiers = [
(name, clazz)
for name, clazz in classifiers
for name, clazz in classifiers:
if ('class_weight' in clazz().get_params().keys() and
issubclass(clazz, LinearClassifierMixin))]
issubclass(clazz, LinearClassifierMixin)):
yield name, clazz


for name, Classifier in linear_classifiers:
yield check_class_weight_balanced_linear_classifier, name, Classifier
@pytest.mark.parametrize("name, Classifier",
_tested_linear_classifiers())
def test_class_weight_balanced_linear_classifiers(name, Classifier):
check_class_weight_balanced_linear_classifier(name, Classifier)


@ignore_warnings
Expand Down

0 comments on commit 67cc975

Please sign in to comment.