Skip to content

Add logging to TSDataset.make_future #555

Merged
merged 7 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536))
-
- Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543))
-
- Add logging to TSDataset.make_future, log repr of transform instead of class name ([#555](https://github.com/tinkoff-ai/etna/pull/555))
- Rename `_SARIMAXModel` and `_ProphetModel`, make `SARIMAXModel` and `ProphetModel` inherit from `PerSegmentPredictionIntervalModel` ([#549](https://github.com/tinkoff-ai/etna/pull/549))
-
-
Expand Down
4 changes: 2 additions & 2 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def __repr__(self):
continue
elif param.kind == param.VAR_KEYWORD:
for arg_, value in self.__dict__[arg].items():
args_str_representation += f"{arg_} = {value.__repr__()}, "
args_str_representation += f"{arg_} = {repr(value)}, "
else:
try:
value = self.__dict__[arg]
except KeyError as e:
value = None
warnings.warn(f"You haven't set all parameters inside class __init__ method: {e}")
args_str_representation += f"{arg} = {value.__repr__()}, "
args_str_representation += f"{arg} = {repr(value)}, "
return f"{self.__class__.__name__}({args_str_representation})"


Expand Down
5 changes: 3 additions & 2 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def transform(self, transforms: Sequence["Transform"]):
self._check_endings(warning=True)
self.transforms = transforms
for transform in self.transforms:
tslogger.log(f"Transform {transform.__class__.__name__} is applied to dataset")
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
columns_before = set(self.columns.get_level_values("feature"))
self.df = transform.transform(self.df)
columns_after = set(self.columns.get_level_values("feature"))
Expand All @@ -145,7 +145,7 @@ def fit_transform(self, transforms: Sequence["Transform"]):
self._check_endings(warning=True)
self.transforms = transforms
for transform in self.transforms:
tslogger.log(f"Transform {transform.__class__.__name__} is applied to dataset")
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
columns_before = set(self.columns.get_level_values("feature"))
self.df = transform.fit_transform(self.df)
columns_after = set(self.columns.get_level_values("feature"))
Expand Down Expand Up @@ -288,6 +288,7 @@ def make_future(self, future_steps: int) -> "TSDataset":

if self.transforms is not None:
for transform in self.transforms:
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
df = transform.transform(df)

future_dataset = df.tail(future_steps).copy(deep=True)
Expand Down
45 changes: 39 additions & 6 deletions tests/test_loggers/test_console_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tempfile import NamedTemporaryFile
from typing import Sequence

import pytest
from loguru import logger as _logger
Expand All @@ -16,20 +17,50 @@
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import LagTransform
from etna.transforms import Transform


def check_logged_transforms(log_file: str, transforms: Sequence[Transform]):
"""Check that transforms are logged into the file."""
with open(log_file, "r") as in_file:
lines = in_file.readlines()
assert len(lines) == len(transforms)
for line, transform in zip(lines, transforms):
assert transform.__class__.__name__ in line


def test_tsdataset_transform_logging(example_tsds: TSDataset):
"""Check working of logging inside `TSDataset.transform`."""
transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")]
file = NamedTemporaryFile()
_logger.add(file.name)
example_tsds.fit_transform(transforms=transforms)
idx = tslogger.add(ConsoleLogger())
julia-shenshina marked this conversation as resolved.
Show resolved Hide resolved
example_tsds.transform(transforms=example_tsds.transforms)
check_logged_transforms(log_file=file.name, transforms=transforms)
tslogger.remove(idx)


def test_tsdataset_fit_transform_logging(example_tsds: TSDataset):
"""Check working of logging inside fit_transform of TSDataset."""
"""Check working of logging inside `TSDataset.fit_transform`."""
transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")]
file = NamedTemporaryFile()
_logger.add(file.name)
idx = tslogger.add(ConsoleLogger())
example_tsds.fit_transform(transforms=transforms)
with open(file.name, "r") as in_file:
lines = in_file.readlines()
assert len(lines) == len(transforms)
for line, transform in zip(lines, transforms):
assert transform.__class__.__name__ in line
check_logged_transforms(log_file=file.name, transforms=transforms)
tslogger.remove(idx)


def test_tsdataset_make_future_logging(example_tsds: TSDataset):
"""Check working of logging inside `TSDataset.make_future`."""
transforms = [LagTransform(lags=5, in_column="target"), AddConstTransform(value=5, in_column="target")]
file = NamedTemporaryFile()
_logger.add(file.name)
example_tsds.fit_transform(transforms=transforms)
idx = tslogger.add(ConsoleLogger())
_ = example_tsds.make_future(5)
check_logged_transforms(log_file=file.name, transforms=transforms)
tslogger.remove(idx)


Expand Down Expand Up @@ -88,6 +119,8 @@ def test_model_logging(example_tsds, model):

with open(file.name, "r") as in_file:
lines = in_file.readlines()
# filter out logs related to transforms
lines = [line for line in lines if lags.__class__.__name__ not in line]
assert len(lines) == 2
assert "fit" in lines[0]
assert "forecast" in lines[1]
Expand Down