Skip to content

Commit

Permalink
Improve Comet Logger pickled behavior (Lightning-AI#2553)
Browse files Browse the repository at this point in the history
* Improve Comet Logger pickled behavior

* Delay the creation of the actual experiment object for as long as we can.
* Save the experiment id in case an Experiment object is created so we can
  continue the same experiment in the sub-processes.
* Run pre-commit on the comet file.

* Handle review comment

Make most Comet Logger attribute protected as they might not reflect the final
Experiment attributes. Also fix the typo in the test name.

* Ensure that CometLogger.name and CometLogger.version always returns str

* Add new test for CometLogger.version behavior

* Add new tests for CometLogger.name and CometLogger.version

* Apply review suggestions

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Remove extraneous comments in Comet logger tests

* Fix lint issues

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people committed Sep 18, 2020
1 parent 580b04b commit e2af4f1
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 86 deletions.
137 changes: 87 additions & 50 deletions pytorch_lightning/loggers/comet.py
Expand Up @@ -17,6 +17,8 @@
-----
"""

import os

from argparse import Namespace
from typing import Optional, Dict, Union, Any

Expand All @@ -25,6 +27,8 @@
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml import BaseExperiment as CometBaseExperiment
from comet_ml import generate_guid

try:
from comet_ml.api import API
except ImportError: # pragma: no-cover
Expand All @@ -37,11 +41,11 @@
CometOfflineExperiment = None
CometBaseExperiment = None
API = None
generate_guid = None
_COMET_AVAILABLE = False
else:
_COMET_AVAILABLE = True


import torch
from torch import is_tensor

Expand Down Expand Up @@ -112,20 +116,24 @@ class CometLogger(LightningLoggerBase):
file but still want to run offline experiments.
"""

def __init__(self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
offline: bool = False,
**kwargs):
def __init__(
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
offline: bool = False,
**kwargs,
):

if not _COMET_AVAILABLE:
raise ImportError('You want to use `comet_ml` logger which is not installed yet,'
' install it with `pip install comet-ml`.')
raise ImportError(
"You want to use `comet_ml` logger which is not installed yet,"
" install it with `pip install comet-ml`."
)
super().__init__()
self._experiment = None

Expand All @@ -145,16 +153,16 @@ def __init__(self,
self._save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise MisconfigurationException(
"CometLogger requires either api_key or save_dir during initialization."
)
raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.")

log.info(f"CometLogger will be initialized in {self.mode} mode")

self.workspace = workspace
self.project_name = project_name
self.experiment_key = experiment_key
self._project_name = project_name
self._experiment_key = experiment_key
self._experiment_name = experiment_name
self._kwargs = kwargs
self._future_experiment_key = None

if rest_api_key is not None:
# Comet.ml rest API, used to determine version number
Expand All @@ -164,8 +172,6 @@ def __init__(self,
self.rest_api_key = None
self.comet_api = None

if experiment_name:
self.experiment.set_name(experiment_name)
self._kwargs = kwargs

@property
Expand All @@ -183,30 +189,37 @@ def experiment(self) -> CometBaseExperiment:
if self._experiment is not None:
return self._experiment

if self.mode == "online":
if self.experiment_key is None:
self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
self.experiment_key = self._experiment.get_key()
if self._future_experiment_key is not None:
os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key
self._future_experiment_key = None

try:
if self.mode == "online":
if self._experiment_key is None:
self._experiment = CometExperiment(
api_key=self.api_key, workspace=self.workspace, project_name=self._project_name, **self._kwargs
)
self._experiment_key = self._experiment.get_key()
else:
self._experiment = CometExistingExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self._project_name,
previous_experiment=self._experiment_key,
**self._kwargs,
)
else:
self._experiment = CometExistingExperiment(
api_key=self.api_key,
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
previous_experiment=self.experiment_key,
**self._kwargs
project_name=self._project_name,
**self._kwargs,
)
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
finally:
os.environ.pop("COMET_EXPERIMENT_KEY", None)

if self._experiment_name:
self._experiment.set_name(self._experiment_name)

return self._experiment

Expand All @@ -217,13 +230,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
self.experiment.log_parameters(params)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand Down Expand Up @@ -257,13 +265,42 @@ def save_dir(self) -> Optional[str]:

@property
def name(self) -> str:
return str(self.experiment.project_name)
# Don't create an experiment if we don't have one
if self._experiment is not None and self._experiment.project_name is not None:
return self._experiment.project_name

if self._project_name is not None:
return self._project_name

return "comet-default"

@property
def version(self) -> str:
return self.experiment.id
# Don't create an experiment if we don't have one
if self._experiment is not None:
return self._experiment.id

if self._experiment_key is not None:
return self._experiment_key

if self._future_experiment_key is not None:
return self._future_experiment_key

# Pre-generate an experiment key
self._future_experiment_key = generate_guid()

return self._future_experiment_key

def __getstate__(self):
state = self.__dict__.copy()

# Save the experiment id in case an experiment object already exists,
# this way we could create an ExistingExperiment pointing to the same
# experiment
state["_experiment_key"] = self._experiment.id if self._experiment is not None else None

# Remove the experiment object as it contains hard to pickle objects
# (like network connections), the experiment object will be recreated if
# needed later
state["_experiment"] = None
return state
123 changes: 87 additions & 36 deletions tests/loggers/test_comet.py
Expand Up @@ -13,43 +13,23 @@ def test_comet_logger_online():
"""Test comet online with mocks."""
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(
api_key='key',
workspace='dummy-test',
project_name='general'
)
logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general'
)
comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test both given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(
save_dir='test',
api_key='key',
workspace='dummy-test',
project_name='general'
)
logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general'
)
comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test neither given
with pytest.raises(MisconfigurationException):
CometLogger(
workspace='dummy-test',
project_name='general'
)
CometLogger(workspace='dummy-test', project_name='general')

# Test already exists
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
Expand All @@ -58,36 +38,48 @@ def test_comet_logger_online():
experiment_name='experiment',
api_key='key',
workspace='dummy-test',
project_name='general'
project_name='general',
)

_ = logger.experiment

comet_existing.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general',
previous_experiment='test'
api_key='key', workspace='dummy-test', project_name='general', previous_experiment='test'
)

comet_existing().set_name.assert_called_once_with('experiment')

with patch('pytorch_lightning.loggers.comet.API') as api:
CometLogger(
api_key='key',
workspace='dummy-test',
project_name='general',
rest_api_key='rest'
)
CometLogger(api_key='key', workspace='dummy-test', project_name='general', rest_api_key='rest')

api.assert_called_once_with('rest')


def test_comet_logger_experiment_name():
"""Test that Comet Logger experiment name works correctly."""

api_key = "key"
experiment_name = "My Name"

# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

assert logger._experiment is None

_ = logger.experiment

comet.assert_called_once_with(api_key=api_key, project_name=None, workspace=None)

comet().set_name.assert_called_once_with(experiment_name)


def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
""" Test that the logger creates the folders and files in the right place. """
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit

monkeypatch.setattr(atexit, 'register', lambda _: None)

logger = CometLogger(project_name='test', save_dir=tmpdir)
Expand All @@ -107,9 +99,68 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}


def test_comet_name_default():
""" Test that CometLogger.name don't create an Experiment and returns a default value. """

api_key = "key"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(api_key=api_key)

assert logger._experiment is None

assert logger.name == "comet-default"

assert logger._experiment is None


def test_comet_name_project_name():
""" Test that CometLogger.name does not create an Experiment and returns project name if passed. """

api_key = "key"
project_name = "My Project Name"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(api_key=api_key, project_name=project_name)

assert logger._experiment is None

assert logger.name == project_name

assert logger._experiment is None


def test_comet_version_without_experiment():
""" Test that CometLogger.version does not create an Experiment. """

api_key = "key"
experiment_name = "My Name"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)

assert logger._experiment is None

first_version = logger.version
assert first_version is not None

assert logger.version == first_version

assert logger._experiment is None

_ = logger.experiment

logger.reset_experiment()

second_version = logger.version
assert second_version is not None
assert second_version != first_version


def test_comet_epoch_logging(tmpdir, monkeypatch):
""" Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """
import atexit

monkeypatch.setattr(atexit, "register", lambda _: None)
with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics:
logger = CometLogger(project_name="test", save_dir=tmpdir)
Expand Down

0 comments on commit e2af4f1

Please sign in to comment.