Skip to content

Commit

Permalink
Add training loop checkpoint support for pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
lvermue committed Nov 2, 2020
1 parent 326cf57 commit 12f5d9a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
47 changes: 44 additions & 3 deletions src/pykeen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,34 @@ def pipeline_from_config(
)


def save_pipeline_checkpoint_helper_file(path: str, random_seed: int) -> None:
"""Save the pipeline checkpoint helper file.
:param path:
Save the state of the pipeline.
:param random_seed:
The random_seed that was used for the pipeline.
"""
torch.save(
{
'random_seed': random_seed,
},
path,
)


def load_pipeline_checkpoint_helper_file(path: str) -> Mapping[str, Any]:
"""Load the pipeline checkpoint helper file.
:param path:
Save the state of the pipeline.
:return:
The pipeline checkpoint helper file dictionary loaded from the pipeline helper file.
"""
return torch.load(path)


def pipeline( # noqa: C901
*,
# 1. Dataset
Expand Down Expand Up @@ -823,9 +851,22 @@ def pipeline( # noqa: C901
:param use_testing_data:
If true, use the testing triples. Otherwise, use the validation triples. Defaults to true - use testing triples.
"""
if random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
# To allow resuming training from a checkpoint when using a pipeline, the pipeline needs to store a helper file
# containing the used random_seed to ensure reproducible results
if training_kwargs.get('checkpoint_file'):
checkpoint_file = training_kwargs.get('checkpoint_file')
pipeline_checkpoint_helper_file = f"{checkpoint_file}_pipeline_helper_file"
if os.path.isfile(pipeline_checkpoint_helper_file):
pipeline_checkpoint_helper_dict = load_pipeline_checkpoint_helper_file(pipeline_checkpoint_helper_file)
random_seed = pipeline_checkpoint_helper_dict['random_seed']
logger.info(f'Loaded random seed {random_seed} from checkpoint.')
else:
logger.info(f"=> no pipeline checkpoint helper file found at '{checkpoint_file}'. Creating a new file.")
if random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
save_pipeline_checkpoint_helper_file(path=pipeline_checkpoint_helper_file, random_seed=random_seed)

set_random_seed(random_seed)

result_tracker_cls: Type[ResultTracker] = get_result_tracker_cls(result_tracker)
Expand Down
4 changes: 2 additions & 2 deletions src/pykeen/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def train(
num_workers: Optional[int] = None,
clear_optimizer: bool = False,
checkpoint_file: Optional[str] = None,
checkpoint_frequency: int = None,
checkpoint_frequency: Optional[int] = None,
) -> List[float]:
"""Train the KGE model.
Expand Down Expand Up @@ -301,7 +301,7 @@ def _train( # noqa: C901
:param checkpoint_file:
The filename for saving checkpoints.
:param checkpoint_frequency:
The frequency of saving checkpoints in minutes.
The frequency of saving checkpoints in minutes. Setting it to 0 will save a checkpoint after every epoch.
:return:
A pair of the KGE model and the losses per epoch.
Expand Down

0 comments on commit 12f5d9a

Please sign in to comment.