Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpoint saving and loading functionality to training loop #123

Merged
merged 74 commits into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
dc55230
Add checkpoint saving and loading functionality to training loop
lvermue Nov 1, 2020
9452a01
Flake8 and parameters defaults
lvermue Nov 1, 2020
21ab40e
Flake8
lvermue Nov 1, 2020
c05b38b
Refactor internal completed epoch tracking
lvermue Nov 1, 2020
3a0f420
Correct epoch number handling in training loop
lvermue Nov 2, 2020
326cf57
Add automatic loading of early stoppers from training loop checkpoints
lvermue Nov 2, 2020
12f5d9a
Add training loop checkpoint support for pipelines
lvermue Nov 2, 2020
0af8dbe
Fix indentation
lvermue Nov 2, 2020
aaa7eb2
Change training loop checksum creation from adler32 to md5
lvermue Nov 2, 2020
7ec0d44
Fix missing random_seed failure, if no training config file is provided
lvermue Nov 2, 2020
7b84f70
Flake8, refactor code and add CheckpointMismatchError
lvermue Nov 2, 2020
21a982c
Fix flake8
lvermue Nov 2, 2020
11e7c4e
Update exception
cthoyt Nov 2, 2020
5f11c15
Add outline for checkpoint tutorial
cthoyt Nov 2, 2020
fc9a5a9
Correct random state recovery
lvermue Nov 2, 2020
2bf5b98
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Nov 2, 2020
256beaa
Remove pipeline checkpoint helper file help function
lvermue Nov 22, 2020
1e170f3
Remove torch save helper file
lvermue Nov 22, 2020
3ba5288
Fix function argument for torch save
lvermue Nov 22, 2020
cfd0598
Add units tests for training loop checkpoint
lvermue Nov 22, 2020
8d2740f
Merge branch 'master' into allow_training_checkpoints
cthoyt Nov 23, 2020
07c3d42
Add unit tests for checkpoints
lvermue Nov 23, 2020
37b777b
Fix flake8
lvermue Nov 23, 2020
8af806f
Add failure fallback checkpoints to training loop
lvermue Nov 23, 2020
bc20d35
Correct checkpoint root dir handling
lvermue Nov 23, 2020
17cdfcc
Add implicit random seed handling from checkpoints in pipeline
lvermue Nov 24, 2020
b036b9d
Code cleanup
cthoyt Nov 24, 2020
5d2e47b
More refactoring
cthoyt Nov 24, 2020
583d494
Add CPU/GPU random state differentiation
lvermue Nov 24, 2020
604bc2f
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Nov 24, 2020
26e0b39
Workaround for CUDA rng state
cthoyt Nov 24, 2020
a564fef
Refactor tests
cthoyt Nov 24, 2020
67776c8
Remove unnecessary stuff
cthoyt Nov 24, 2020
578fd1e
Unnest logic
cthoyt Nov 24, 2020
09d03e3
Improve typing and safety
cthoyt Nov 24, 2020
9ca40dc
Fix testing
lvermue Nov 24, 2020
5d64d6c
Fix pipeline checkpoint unit tests
lvermue Nov 24, 2020
bf82e79
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 24, 2020
661eed3
Fix usage of forbidden characters for Windows in filepaths
lvermue Nov 24, 2020
675bd59
Refactor loading of states for the training loop and stoppers
lvermue Nov 25, 2020
f618a90
Fix flake8
lvermue Nov 25, 2020
14e14b7
Fix handling of stopper state dictionaries
lvermue Nov 25, 2020
cd64e7e
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 25, 2020
d9abffd
Fix flake8
lvermue Nov 25, 2020
f226b0a
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 29, 2020
240140b
Fix unit tests
lvermue Nov 29, 2020
3cf1f69
Refactor pipeline unit tests
lvermue Nov 29, 2020
29780d0
Refactor training loop unit tests
lvermue Nov 29, 2020
bf6190b
Fix flake8
lvermue Nov 29, 2020
04b31c9
Merge branch 'master' into allow_training_checkpoints
lvermue Dec 1, 2020
1cabf5f
Add saving of checkpoints after successful training
lvermue Dec 1, 2020
1fa0f18
Add usage of temporary directories for unit tests
lvermue Dec 1, 2020
46bb380
Add checkpoint documentation and correct failure checkpoint handling
lvermue Dec 1, 2020
2b32fbb
Trigger CI
PyKEEN-bot Dec 1, 2020
566e9ab
Add missing variable default value
cthoyt Dec 1, 2020
778ce48
Get rid of tqdms
cthoyt Dec 1, 2020
317f639
Pass flake8
cthoyt Dec 1, 2020
563e5a4
Use class teardown for handling temporary directory
cthoyt Dec 1, 2020
7fba750
Update docs
cthoyt Dec 2, 2020
65cefa7
Update argument names and type hints
cthoyt Dec 2, 2020
52656ec
Trigger CI
PyKEEN-bot Dec 2, 2020
a745853
Add datetime formatting
lvermue Dec 2, 2020
12e3fed
Add docs for checkpoint_on_failure_file_path
lvermue Dec 2, 2020
86749bf
Merge branch 'master' into allow_training_checkpoints
cthoyt Dec 4, 2020
37069be
Update constants
cthoyt Dec 4, 2020
ffd8e55
Change temp dir creation and teardown during unit tests
lvermue Dec 5, 2020
d285704
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Dec 5, 2020
b5fb84f
Update the checkpoint tutorial
lvermue Dec 5, 2020
92ed682
Trigger CI
PyKEEN-bot Dec 5, 2020
33f0776
Fix temp dir name handling
lvermue Dec 5, 2020
cc36bde
Trigger CI
PyKEEN-bot Dec 5, 2020
ed6b98e
Small fixes in docs
cthoyt Dec 7, 2020
6231d23
Merge branch 'master' into allow_training_checkpoints
lvermue Dec 7, 2020
605b76e
Trigger CI
PyKEEN-bot Dec 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ PyKEEN
tutorial/first_steps
tutorial/understanding_evaluation
tutorial/translational_toy_example
tutorial/checkpoints
tutorial/running_hpo
tutorial/running_ablation
tutorial/byod
Expand Down
9 changes: 9 additions & 0 deletions docs/source/tutorial/checkpoints.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Using Checkpoints
=================
Why does someone want to use checkpoints?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvermue write up this tutorial please


Give an example of a run that will obviously crash

How to recover when you were smart enough to keep checkpoints?

Where is this applicable? pipeline / hpo pipeline?
6 changes: 6 additions & 0 deletions src/pykeen/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
"""Constants for PyKEEN."""

import os
import pathlib

__all__ = [
'PYKEEN_HOME',
'PYKEEN_DEFAULT_CHECKPOINT_DIR',
]

PYKEEN_HOME = os.environ.get('PYKEEN_HOME') or os.path.join(os.path.expanduser('~'), '.pykeen')
PYKEEN_DEFAULT_CHECKPOINT = "PyKEEN_just_saved_my_day.pt"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect name


PYKEEN_DEFAULT_CHECKPOINT_DIR = pathlib.Path(PYKEEN_HOME).joinpath("checkpoints")
PYKEEN_DEFAULT_CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
2 changes: 2 additions & 0 deletions src/pykeen/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,10 @@ def __init__(
# Random seeds have to set before the embeddings are initialized
if random_seed is None:
logger.warning('No random seed is specified. This may lead to non-reproducible results.')
self._random_seed = None
elif random_seed is not NoRandomSeedNecessary:
set_random_seed(random_seed)
self._random_seed = random_seed

if automatic_memory_optimization is None:
automatic_memory_optimization = True
Expand Down
28 changes: 24 additions & 4 deletions src/pykeen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
import json
import logging
import os
import pathlib
import time
from dataclasses import dataclass, field
from typing import Any, Collection, Dict, Iterable, List, Mapping, Optional, Set, Type, Union
Expand All @@ -176,6 +177,7 @@
import torch
from torch.optim.optimizer import Optimizer

from .constants import PYKEEN_DEFAULT_CHECKPOINT_DIR
from .datasets import get_dataset
from .datasets.base import DataSet
from .evaluation import Evaluator, MetricResults, get_evaluator_cls
Expand Down Expand Up @@ -823,7 +825,28 @@ def pipeline( # noqa: C901
:param use_testing_data:
If true, use the testing triples. Otherwise, use the validation triples. Defaults to true - use testing triples.
"""
if random_seed is None:
if training_kwargs is None:
training_kwargs = {}

# To allow resuming training from a checkpoint when using a pipeline, the pipeline needs to obtain the
# used random_seed to ensure reproducible results
checkpoint_file_name = training_kwargs.get('checkpoint_file')
if checkpoint_file_name is not None:
checkpoint_directory = pathlib.Path(training_kwargs.get('checkpoint_root', PYKEEN_DEFAULT_CHECKPOINT_DIR))
checkpoint_directory.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_directory / checkpoint_file_name
if checkpoint_path.is_file():
checkpoint_dict = torch.load(checkpoint_path)
random_seed = checkpoint_dict['random_seed']
logger.info('loaded random seed %s from checkpoint.', random_seed)
# We have to set clear optimizer to False since training should be continued
clear_optimizer = False
else:
logger.info(f"=> no training loop checkpoint file found at '{checkpoint_path}'. Creating a new file.")
if random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
elif random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
set_random_seed(random_seed)
Expand Down Expand Up @@ -939,9 +962,6 @@ def pipeline( # noqa: C901
if evaluation_kwargs is None:
evaluation_kwargs = {}

if training_kwargs is None:
training_kwargs = {}

# Stopping
if 'stopper' in training_kwargs and stopper is not None:
raise ValueError('Specified stopper in training_kwargs and as stopper')
Expand Down
23 changes: 23 additions & 0 deletions src/pykeen/stoppers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,26 @@ def get_summary_dict(self) -> Mapping[str, Any]:
best_epoch=self.best_epoch,
best_metric=self.best_metric,
)

def _write_from_summary_dict(
self,
frequency: int,
patience: int,
relative_delta: float,
metric: str,
larger_is_better: bool,
results: List[float],
stopped: bool,
best_epoch: int,
best_metric: float,
) -> None:
"""Write attributes to stopper from a summary dict."""
self.frequency = frequency
self.patience = patience
self.relative_delta = relative_delta
self.metric = metric
self.larger_is_better = larger_is_better
self.results = results
self.stopped = stopped
self.best_epoch = best_epoch
self.best_metric = best_metric
34 changes: 34 additions & 0 deletions src/pykeen/stoppers/stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

"""Basic stoppers."""

import logging
import pathlib
from abc import ABC, abstractmethod
from typing import Any, Mapping, Union

import torch

__all__ = [
'Stopper',
'NopStopper',
]

logger = logging.getLogger(__name__)


class Stopper(ABC):
"""A harness for stopping training."""
Expand All @@ -25,6 +32,29 @@ def should_stop(self, epoch: int) -> bool:
"""Validate on validation set and check for termination condition."""
raise NotImplementedError

@abstractmethod
def get_summary_dict(self) -> Mapping[str, Any]:
"""Get a summary dict."""
raise NotImplementedError

def _write_from_summary_dict(self, **kwargs):
pass

@staticmethod
def load_summary_dict_from_training_loop_checkpoint(path: Union[str, pathlib.Path]) -> Mapping[str, Any]:
"""Load the summary dict from a training loop checkpoint.

:param path:
Path of the file where to store the state in.

:return:
The summary dict of the stopper at the time of saving the checkpoint.
"""
logger.info(f"=> loading stopper summary dict from training loop checkpoint in '{path}'")
checkpoint = torch.load(path)
logger.info(f"=> loaded stopper summary dictionary from checkpoint in '{path}'")
return checkpoint['stopper_dict']


class NopStopper(Stopper):
"""A stopper that does nothing."""
Expand All @@ -36,3 +66,7 @@ def should_evaluate(self, epoch: int) -> bool:
def should_stop(self, epoch: int) -> bool:
"""Return false; should never stop."""
return False

def get_summary_dict(self) -> Mapping[str, Any]:
"""Return empty mapping, doesn't have any attributes."""
return dict()