Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers refactor: summarizer and base class #1663

Merged
merged 11 commits into from Dec 7, 2021
5 changes: 5 additions & 0 deletions sktime/tests/test_all_estimators.py
Expand Up @@ -348,9 +348,14 @@ def test_fit_returns_self(estimator_instance):
def test_raises_not_fitted_error(estimator_instance):
"""Check that we raise appropriate error for unfitted estimators."""
estimator = estimator_instance

# call methods without prior fitting and check that they raise our
# NotFittedError
for method in NON_STATE_CHANGING_METHODS:
# if has fit-in-transform tag (empty fit), no error needs to be raised
if method == "transform" and estimator.get_tag("fit-in-transform"):
mloning marked this conversation as resolved.
Show resolved Hide resolved
return None
# otherwise check if an error is raised
if hasattr(estimator, method):
args = _make_args(estimator, method)
with pytest.raises(NotFittedError, match=r"has not been fitted"):
Expand Down
5 changes: 4 additions & 1 deletion sktime/transformations/base.py
Expand Up @@ -420,7 +420,7 @@ def transform(self, X, y=None, Z=None):
# we concatenate those and overwrite the index with that of X
elif output_scitype == "Primitives":
Xt = pd.concat(Xt)
Xt.index = X.index
Xt = Xt.reset_index(drop=True)
return Xt

# convert X/y to supported inner type, if necessary
Expand Down Expand Up @@ -476,6 +476,9 @@ def transform(self, X, y=None, Z=None):
)
elif output_scitype == "Primitives":
# we "abuse" the Series converter to ensure df output
# & reset index to have integers for instances
if isinstance(Xt, (pd.DataFrame, pd.Series)):
Xt = Xt.reset_index(drop=True)
Xt = convert_to(
Xt,
to_type="pd.DataFrame",
Expand Down
21 changes: 1 addition & 20 deletions sktime/transformations/series/summarize.py
Expand Up @@ -10,7 +10,6 @@

from sktime.transformations.base import _SeriesToPrimitivesTransformer
from sktime.utils._maint import deprecated
from sktime.utils.validation.series import check_series

ALLOWED_SUM_FUNCS = [
"mean",
Expand Down Expand Up @@ -133,6 +132,7 @@ class SummaryTransformer(_SeriesToPrimitivesTransformer):

_tags = {
"fit-in-transform": True,
"X_inner_mtype": ["pd.DataFrame", "pd.Series"],
}

def __init__(
Expand Down Expand Up @@ -176,25 +176,6 @@ def _transform(self, Z, X=None):

return summary_value.T

def transform(self, Z, X=None):
"""Transform series.

Parameters
----------
Z : pd.Series or pd.DataFrame
The series to transform.

Returns
-------
summary_value : pd.DataFrame
DataFrame where series are instances (rows) and calculated summary
values are treated as features (columns).
"""
self.check_is_fitted()
Z = check_series(Z)
summary_value = self._transform(Z, X=X)
return summary_value


class MeanTransformer(SummaryTransformer):
"""Calculate mean value of a time series.
Expand Down