Skip to content

Commit

Permalink
add test to check for proper initialization of dl estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
achieveordie committed Jan 10, 2023
1 parent 4e2f2c2 commit 1164017
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions sktime/tests/test_all_estimators.py
Expand Up @@ -1377,6 +1377,37 @@ def test_multiprocessing_idempotent(
err_msg="Results are not equal for n_jobs=1 and n_jobs=-1",
)

def test_dl_constructor_initializes_deeply(self, estimator_class):
"""Test DL estimators that they pass custom parameters to underlying Network."""
estimator = estimator_class

if not issubclass(estimator, (BaseDeepClassifier, BaseDeepRegressor)):
return None

if not hasattr(estimator, "get_test_params"):
return None

params = estimator.get_test_params()

if isinstance(params, list):
params = params[0]
if isinstance(params, dict):
pass
else:
raise TypeError(
f"`get_test_params()` of estimator: {estimator} returns "
f"an expected type: {type(params)}, acceptable formats: [list, dict]"
)

estimator = estimator(**params)

for key, value in params.items():
assert vars(estimator)[key] == value
# some keys are only relevant to the final model (eg: n_epochs)
# skip them for the underlying network
if vars(estimator._network).get(key) is not None:
assert vars(estimator._network)[key] == value

def _get_err_msg(estimator):
return (
f"Invalid estimator type: {type(estimator)}. Valid estimator types are: "
Expand Down

0 comments on commit 1164017

Please sign in to comment.