Skip to content

Commit

Permalink
Wandb Logger does not work unless pytorch is installed (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
martins0n committed Dec 2, 2021
1 parent 7a66594 commit dbaa95b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Speed up inference for multisegment regression models ([#333](https://github.com/tinkoff-ai/etna/pull/333))
- Speed up Pipeline._get_backtest_forecasts ([#336](https://github.com/tinkoff-ai/etna/pull/336))
- Speed up SegmentEncoderTransform ([#331](https://github.com/tinkoff-ai/etna/pull/331))
- Wandb Logger does not work unless pytorch is installed ([#340](https://github.com/tinkoff-ai/etna/pull/340))

### Fixed
- Get rid of lambda in DensityOutliersTransform and get_anomalies_density ([#341](https://github.com/tinkoff-ai/etna/pull/341))
Expand Down
2 changes: 1 addition & 1 deletion etna/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from etna.loggers.base import _Logger
from etna.loggers.console_logger import ConsoleLogger

if SETTINGS.wandb_required and SETTINGS.torch_required:
if SETTINGS.wandb_required:
from etna.loggers.wandb_logger import WandbLogger

tslogger = _Logger()
17 changes: 8 additions & 9 deletions etna/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
from etna.loggers.base import BaseLogger

if TYPE_CHECKING:
from pytorch_lightning.loggers import WandbLogger as PLWandbLogger

from etna.datasets import TSDataset

if SETTINGS.wandb_required:
import wandb

if SETTINGS.torch_required:
from pytorch_lightning.loggers import WandbLogger as PLWandbLogger
else:
PLWandbLogger = None # type: ignore


def percentile(n: int):
"""Percentile for pandas agg."""
Expand All @@ -50,6 +47,7 @@ def __init__(
table: bool = True,
name_prefix: str = "",
config: Optional[Union[Dict, str, None]] = None,
log_model: bool = False,
):
"""
Create instance of WandbLogger.
Expand Down Expand Up @@ -90,12 +88,13 @@ def __init__(
self.group = group
self.config = config
self._experiment = None
self._pl_logger: Optional[PLWandbLogger] = None
self._pl_logger: Optional["PLWandbLogger"] = None
self.job_type = job_type
self.tags = tags
self.plot = plot
self.table = table
self.name_prefix = name_prefix
self.log_model = log_model

def log(self, msg: Union[str, Dict[str, Any]], **kwargs):
"""
Expand Down Expand Up @@ -211,7 +210,6 @@ def start_experiment(self, job_type: Optional[str] = None, group: Optional[str]
self.job_type = job_type
self.group = group
self.reinit_experiment()
self._pl_logger = PLWandbLogger(experiment=self.experiment)

def reinit_experiment(self):
"""Reinit experiment."""
Expand All @@ -234,13 +232,14 @@ def finish_experiment(self):
@property
def pl_logger(self):
"""Pytorch lightning loggers."""
self._pl_logger = PLWandbLogger(experiment=self.experiment, log_model=True)
from pytorch_lightning.loggers import WandbLogger as PLWandbLogger

self._pl_logger = PLWandbLogger(experiment=self.experiment, log_model=self.log_model)
return self._pl_logger

@property
def experiment(self):
"""Init experiment."""
if self._experiment is None:
self.reinit_experiment()
self._pl_logger = PLWandbLogger(experiment=self.experiment)
return self._experiment

0 comments on commit dbaa95b

Please sign in to comment.