Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers refactor: summarizer and base class #1663

Merged
merged 11 commits into from Dec 7, 2021
1 change: 1 addition & 0 deletions sktime/tests/test_all_estimators.py
Expand Up @@ -348,6 +348,7 @@ def test_fit_returns_self(estimator_instance):
def test_raises_not_fitted_error(estimator_instance):
"""Check that we raise appropriate error for unfitted estimators."""
estimator = estimator_instance

# call methods without prior fitting and check that they raise our
# NotFittedError
for method in NON_STATE_CHANGING_METHODS:
Expand Down
28 changes: 13 additions & 15 deletions sktime/transformations/base.py
Expand Up @@ -318,11 +318,8 @@ def transform(self, X, y=None, Z=None):
"""
X = _handle_alias(X, Z)

# check whether is fitted, unless fit-in-transform is true
if self.get_tag("fit-in-transform"):
self.fit(X=X, y=y, Z=Z)
else:
self.check_is_fitted()
# check whether is fitted
self.check_is_fitted()

# input checks and minor coercions on X, y
###########################################
Expand Down Expand Up @@ -400,16 +397,14 @@ def transform(self, X, y=None, Z=None):
)
X = convert_to(X, to_type="df-list", as_scitype="Panel")

if self.get_tag("fit-in-transform"):
Xt = [clone(self).transform(Xi) for Xi in X]
# these are the transformers-per-instanced, fitted in fit
transformers = self.transformers_
if len(transformers) != len(X):
raise RuntimeError(
"found different number of instances in transform than in fit"
)
else:
transformers = self.transformers_
if len(transformers) != len(X):
raise RuntimeError(
"found different number of instances in transform than in fit"
)
else:
Xt = [transformers[i].transform(X[i]) for i in range(len(X))]
Xt = [transformers[i].transform(X[i]) for i in range(len(X))]
# now we have a list of transformed instances

# if the output is Series, Xt is a Panel and we convert back
Expand All @@ -420,7 +415,7 @@ def transform(self, X, y=None, Z=None):
# we concatenate those and overwrite the index with that of X
elif output_scitype == "Primitives":
Xt = pd.concat(Xt)
Xt.index = X.index
Xt = Xt.reset_index(drop=True)
return Xt

# convert X/y to supported inner type, if necessary
Expand Down Expand Up @@ -476,6 +471,9 @@ def transform(self, X, y=None, Z=None):
)
elif output_scitype == "Primitives":
# we "abuse" the Series converter to ensure df output
# & reset index to have integers for instances
if isinstance(Xt, (pd.DataFrame, pd.Series)):
Xt = Xt.reset_index(drop=True)
Xt = convert_to(
Xt,
to_type="pd.DataFrame",
Expand Down
21 changes: 1 addition & 20 deletions sktime/transformations/series/summarize.py
Expand Up @@ -10,7 +10,6 @@

from sktime.transformations.base import _SeriesToPrimitivesTransformer
from sktime.utils._maint import deprecated
from sktime.utils.validation.series import check_series

ALLOWED_SUM_FUNCS = [
"mean",
Expand Down Expand Up @@ -133,6 +132,7 @@ class SummaryTransformer(_SeriesToPrimitivesTransformer):

_tags = {
"fit-in-transform": True,
"X_inner_mtype": ["pd.DataFrame", "pd.Series"],
}

def __init__(
Expand Down Expand Up @@ -176,25 +176,6 @@ def _transform(self, Z, X=None):

return summary_value.T

def transform(self, Z, X=None):
"""Transform series.

Parameters
----------
Z : pd.Series or pd.DataFrame
The series to transform.

Returns
-------
summary_value : pd.DataFrame
DataFrame where series are instances (rows) and calculated summary
values are treated as features (columns).
"""
self.check_is_fitted()
Z = check_series(Z)
summary_value = self._transform(Z, X=X)
return summary_value


class MeanTransformer(SummaryTransformer):
"""Calculate mean value of a time series.
Expand Down