diff --git a/CHANGELOG.md b/CHANGELOG.md index 73a194f5f..ca57b5d0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536)) - - Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543)) -- +- Add logging to TSDataset.make_future, log repr of transform instead of class name ([#555](https://github.com/tinkoff-ai/etna/pull/555)) - Rename `_SARIMAXModel` and `_ProphetModel`, make `SARIMAXModel` and `ProphetModel` inherit from `PerSegmentPredictionIntervalModel` ([#549](https://github.com/tinkoff-ai/etna/pull/549)) - - diff --git a/etna/core/mixins.py b/etna/core/mixins.py index f8ac4e8e9..8f7c4a061 100644 --- a/etna/core/mixins.py +++ b/etna/core/mixins.py @@ -16,14 +16,14 @@ def __repr__(self): continue elif param.kind == param.VAR_KEYWORD: for arg_, value in self.__dict__[arg].items(): - args_str_representation += f"{arg_} = {value.__repr__()}, " + args_str_representation += f"{arg_} = {repr(value)}, " else: try: value = self.__dict__[arg] except KeyError as e: value = None warnings.warn(f"You haven't set all parameters inside class __init__ method: {e}") - args_str_representation += f"{arg} = {value.__repr__()}, " + args_str_representation += f"{arg} = {repr(value)}, " return f"{self.__class__.__name__}({args_str_representation})" diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index 4f87deac1..1abff2230 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -134,7 +134,7 @@ def transform(self, transforms: Sequence["Transform"]): self._check_endings(warning=True) self.transforms = transforms for transform in self.transforms: - tslogger.log(f"Transform {transform.__class__.__name__} is applied to dataset") + tslogger.log(f"Transform {repr(transform)} is applied to dataset") columns_before = set(self.columns.get_level_values("feature")) self.df = transform.transform(self.df) columns_after = set(self.columns.get_level_values("feature")) @@ -145,7 +145,7 @@ def fit_transform(self, transforms: Sequence["Transform"]): self._check_endings(warning=True) self.transforms = transforms for transform in self.transforms: - tslogger.log(f"Transform {transform.__class__.__name__} is applied to dataset") + tslogger.log(f"Transform {repr(transform)} is applied to dataset") columns_before = set(self.columns.get_level_values("feature")) self.df = transform.fit_transform(self.df) columns_after = set(self.columns.get_level_values("feature")) @@ -288,6 +288,7 @@ def make_future(self, future_steps: int) -> "TSDataset": if self.transforms is not None: for transform in self.transforms: + tslogger.log(f"Transform {repr(transform)} is applied to dataset") df = transform.transform(df) future_dataset = df.tail(future_steps).copy(deep=True) diff --git a/tests/test_loggers/test_console_logger.py b/tests/test_loggers/test_console_logger.py index 6b22127bd..cd739b69e 100644 --- a/tests/test_loggers/test_console_logger.py +++ b/tests/test_loggers/test_console_logger.py @@ -1,4 +1,5 @@ from tempfile import NamedTemporaryFile +from typing import Sequence import pytest from loguru import logger as _logger @@ -16,20 +17,50 @@ from etna.transforms import AddConstTransform from etna.transforms import DateFlagsTransform from etna.transforms import LagTransform +from etna.transforms import Transform + + +def check_logged_transforms(log_file: str, transforms: Sequence[Transform]): + """Check that transforms are logged into the file.""" + with open(log_file, "r") as in_file: + lines = in_file.readlines() + assert len(lines) == len(transforms) + for line, transform in zip(lines, transforms): + assert transform.__class__.__name__ in line + + +def test_tsdataset_transform_logging(example_tsds: TSDataset): + """Check working of logging inside `TSDataset.transform`.""" + transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")] + file = NamedTemporaryFile() + _logger.add(file.name) + example_tsds.fit_transform(transforms=transforms) + idx = tslogger.add(ConsoleLogger()) + example_tsds.transform(transforms=example_tsds.transforms) + check_logged_transforms(log_file=file.name, transforms=transforms) + tslogger.remove(idx) def test_tsdataset_fit_transform_logging(example_tsds: TSDataset): - """Check working of logging inside fit_transform of TSDataset.""" + """Check working of logging inside `TSDataset.fit_transform`.""" transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")] file = NamedTemporaryFile() _logger.add(file.name) idx = tslogger.add(ConsoleLogger()) example_tsds.fit_transform(transforms=transforms) - with open(file.name, "r") as in_file: - lines = in_file.readlines() - assert len(lines) == len(transforms) - for line, transform in zip(lines, transforms): - assert transform.__class__.__name__ in line + check_logged_transforms(log_file=file.name, transforms=transforms) + tslogger.remove(idx) + + +def test_tsdataset_make_future_logging(example_tsds: TSDataset): + """Check working of logging inside `TSDataset.make_future`.""" + transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")] + file = NamedTemporaryFile() + _logger.add(file.name) + example_tsds.fit_transform(transforms=transforms) + idx = tslogger.add(ConsoleLogger()) + _ = example_tsds.make_future(5) + check_logged_transforms(log_file=file.name, transforms=transforms) tslogger.remove(idx) @@ -88,6 +119,8 @@ def test_model_logging(example_tsds, model): with open(file.name, "r") as in_file: lines = in_file.readlines() + # filter out logs related to transforms + lines = [line for line in lines if lags.__class__.__name__ not in line] assert len(lines) == 2 assert "fit" in lines[0] assert "forecast" in lines[1]