Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions ignite/contrib/handlers/trains_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,17 @@ def _create_opt_params_handler(self, *args, **kwargs):


class TrainsSaver(DiskSaver):
"""Handler that saves input checkpoint as Trains artifacts
"""
Handler that saves input checkpoint as Trains artifacts

Args:
logger (TrainsLogger, optional): An instance of :class:`~ignite.contrib.handlers.TrainsLogger`, ensuring a valid
Trains ``Task`` has been initialized. If not provided, and a Trains Task has not been manually initialized,
a runtime error will be raised.
output_uri (str, optional): The default location for output models and other artifacts uploaded by Trains. For
more information, see ``trains.Task.init``.
dirname (str, optional): Directory path where the checkpoint will be saved. If not provided, a temporary
directory will be created.

Examples:

Expand All @@ -595,7 +605,7 @@ class TrainsSaver(DiskSaver):

handler = Checkpoint(
to_save,
TrainsSaver(trains_logger),
TrainsSaver(),
n_saved=1,
score_function=lambda e: 123,
score_name="acc",
Expand All @@ -607,10 +617,7 @@ class TrainsSaver(DiskSaver):

"""

def __init__(self, logger: TrainsLogger, output_uri: str = None, dirname: str = None, *args, **kwargs):
if not isinstance(logger, TrainsLogger):
raise TypeError("logger must be an instance of TrainsLogger")

def __init__(self, logger: TrainsLogger = None, output_uri: str = None, dirname: str = None, *args, **kwargs):
try:
from trains import Task
except ImportError:
Expand All @@ -619,6 +626,16 @@ def __init__(self, logger: TrainsLogger, output_uri: str = None, dirname: str =
"You may install trains using: \n pip install trains \n"
)

if logger and not isinstance(logger, TrainsLogger):
raise TypeError("logger must be an instance of TrainsLogger")

self.task = Task.current_task()
if not self.task:
raise RuntimeError(
"TrainsSaver requires a Trains Task to be initialized."
"Please use the `logger` argument or call `trains.Task.init()`."
)

if not dirname:
dirname = tempfile.mkdtemp(
prefix="ignite_checkpoints_{}".format(datetime.now().strftime("%Y_%m_%d_%H_%M_%S_"))
Expand All @@ -627,9 +644,6 @@ def __init__(self, logger: TrainsLogger, output_uri: str = None, dirname: str =

super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs)

self.logger = logger
self.task = Task.current_task()

if output_uri:
self.task.output_uri = output_uri

Expand Down
22 changes: 21 additions & 1 deletion tests/ignite/contrib/handlers/test_trains_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from unittest.mock import ANY, MagicMock, call
from unittest.mock import ANY, MagicMock, Mock, call

import pytest
import torch
Expand Down Expand Up @@ -623,6 +623,7 @@ def test_trains_disk_saver_integration():
to_save_serializable = {"model": model}

mock_logger = MagicMock(spec=TrainsLogger)
trains.Task.current_task = Mock(return_value=object())
trains_saver = TrainsSaver(mock_logger)
trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock()

Expand All @@ -634,3 +635,22 @@ def test_trains_disk_saver_integration():
trainer.state.iteration = 1
checkpoint(trainer)
assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2


def test_trains_disk_saver_integration_no_logger():
model = torch.nn.Module()
to_save_serializable = {"model": model}

trains.Task.current_task = Mock(return_value=object())
trains_saver = TrainsSaver()

trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock()

checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=trains_saver, n_saved=1)

trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)
checkpoint(trainer)
trainer.state.iteration = 1
checkpoint(trainer)
assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2