Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grid/random search tag fix #1455

Merged
merged 3 commits into from Sep 26, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 14 additions & 28 deletions sktime/forecasting/model_selection/_tune.py
Expand Up @@ -56,6 +56,20 @@ def __init__(
self.return_n_best_forecasters = return_n_best_forecasters
super(BaseGridSearch, self).__init__()

tags_to_clone = [
"requires-fh-in-fit",
"capability:pred_int",
# "scitype:y", commented out until grid search works with multivariate
"ignores-exogeneous-X",
"handles-missing-data",
"y_inner_mtype",
"X_inner_mtype",
"X-y-must-have-same-index",
"enforce-index-type",
]

self.clone_tags(forecaster, tags_to_clone)

@if_delegate_has_method(delegate=("best_forecaster_", "forecaster"))
def _update(self, y, X=None, update_params=False):
"""Call predict on the forecaster with the best found parameters."""
Expand Down Expand Up @@ -453,20 +467,6 @@ def __init__(
)
self.param_grid = param_grid

tags_to_clone = [
"requires-fh-in-fit",
"capability:pred_int",
"scitype:y",
"univariate-only",
"handles-missing-data",
"y_inner_mtype",
"X_inner_mtype",
"X-y-must-have-same-index",
"enforce-index-type",
]

self.clone_tags(forecaster, tags_to_clone)

def _run_search(self, evaluate_candidates):
"""Search all candidates in param_grid."""
_check_param_grid(self.param_grid)
Expand Down Expand Up @@ -574,20 +574,6 @@ def __init__(
self.n_iter = n_iter
self.random_state = random_state

tags_to_clone = [
"requires-fh-in-fit",
"capability:pred_int",
"scitype:y",
"univariate-only",
"handles-missing-data",
"y_inner_mtype",
"X_inner_mtype",
"X-y-must-have-same-index",
"enforce-index-type",
]

self.clone_tags(forecaster, tags_to_clone)

def _run_search(self, evaluate_candidates):
"""Search n_iter candidates from param_distributions."""
return evaluate_candidates(
Expand Down