Skip to content

Commit

Permalink
Fixes #950 (#954)
Browse files Browse the repository at this point in the history
- Added missing BaseSaveHandler
  • Loading branch information
vfdev-5 committed Apr 21, 2020
1 parent ca9d08e commit 9a73813
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

import ignite
from ignite.handlers.checkpoint import BaseSaveHandler
from ignite.contrib.handlers.base_logger import (
BaseLogger,
BaseOptimizerParamsHandler,
Expand Down Expand Up @@ -466,7 +467,7 @@ def close(self):
self.stop()


class NeptuneSaver:
class NeptuneSaver(BaseSaveHandler):
"""Handler that saves input checkpoint to the Neptune server.
Args:
Expand Down
26 changes: 24 additions & 2 deletions tests/ignite/contrib/handlers/test_neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from ignite.engine import Engine, Events, State
from ignite.handlers.checkpoint import Checkpoint
from ignite.contrib.handlers.neptune_logger import *


Expand Down Expand Up @@ -414,7 +415,7 @@ def dummy_handler(engine, logger, event_name):
trainer.run(data, max_epochs=n_epochs)


def test_neptune_saver_serializable(dummy_model_factory, dirname):
def test_neptune_saver_serializable():

mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_artifact = MagicMock()
Expand All @@ -428,7 +429,28 @@ def test_neptune_saver_serializable(dummy_model_factory, dirname):
assert mock_logger.log_artifact.call_count == 1


def test_neptune_saver_non_serializable(dirname):
def test_neptune_saver_integration():

model = torch.nn.Module()
to_save_serializable = {"model": model}

mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_artifact = MagicMock()
mock_logger.delete_artifacts = MagicMock()
saver = NeptuneSaver(mock_logger)

checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=saver, n_saved=1)

trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)
checkpoint(trainer)
trainer.state.iteration = 1
checkpoint(trainer)
assert mock_logger.log_artifact.call_count == 2
assert mock_logger.delete_artifacts.call_count == 1


def test_neptune_saver_non_serializable():

mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_artifact = MagicMock()
Expand Down

0 comments on commit 9a73813

Please sign in to comment.