Skip to content

Add logging to TSDataset.make_future #555

Merged
merged 7 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536))
-
- Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543))
-
- Add logging to TSDataset.make_future, log repr of transform instead of class name ([#555](https://github.com/tinkoff-ai/etna/pull/555))
- Rename `_SARIMAXModel` and `_ProphetModel`, make `SARIMAXModel` and `ProphetModel` inherit from `PerSegmentPredictionIntervalModel` ([#549](https://github.com/tinkoff-ai/etna/pull/549))
-
-
Expand Down
4 changes: 2 additions & 2 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def __repr__(self):
continue
elif param.kind == param.VAR_KEYWORD:
for arg_, value in self.__dict__[arg].items():
args_str_representation += f"{arg_} = {value.__repr__()}, "
args_str_representation += f"{arg_} = {repr(value)}, "
else:
try:
value = self.__dict__[arg]
except KeyError as e:
value = None
warnings.warn(f"You haven't set all parameters inside class __init__ method: {e}")
args_str_representation += f"{arg} = {value.__repr__()}, "
args_str_representation += f"{arg} = {repr(value)}, "
return f"{self.__class__.__name__}({args_str_representation})"


Expand Down
5 changes: 3 additions & 2 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def transform(self, transforms: Sequence["Transform"]):
self._check_endings(warning=True)
self.transforms = transforms
for transform in self.transforms:
tslogger.log(f"Transform {transform.__class__.__name__} is applied to dataset")
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
columns_before = set(self.columns.get_level_values("feature"))
self.df = transform.transform(self.df)
columns_after = set(self.columns.get_level_values("feature"))
Expand All @@ -145,7 +145,7 @@ def fit_transform(self, transforms: Sequence["Transform"]):
self._check_endings(warning=True)
self.transforms = transforms
for transform in self.transforms:
tslogger.log(f"Transform {transform.__class__.__name__} is applied to dataset")
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
columns_before = set(self.columns.get_level_values("feature"))
self.df = transform.fit_transform(self.df)
columns_after = set(self.columns.get_level_values("feature"))
Expand Down Expand Up @@ -288,6 +288,7 @@ def make_future(self, future_steps: int) -> "TSDataset":

if self.transforms is not None:
for transform in self.transforms:
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
df = transform.transform(df)

future_dataset = df.tail(future_steps).copy(deep=True)
Expand Down
45 changes: 39 additions & 6 deletions tests/test_loggers/test_console_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tempfile import NamedTemporaryFile
from typing import Sequence

import pytest
from loguru import logger as _logger
Expand All @@ -16,20 +17,50 @@
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import LagTransform
from etna.transforms import Transform


def check_logged_transforms(log_file: str, transforms: Sequence[Transform]):
"""Check that transforms are logged into the file."""
with open(log_file, "r") as in_file:
lines = in_file.readlines()
assert len(lines) == len(transforms)
for line, transform in zip(lines, transforms):
assert transform.__class__.__name__ in line


def test_tsdataset_transform_logging(example_tsds: TSDataset):
"""Check working of logging inside `TSDataset.transform`."""
transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")]
file = NamedTemporaryFile()
_logger.add(file.name)
example_tsds.fit_transform(transforms=transforms)
idx = tslogger.add(ConsoleLogger())
julia-shenshina marked this conversation as resolved.
Show resolved Hide resolved
example_tsds.transform(transforms=example_tsds.transforms)
check_logged_transforms(log_file=file.name, transforms=transforms)
tslogger.remove(idx)


def test_tsdataset_fit_transform_logging(example_tsds: TSDataset):
"""Check working of logging inside fit_transform of TSDataset."""
"""Check working of logging inside `TSDataset.fit_transform`."""
transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")]
file = NamedTemporaryFile()
_logger.add(file.name)
idx = tslogger.add(ConsoleLogger())
example_tsds.fit_transform(transforms=transforms)
with open(file.name, "r") as in_file:
lines = in_file.readlines()
assert len(lines) == len(transforms)
for line, transform in zip(lines, transforms):
assert transform.__class__.__name__ in line
check_logged_transforms(log_file=file.name, transforms=transforms)
tslogger.remove(idx)


def test_tsdataset_make_future_logging(example_tsds: TSDataset):
"""Check working of logging inside `TSDataset.make_future`."""
transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")]
file = NamedTemporaryFile()
_logger.add(file.name)
example_tsds.fit_transform(transforms=transforms)
idx = tslogger.add(ConsoleLogger())
_ = example_tsds.make_future(5)
check_logged_transforms(log_file=file.name, transforms=transforms)
tslogger.remove(idx)


Expand Down Expand Up @@ -88,6 +119,8 @@ def test_model_logging(example_tsds, model):

with open(file.name, "r") as in_file:
lines = in_file.readlines()
# filter out logs related to transforms
lines = [line for line in lines if lags.__class__.__name__ not in line]
assert len(lines) == 2
assert "fit" in lines[0]
assert "forecast" in lines[1]
Expand Down