Skip to content

Make in_column the first argument in every transform #247

Merged
merged 5 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Add possibility to set custom in_column for ConfidenceIntervalOutliersTransform ([#240](https://github.com/tinkoff-ai/etna-ts/pull/240))
- Make `in_column` the first argument in every transform ([#247](https://github.com/tinkoff-ai/etna-ts/pull/247))

### Fixed
- Fixed broken links in docs command section ([#223](https://github.com/tinkoff-ai/etna-ts/pull/223))
Expand Down
18 changes: 9 additions & 9 deletions etna/transforms/add_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@


class _OneSegmentAddConstTransform(Transform):
def __init__(self, value: float, in_column: str, inplace: bool = True):
def __init__(self, in_column: str, value: float, inplace: bool = True):
"""
Init _OneSegmentAddConstTransform.

Parameters
----------
value:
value that should be added to the series
in_column:
column to apply transform
value:
value that should be added to the series
inplace:
if True, apply add constant transformation inplace to in_column, if False, add column {in_column}_add_{value} to dataset
"""
self.value = value
self.in_column = in_column
self.value = value
self.inplace = inplace
self.out_column = self.in_column if self.inplace else f"{self.in_column}_add_{self.value}"

Expand Down Expand Up @@ -72,24 +72,24 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
class AddConstTransform(PerSegmentWrapper):
"""AddConstTransform add constant for given series."""

def __init__(self, value: float, in_column: str, inplace: bool = True):
def __init__(self, in_column: str, value: float, inplace: bool = True):
"""
Init AddConstTransform.

Parameters
----------
value:
value that should be added to the series
in_column:
column to apply transform
value:
value that should be added to the series
inplace:
if True, apply add constant transformation inplace to in_column, if False, add column {in_column}_add_{value} to dataset
"""
self.value = value
self.in_column = in_column
self.value = value
self.inplace = inplace
super().__init__(
transform=_OneSegmentAddConstTransform(value=self.value, in_column=self.in_column, inplace=self.inplace)
transform=_OneSegmentAddConstTransform(in_column=self.in_column, value=self.value, inplace=self.inplace)
)


Expand Down
15 changes: 7 additions & 8 deletions etna/transforms/lags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


class _OneSegmentLagFeature(Transform):
def __init__(self, lags: Union[List[int], int], in_column: str):
def __init__(self, in_column: str, lags: Union[List[int], int]):
self.in_column = in_column
if isinstance(lags, int):
if lags < 1:
raise ValueError(f"{type(self).__name__} works only with positive lags values, {lags} given")
Expand All @@ -18,7 +19,6 @@ def __init__(self, lags: Union[List[int], int], in_column: str):
raise ValueError(f"{type(self).__name__} works only with positive lags values")
self.lags = lags

self.in_column = in_column
self.out_postfix = "_lag"
self.out_prefix = "regressor_"

Expand All @@ -35,21 +35,20 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
class LagTransform(PerSegmentWrapper):
"""Generates series of lags from given dataframe. Creates columns 'regressor_<column>_lag_<number>'."""

def __init__(self, lags: Union[List[int], int], in_column: str):
def __init__(self, in_column: str, lags: Union[List[int], int]):
"""Create instance of LagTransform.

Parameters
----------
lags:
int value or list of values for lags computation; if int, generate range of lags from 1 to given value
in_column:
name of processed column

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we delete this space?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we should't)

lags:
int value or list of values for lags computation; if int, generate range of lags from 1 to given value
Raises
------
ValueError:
if lags value contains non-positive values
"""
self.lags = lags
self.in_column = in_column
super().__init__(transform=_OneSegmentLagFeature(lags=self.lags, in_column=self.in_column))
self.lags = lags
super().__init__(transform=_OneSegmentLagFeature(in_column=self.in_column, lags=self.lags))
2 changes: 1 addition & 1 deletion etna/transforms/scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(
if incorrect mode given
"""
super().__init__(
transformer=StandardScaler(with_mean=with_mean, with_std=with_std, copy=True),
in_column=in_column,
transformer=StandardScaler(with_mean=with_mean, with_std=with_std, copy=True),
inplace=inplace,
mode=mode,
)
Expand Down
6 changes: 3 additions & 3 deletions etna/transforms/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class SklearnTransform(Transform):

def __init__(
self,
in_column: Union[str, List[str]],
transformer: TransformerMixin,
in_column: Optional[Union[str, List[str]]] = None,
inplace: bool = True,
mode: Union[TransformMode, str] = "per-segment",
):
Expand All @@ -32,10 +32,10 @@ def __init__(

Parameters
----------
transformer:
sklearn.base.TransformerMixin instance.
in_column:
columns to be transformed, if None - all columns will be scaled.
transformer:
sklearn.base.TransformerMixin instance.
inplace:
features are changed by transformed.
mode:
Expand Down
18 changes: 12 additions & 6 deletions etna/transforms/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class WindowStatisticsTransform(Transform, ABC):

def __init__(
self,
window: int,
in_column: str,
window: int,
seasonality: int = 1,
min_periods: int = 1,
out_postfix: Optional[str] = None,
Expand All @@ -27,6 +27,8 @@ def __init__(

Parameters
----------
in_column: str
name of processed column
window: int
size of window to aggregate
seasonality: int
Expand All @@ -39,6 +41,7 @@ def __init__(
fillna: float
value to fill results NaNs with
"""
self.in_column = in_column
self.window = window
self.seasonality = seasonality
self.min_periods = min_periods
Expand All @@ -47,7 +50,6 @@ def __init__(
self.kwargs = kwargs
self.min_required_len = max(self.min_periods - 1, 0) * self.seasonality + 1
self.history = self.window * self.seasonality
self.in_column = in_column

def fit(self, *args) -> "WindowStatisticsTransform":
"""Fits transform."""
Expand Down Expand Up @@ -111,8 +113,8 @@ class MeanTransform(WindowStatisticsTransform):

def __init__(
self,
window: int,
in_column: str,
window: int,
seasonality: int = 1,
alpha: float = 1,
min_periods: int = 1,
Expand All @@ -123,6 +125,8 @@ def __init__(

Parameters
----------
in_column: str
name of processed column
window: int
size of window to aggregate
seasonality: int
Expand All @@ -138,8 +142,8 @@ def __init__(
value to fill results NaNs with
"""
super().__init__(
window=window,
in_column=in_column,
window=window,
seasonality=seasonality,
min_periods=min_periods,
out_postfix=out_postfix,
Expand Down Expand Up @@ -196,9 +200,9 @@ class QuantileTransform(WindowStatisticsTransform):

def __init__(
self,
in_column: str,
quantile: float,
window: int,
in_column: str,
seasonality: int = 1,
min_periods: int = 1,
out_postfix: Optional[str] = None,
Expand All @@ -208,6 +212,8 @@ def __init__(

Parameters
----------
in_column: str
name of processed column
quantile: float
quantile to calculate
window: int
Expand All @@ -224,8 +230,8 @@ def __init__(
"""
self.quantile = quantile
super().__init__(
window=window,
in_column=in_column,
window=window,
seasonality=seasonality,
min_periods=min_periods,
out_postfix=out_postfix,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms/test_lag_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_repr():
lags = list(range(8, 24, 1))
transform = LagTransform(lags=lags, in_column="target")
transform_repr = transform.__repr__()
true_repr = f"{transform_class_repr}(lags = {lags}, in_column = 'target', )"
true_repr = f"{transform_class_repr}(in_column = 'target', lags = {lags}, )"
assert transform_repr == true_repr


Expand Down