diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 0750be6c8828..b698c3d2db45 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -3,9 +3,11 @@ Logging utils """ +import os import warnings from threading import Thread +import pkg_resources as pkg import torch from torch.utils.tensorboard import SummaryWriter @@ -15,11 +17,16 @@ from utils.torch_utils import de_parallel LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases +RANK = int(os.getenv('RANK', -1)) try: import wandb assert hasattr(wandb, '__version__') # verify package import not local dir + if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in [0, -1]: + wandb_login_success = wandb.login(timeout=30) + if not wandb_login_success: + wandb = None except (ImportError, AssertionError): wandb = None diff --git a/utils/loggers/wandb/wandb_utils.py b/utils/loggers/wandb/wandb_utils.py index 92fdd27bb004..39b802e9655e 100644 --- a/utils/loggers/wandb/wandb_utils.py +++ b/utils/loggers/wandb/wandb_utils.py @@ -20,16 +20,6 @@ from utils.general import check_dataset, check_file RANK = int(os.getenv('RANK', -1)) - -try: - import wandb - - assert hasattr(wandb, '__version__') # verify package import not local dir - if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in [0, -1]: - wandb.login(timeout=30) -except (ImportError, AssertionError): - wandb = None - WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'