Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] fix lookup for specialized test classes #189

Merged
merged 24 commits into from Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions skpro/distributions/tests/test_all_distrs.py
Expand Up @@ -10,7 +10,6 @@
from skbase.testing import QuickTester

from skpro.datatypes import check_is_mtype
from skpro.distributions.base import BaseDistribution
from skpro.tests.test_all_estimators import BaseFixtureGenerator, PackageConfig
from skpro.utils.index import random_ss_ix

Expand All @@ -27,7 +26,7 @@ class DistributionFixtureGenerator(BaseFixtureGenerator):
instances are generated by create_test_instance class method
"""

object_type_filter = BaseDistribution
object_type_filter = "distribution"


def _has_capability(distr, method):
Expand Down Expand Up @@ -60,6 +59,11 @@ def _has_capability(distr, method):
class TestAllDistributions(PackageConfig, DistributionFixtureGenerator, QuickTester):
"""Module level tests for all skpro parameter fitters."""

# TEMPORARY skip for CyclicBoosting and QPD classes
# due to silent failures on main, se #190
exclude_objects = ["QPD_S", "QPD_B", "QPD_U"]
# remove this when fixing failures to re-enable testing

@pytest.mark.parametrize("shuffled", [False, True])
def test_sample(self, object_instance, shuffled):
"""Test sample expected return."""
Expand Down
57 changes: 34 additions & 23 deletions skpro/registry/_lookup.py
Expand Up @@ -48,35 +48,46 @@ def all_objects(
Which kind of objects should be returned.
if None, no filter is applied and all objects are returned.
if str or list of str, strings define scitypes specified in search
only objects that are of (at least) one of the scitypes are returned
possible str values are entries of registry.BASE_CLASS_REGISTER (first col)
for instance 'classifier', 'regressor', 'transformer', 'forecaster'
only objects that are of (at least) one of the scitypes are returned
possible str values are entries of registry.BASE_CLASS_REGISTER (first col)
for instance 'regrssor_proba', 'distribution, 'metric'

return_names: bool, optional (default=True)
if True, object class name is included in the all_objects()
return in the order: name, object class, optional tags, either as
a tuple or as pandas.DataFrame columns
if False, object class name is removed from the all_objects()
return.

if True, estimator class name is included in the ``all_objects``
return in the order: name, estimator class, optional tags, either as
a tuple or as pandas.DataFrame columns

if False, estimator class name is removed from the ``all_objects`` return.

filter_tags: dict of (str or list of str), optional (default=None)
For a list of valid tag strings, use the registry.all_tags utility.
subsets the returned objects as follows:
each key/value pair is statement in "and"/conjunction
key is tag name to sub-set on
value str or list of string are tag values
condition is "key must be equal to value, or in set(value)"
exclude_objects: str, list of str, optional (default=None)
Names of objects to exclude.

``filter_tags`` subsets the returned estimators as follows:

* each key/value pair is statement in "and"/conjunction
* key is tag name to sub-set on
* value str or list of string are tag values
* condition is "key must be equal to value, or in set(value)"

exclude_estimators: str, list of str, optional (default=None)
Names of estimators to exclude.

as_dataframe: bool, optional (default=False)
if True, all_objects will return a pandas.DataFrame with named
columns for all of the attributes being returned.
if False, all_objects will return a list (either a list of
objects or a list of tuples, see Returns)

True: ``all_objects`` will return a pandas.DataFrame with named
columns for all of the attributes being returned.

False: ``all_objects`` will return a list (either a list of
estimators or a list of tuples, see Returns)

return_tags: str or list of str, optional (default=None)
Names of tags to fetch and return each object's value of.
Names of tags to fetch and return each estimator's value of.
For a list of valid tag strings, use the registry.all_tags utility.
if str or list of str,
the tag values named in return_tags will be fetched for each
object and will be appended as either columns or tuple entries.
the tag values named in return_tags will be fetched for each
estimator and will be appended as either columns or tuple entries.

suppress_import_stdout : bool, optional. Default=True
whether to suppress stdout printout upon import.

Expand All @@ -85,7 +96,7 @@ def all_objects(
all_objects will return one of the following:
1. list of objects, if return_names=False, and return_tags is None
2. list of tuples (optional object name, class, ~optional object
tags), if return_names=True or return_tags is not None.
tags), if return_names=True or return_tags is not None.
3. pandas.DataFrame if as_dataframe = True
if list of objects:
entries are objects matching the query,
Expand Down
11 changes: 8 additions & 3 deletions skpro/regression/tests/test_all_regressors.py
Expand Up @@ -5,7 +5,6 @@

from skpro.datatypes import check_is_mtype, check_raise
from skpro.distributions.base import BaseDistribution
from skpro.regression.base._base import BaseProbaRegressor
from skpro.tests.test_all_estimators import BaseFixtureGenerator, PackageConfig

TEST_ALPHAS = [0.05, [0.1], [0.25, 0.75], [0.3, 0.1, 0.9]]
Expand All @@ -17,8 +16,14 @@ class TestAllRegressors(PackageConfig, BaseFixtureGenerator, QuickTester):
# class variables which can be overridden by descendants
# ------------------------------------------------------

# which object types are generated; None=all, or class (passed to all_objects)
object_type_filter = BaseProbaRegressor
# which object types are generated; None=all, or scitype string
# passed to skpro.registry.all_objects as object_type
object_type_filter = "regressor_proba"

# TEMPORARY skip for CyclicBoosting and QPD classes
# due to silent failures on main, se #190
exclude_objects = ["CyclicBoosting"]
# remove this when fixing failures to re-enable testing

def test_input_output_contract(self, object_instance):
"""Tests that output of predict methods is as specified."""
Expand Down