From dbaa95b046f289028d6df36558f8c77c80d2a1d8 Mon Sep 17 00:00:00 2001 From: Martin Gabdushev <33594071+martins0n@users.noreply.github.com> Date: Thu, 2 Dec 2021 15:59:52 +0300 Subject: [PATCH] Wandb Logger does not work unless pytorch is installed (#340) --- CHANGELOG.md | 1 + etna/loggers/__init__.py | 2 +- etna/loggers/wandb_logger.py | 17 ++++++++--------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 519e744d1..26e39e583 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Speed up inference for multisegment regression models ([#333](https://github.com/tinkoff-ai/etna/pull/333)) - Speed up Pipeline._get_backtest_forecasts ([#336](https://github.com/tinkoff-ai/etna/pull/336)) - Speed up SegmentEncoderTransform ([#331](https://github.com/tinkoff-ai/etna/pull/331)) +- Wandb Logger does not work unless pytorch is installed ([#340](https://github.com/tinkoff-ai/etna/pull/340)) ### Fixed - Get rid of lambda in DensityOutliersTransform and get_anomalies_density ([#341](https://github.com/tinkoff-ai/etna/pull/341)) diff --git a/etna/loggers/__init__.py b/etna/loggers/__init__.py index 36d590726..b597a5181 100644 --- a/etna/loggers/__init__.py +++ b/etna/loggers/__init__.py @@ -16,7 +16,7 @@ from etna.loggers.base import _Logger from etna.loggers.console_logger import ConsoleLogger -if SETTINGS.wandb_required and SETTINGS.torch_required: +if SETTINGS.wandb_required: from etna.loggers.wandb_logger import WandbLogger tslogger = _Logger() diff --git a/etna/loggers/wandb_logger.py b/etna/loggers/wandb_logger.py index 15ba060d7..7acbeba6e 100644 --- a/etna/loggers/wandb_logger.py +++ b/etna/loggers/wandb_logger.py @@ -14,16 +14,13 @@ from etna.loggers.base import BaseLogger if TYPE_CHECKING: + from pytorch_lightning.loggers import WandbLogger as PLWandbLogger + from etna.datasets import TSDataset if SETTINGS.wandb_required: import wandb -if SETTINGS.torch_required: - from pytorch_lightning.loggers import WandbLogger as PLWandbLogger -else: - PLWandbLogger = None # type: ignore - def percentile(n: int): """Percentile for pandas agg.""" @@ -50,6 +47,7 @@ def __init__( table: bool = True, name_prefix: str = "", config: Optional[Union[Dict, str, None]] = None, + log_model: bool = False, ): """ Create instance of WandbLogger. @@ -90,12 +88,13 @@ def __init__( self.group = group self.config = config self._experiment = None - self._pl_logger: Optional[PLWandbLogger] = None + self._pl_logger: Optional["PLWandbLogger"] = None self.job_type = job_type self.tags = tags self.plot = plot self.table = table self.name_prefix = name_prefix + self.log_model = log_model def log(self, msg: Union[str, Dict[str, Any]], **kwargs): """ @@ -211,7 +210,6 @@ def start_experiment(self, job_type: Optional[str] = None, group: Optional[str] self.job_type = job_type self.group = group self.reinit_experiment() - self._pl_logger = PLWandbLogger(experiment=self.experiment) def reinit_experiment(self): """Reinit experiment.""" @@ -234,7 +232,9 @@ def finish_experiment(self): @property def pl_logger(self): """Pytorch lightning loggers.""" - self._pl_logger = PLWandbLogger(experiment=self.experiment, log_model=True) + from pytorch_lightning.loggers import WandbLogger as PLWandbLogger + + self._pl_logger = PLWandbLogger(experiment=self.experiment, log_model=self.log_model) return self._pl_logger @property @@ -242,5 +242,4 @@ def experiment(self): """Init experiment.""" if self._experiment is None: self.reinit_experiment() - self._pl_logger = PLWandbLogger(experiment=self.experiment) return self._experiment