Skip to content

Commit

Permalink
[ENH] improved output type checking error messages in `BaseTransforme…
Browse files Browse the repository at this point in the history
…r.transform` (#5921)

Improves output type checking error messages in
`BaseTransformer.transform`, using idiomatic `check_is_error_msg`.

Related: #5867
  • Loading branch information
fkiraly committed Feb 23, 2024
1 parent ca8d8e6 commit d8e7e88
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions sktime/transformations/base.py
Expand Up @@ -1227,20 +1227,32 @@ def _convert_output(self, X, metadata, inverse=False):
else:
Xt_metadata_required = []

valid, msg, metadata = check_is_mtype(
ALLOWED_OUT_MTYPES = ["pd.DataFrame", "pd.Series", "np.ndarray"]
Xt_valid, Xt_msg, metadata = check_is_mtype(
Xt,
["pd.DataFrame", "pd.Series", "np.ndarray"],
ALLOWED_OUT_MTYPES,
msg_return_dict="list",
return_metadata=Xt_metadata_required,
)

if not valid:
raise TypeError(
if not Xt_valid:
Xtd = {k: v for k, v in Xt_msg.items() if k in ALLOWED_OUT_MTYPES}
msg_start = (
f"Type checking error in output of _transform of "
f"{self.__class__.__name__}, output"
)
msg_out = (
f"_transform output of {type(self)} does not comply "
"with sktime mtype specifications. See datatypes.MTYPE_REGISTER"
" for mtype specifications. Returned error message:"
f" {msg}. Returned object: {Xt}"
" for mtype specifications."
)
check_is_error_msg(
Xtd,
var_name=msg_start,
allowed_msg=msg_out,
raise_exception=True,
)

if X_input_mtype == "pd.Series" and not metadata["is_univariate"]:
X_output_mtype = "pd.DataFrame"
elif self.get_tags()["scitype:transform-input"] == "Panel":
Expand Down

0 comments on commit d8e7e88

Please sign in to comment.