Skip to content

Fix warning during creation of ResampleWithDistributionTransform #1230

Merged
merged 3 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix warning during creation of `ResampleWithDistributionTransform` ([#1230](https://github.com/tinkoff-ai/etna/pull/1230))

## [2.0.0] - 2023-04-11
### Added
Expand Down
20 changes: 12 additions & 8 deletions etna/transforms/missing_values/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,30 @@ def __init__(
self.in_column = in_column
self.distribution_column = distribution_column
self.inplace = inplace
self.out_column = self._get_out_column(out_column)
self.out_column = out_column
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
self.in_column_regressor: Optional[bool] = None

if self.inplace and out_column:
warnings.warn("Transformation will be applied inplace, out_column param will be ignored")

super().__init__(
transform=_OneSegmentResampleWithDistributionTransform(
in_column=in_column,
distribution_column=distribution_column,
inplace=inplace,
out_column=self.out_column,
out_column=self._get_column_name(),
),
required_features=[in_column, distribution_column],
)

def _get_out_column(self, out_column: Optional[str]) -> str:
def _get_column_name(
self,
) -> str:
"""Get the `out_column` depending on the transform's parameters."""
if self.inplace and out_column:
warnings.warn("Transformation will be applied inplace, out_column param will be ignored")
if self.inplace:
return self.in_column
if out_column:
return out_column
if self.out_column:
return self.out_column
return self.__repr__()

def get_regressors_info(self) -> List[str]:
Expand All @@ -168,7 +172,7 @@ def get_regressors_info(self) -> List[str]:
raise ValueError("Fit the transform to get the correct regressors info!")
if self.inplace:
return []
return [self.out_column] if self.in_column_regressor else []
return [self._get_column_name()] if self.in_column_regressor else []

def fit(self, ts: TSDataset) -> "ResampleWithDistributionTransform":
"""Fit the transform."""
Expand Down