diff --git a/ignite/contrib/handlers/trains_logger.py b/ignite/contrib/handlers/trains_logger.py index a4f39029938d..3bb490fb213e 100644 --- a/ignite/contrib/handlers/trains_logger.py +++ b/ignite/contrib/handlers/trains_logger.py @@ -577,7 +577,17 @@ def _create_opt_params_handler(self, *args, **kwargs): class TrainsSaver(DiskSaver): - """Handler that saves input checkpoint as Trains artifacts + """ + Handler that saves input checkpoint as Trains artifacts + + Args: + logger (TrainsLogger, optional): An instance of :class:`~ignite.contrib.handlers.TrainsLogger`, ensuring a valid + Trains ``Task`` has been initialized. If not provided, and a Trains Task has not been manually initialized, + a runtime error will be raised. + output_uri (str, optional): The default location for output models and other artifacts uploaded by Trains. For + more information, see ``trains.Task.init``. + dirname (str, optional): Directory path where the checkpoint will be saved. If not provided, a temporary + directory will be created. Examples: @@ -595,7 +605,7 @@ class TrainsSaver(DiskSaver): handler = Checkpoint( to_save, - TrainsSaver(trains_logger), + TrainsSaver(), n_saved=1, score_function=lambda e: 123, score_name="acc", @@ -607,10 +617,7 @@ class TrainsSaver(DiskSaver): """ - def __init__(self, logger: TrainsLogger, output_uri: str = None, dirname: str = None, *args, **kwargs): - if not isinstance(logger, TrainsLogger): - raise TypeError("logger must be an instance of TrainsLogger") - + def __init__(self, logger: TrainsLogger = None, output_uri: str = None, dirname: str = None, *args, **kwargs): try: from trains import Task except ImportError: @@ -619,6 +626,16 @@ def __init__(self, logger: TrainsLogger, output_uri: str = None, dirname: str = "You may install trains using: \n pip install trains \n" ) + if logger and not isinstance(logger, TrainsLogger): + raise TypeError("logger must be an instance of TrainsLogger") + + self.task = Task.current_task() + if not self.task: + raise RuntimeError( + "TrainsSaver requires a Trains Task to be initialized." + "Please use the `logger` argument or call `trains.Task.init()`." + ) + if not dirname: dirname = tempfile.mkdtemp( prefix="ignite_checkpoints_{}".format(datetime.now().strftime("%Y_%m_%d_%H_%M_%S_")) @@ -627,9 +644,6 @@ def __init__(self, logger: TrainsLogger, output_uri: str = None, dirname: str = super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs) - self.logger = logger - self.task = Task.current_task() - if output_uri: self.task.output_uri = output_uri diff --git a/tests/ignite/contrib/handlers/test_trains_logger.py b/tests/ignite/contrib/handlers/test_trains_logger.py index 0b9be046eae5..aa5bbe82eaf7 100644 --- a/tests/ignite/contrib/handlers/test_trains_logger.py +++ b/tests/ignite/contrib/handlers/test_trains_logger.py @@ -1,5 +1,5 @@ import math -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, MagicMock, Mock, call import pytest import torch @@ -623,6 +623,7 @@ def test_trains_disk_saver_integration(): to_save_serializable = {"model": model} mock_logger = MagicMock(spec=TrainsLogger) + trains.Task.current_task = Mock(return_value=object()) trains_saver = TrainsSaver(mock_logger) trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock() @@ -634,3 +635,22 @@ def test_trains_disk_saver_integration(): trainer.state.iteration = 1 checkpoint(trainer) assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2 + + +def test_trains_disk_saver_integration_no_logger(): + model = torch.nn.Module() + to_save_serializable = {"model": model} + + trains.Task.current_task = Mock(return_value=object()) + trains_saver = TrainsSaver() + + trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock() + + checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=trains_saver, n_saved=1) + + trainer = Engine(lambda e, b: None) + trainer.state = State(epoch=0, iteration=0) + checkpoint(trainer) + trainer.state.iteration = 1 + checkpoint(trainer) + assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2