Skip to content

Commit

Permalink
[BUG] fix get_fitted_params in case of vectoriztion for forecasters…
Browse files Browse the repository at this point in the history
…, transformers (#4105)

This PR fixes an unreported bug that was noticed here: #4101 (reply in thread)

The bug causes only the diagonal combinations of rows/columns to be present as string keyable values in the fitted params dict in case of vectorized forecasters and transformers.

What should be present is the cartesian product, not the diagonal. This is solved by correctly using `product` instead of `zip` that was used accidentally.

Also fixes an issue (not necessarily bug) with the convention of constructing the strings - they should now always be a string that evaluates to a correct data frame `loc` access command, which previously was not the case for `str` keys (as that would eat away the quotes in `str` coercion).
  • Loading branch information
fkiraly committed Jan 23, 2023
1 parent af1ce18 commit 90e54f0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
9 changes: 7 additions & 2 deletions sktime/forecasting/base/_base.py
Expand Up @@ -1219,10 +1219,15 @@ def get_fitted_params(self):
# return forecasters in the "forecasters" param
fitted_params["forecasters"] = forecasters

def _to_str(x):
if isinstance(x, str):
x = f"'{x}'"
return str(x)

# populate fitted_params with forecasters and their parameters
for ix, col in zip(forecasters.index, forecasters.columns):
for ix, col in product(forecasters.index, forecasters.columns):
fcst = forecasters.loc[ix, col]
fcst_key = f"forecasters.loc[{ix},{col}]"
fcst_key = f"forecasters.loc[{_to_str(ix)},{_to_str(col)}]"
fitted_params[fcst_key] = fcst
fcst_params = fcst.get_fitted_params()
for key, val in fcst_params.items():
Expand Down
21 changes: 13 additions & 8 deletions sktime/transformations/base.py
Expand Up @@ -708,14 +708,19 @@ def get_fitted_params(self):
# return forecasters in the "forecasters" param
fitted_params["transformers"] = transformers

# populate fitted_params with ftransformers and their parameters
for ix, col in zip(transformers.index, transformers.columns):
fcst = transformers.loc[ix, col]
fcst_key = f"transformers.loc[{ix},{col}]"
fitted_params[fcst_key] = fcst
fcst_params = fcst.get_fitted_params()
for key, val in fcst_params.items():
fitted_params[f"{fcst_key}__{key}"] = val
def _to_str(x):
if isinstance(x, str):
x = f"'{x}'"
return str(x)

# populate fitted_params with transformers and their parameters
for ix, col in product(transformers.index, transformers.columns):
trafo = transformers.loc[ix, col]
trafo_key = f"transformers.loc[{_to_str(ix)},{_to_str(col)}]"
fitted_params[trafo_key] = trafo
trafo_params = trafo.get_fitted_params()
for key, val in trafo_params.items():
fitted_params[f"{trafo_key}__{key}"] = val

return fitted_params

Expand Down

0 comments on commit 90e54f0

Please sign in to comment.