Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpoint saving and loading functionality to training loop #123

Merged
merged 74 commits into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
dc55230
Add checkpoint saving and loading functionality to training loop
lvermue Nov 1, 2020
9452a01
Flake8 and parameters defaults
lvermue Nov 1, 2020
21ab40e
Flake8
lvermue Nov 1, 2020
c05b38b
Refactor internal completed epoch tracking
lvermue Nov 1, 2020
3a0f420
Correct epoch number handling in training loop
lvermue Nov 2, 2020
326cf57
Add automatic loading of early stoppers from training loop checkpoints
lvermue Nov 2, 2020
12f5d9a
Add training loop checkpoint support for pipelines
lvermue Nov 2, 2020
0af8dbe
Fix indentation
lvermue Nov 2, 2020
aaa7eb2
Change training loop checksum creation from adler32 to md5
lvermue Nov 2, 2020
7ec0d44
Fix missing random_seed failure, if no training config file is provided
lvermue Nov 2, 2020
7b84f70
Flake8, refactor code and add CheckpointMismatchError
lvermue Nov 2, 2020
21a982c
Fix flake8
lvermue Nov 2, 2020
11e7c4e
Update exception
cthoyt Nov 2, 2020
5f11c15
Add outline for checkpoint tutorial
cthoyt Nov 2, 2020
fc9a5a9
Correct random state recovery
lvermue Nov 2, 2020
2bf5b98
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Nov 2, 2020
256beaa
Remove pipeline checkpoint helper file help function
lvermue Nov 22, 2020
1e170f3
Remove torch save helper file
lvermue Nov 22, 2020
3ba5288
Fix function argument for torch save
lvermue Nov 22, 2020
cfd0598
Add units tests for training loop checkpoint
lvermue Nov 22, 2020
8d2740f
Merge branch 'master' into allow_training_checkpoints
cthoyt Nov 23, 2020
07c3d42
Add unit tests for checkpoints
lvermue Nov 23, 2020
37b777b
Fix flake8
lvermue Nov 23, 2020
8af806f
Add failure fallback checkpoints to training loop
lvermue Nov 23, 2020
bc20d35
Correct checkpoint root dir handling
lvermue Nov 23, 2020
17cdfcc
Add implicit random seed handling from checkpoints in pipeline
lvermue Nov 24, 2020
b036b9d
Code cleanup
cthoyt Nov 24, 2020
5d2e47b
More refactoring
cthoyt Nov 24, 2020
583d494
Add CPU/GPU random state differentiation
lvermue Nov 24, 2020
604bc2f
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Nov 24, 2020
26e0b39
Workaround for CUDA rng state
cthoyt Nov 24, 2020
a564fef
Refactor tests
cthoyt Nov 24, 2020
67776c8
Remove unnecessary stuff
cthoyt Nov 24, 2020
578fd1e
Unnest logic
cthoyt Nov 24, 2020
09d03e3
Improve typing and safety
cthoyt Nov 24, 2020
9ca40dc
Fix testing
lvermue Nov 24, 2020
5d64d6c
Fix pipeline checkpoint unit tests
lvermue Nov 24, 2020
bf82e79
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 24, 2020
661eed3
Fix usage of forbidden characters for Windows in filepaths
lvermue Nov 24, 2020
675bd59
Refactor loading of states for the training loop and stoppers
lvermue Nov 25, 2020
f618a90
Fix flake8
lvermue Nov 25, 2020
14e14b7
Fix handling of stopper state dictionaries
lvermue Nov 25, 2020
cd64e7e
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 25, 2020
d9abffd
Fix flake8
lvermue Nov 25, 2020
f226b0a
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 29, 2020
240140b
Fix unit tests
lvermue Nov 29, 2020
3cf1f69
Refactor pipeline unit tests
lvermue Nov 29, 2020
29780d0
Refactor training loop unit tests
lvermue Nov 29, 2020
bf6190b
Fix flake8
lvermue Nov 29, 2020
04b31c9
Merge branch 'master' into allow_training_checkpoints
lvermue Dec 1, 2020
1cabf5f
Add saving of checkpoints after successful training
lvermue Dec 1, 2020
1fa0f18
Add usage of temporary directories for unit tests
lvermue Dec 1, 2020
46bb380
Add checkpoint documentation and correct failure checkpoint handling
lvermue Dec 1, 2020
2b32fbb
Trigger CI
PyKEEN-bot Dec 1, 2020
566e9ab
Add missing variable default value
cthoyt Dec 1, 2020
778ce48
Get rid of tqdms
cthoyt Dec 1, 2020
317f639
Pass flake8
cthoyt Dec 1, 2020
563e5a4
Use class teardown for handling temporary directory
cthoyt Dec 1, 2020
7fba750
Update docs
cthoyt Dec 2, 2020
65cefa7
Update argument names and type hints
cthoyt Dec 2, 2020
52656ec
Trigger CI
PyKEEN-bot Dec 2, 2020
a745853
Add datetime formatting
lvermue Dec 2, 2020
12e3fed
Add docs for checkpoint_on_failure_file_path
lvermue Dec 2, 2020
86749bf
Merge branch 'master' into allow_training_checkpoints
cthoyt Dec 4, 2020
37069be
Update constants
cthoyt Dec 4, 2020
ffd8e55
Change temp dir creation and teardown during unit tests
lvermue Dec 5, 2020
d285704
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Dec 5, 2020
b5fb84f
Update the checkpoint tutorial
lvermue Dec 5, 2020
92ed682
Trigger CI
PyKEEN-bot Dec 5, 2020
33f0776
Fix temp dir name handling
lvermue Dec 5, 2020
cc36bde
Trigger CI
PyKEEN-bot Dec 5, 2020
ed6b98e
Small fixes in docs
cthoyt Dec 7, 2020
6231d23
Merge branch 'master' into allow_training_checkpoints
lvermue Dec 7, 2020
605b76e
Trigger CI
PyKEEN-bot Dec 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ PyKEEN
tutorial/first_steps
tutorial/understanding_evaluation
tutorial/translational_toy_example
tutorial/checkpoints
tutorial/running_hpo
tutorial/running_ablation
tutorial/byod
Expand Down
147 changes: 147 additions & 0 deletions docs/source/tutorial/checkpoints.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
Using Checkpoints
=================
Training may take days to weeks in extreme cases when using models with many parameters or big datasets. This introduces
a large array of possible errors, e.g. session timeouts, server restarts etc., which would lead to a complete loss of
all progress made so far. To avoid this the :class:`pykeen.training.TrainingLoop` supports built-in check-points that
allow a straight-forward saving of the current training loop state and resumption of a saved
state from saved checkpoints.

How to do it
------------
To show how checkpoints are used with PyKEEN let's look at a simple example of how a model is setup.
For fixing possible errors and safety fallbacks please also look at :ref:`word_of_caution`.

.. code-block:: python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't these tutorials be more useful for users if they started by being centered on the pipeline and then at the end gave some insight into the underlying implementation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason it is kept right now is due to the fact that the checkpoint functionality is only a true training loop functionality, because even though the pipeline supports using training loop checkpoints, it is not a true pipeline checkpoint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but I don't think the beginning of a tutorial section of the documentation benefits from being pedagogical. The technical part can be in the reference, or at the end of the tutorial to help users who understand how to use the simple parts and want to understand how it works. I think one place this worked really well was the First Steps tutorial, which ended with the Beyond the Pipeline section.

Could you elaborate on the difference you mean between a training loop checkpoint vs a pipeline checkpoint?


from pykeen.models import TransE
from pykeen.training import SLCWATrainingLoop
from pykeen.triples import TriplesFactory
from torch.optim import Adam

triples_factory = Nations().training
model = TransE(
triples_factory=triples_factory,
random_seed=123,
)

optimizer = Adam(params=model.get_grad_params())
training_loop = SLCWATrainingLoop(model=model, optimizer=optimizer)

At this point we have a model, dataset and optimizer all setup in a training loop and are ready to train the model with
the ``training_loop``'s method :func:`pykeen.training.TrainingLoop.train`. To enable checkpoints all you have to do is
setting the function argument ``checkpoint_file`` to the name you would like it to have.
Optionally, you can set the path to where you want the checkpoints to be saved by setting the ``checkpoint_directory``
argument with a string or a :class:`pathlib.Path` object containing your desired root path. If you didn't set the
``checkpoint_directory`` argument, your checkpoints will be saved in the ``PYKEEN_HOME`` directory that is defined in
:mod:`pykeen.constants`, which is a subdirectory in your home directory, e.g. ``~/.pykeen/checkpoints``.
Furthermore, you can set the checkpoint frequency, i.e. how often checkpoints should be saved given in minutes, by
setting the argument ``checkpoint_frequency`` with an integer. The default frequency is 30 minutes and setting it to
``0`` will cause the training loop to save a checkpoint after each epoch.

Here is an example:

.. code-block:: python

losses = training_loop.train(
num_epochs=1000,
checkpoint_name='my_checkpoint.pt',
checkpoint_frequency=5,
)

With this code we have started the training loop with the above defined KGEM. The training loop will save a checkpoint
in the ``my_checkpoint.pt`` file, which will be saved in the ``~/.pykeen/checkpoints/`` directory, since we haven't
set the argument ``checkpoint_directory``.
The checkpoint file will be saved after 5 minutes since starting the training loop or the last time a checkpoint was
saved and the epoch finishes, i.e. when one epoch takes 10 minutes the checkpoint will be saved after 10 minutes.
In addition, checkpoints are always saved when the early stopper stops the training loop or the last epoch was finished.

Let's assume you were anticipative, saved checkpoints and your training loop crashed after 200 epochs.
Now you would like to resume from the last checkpoint. All you have to do is to rerun the **exact same code** as above
and PyKEEN will smoothly start from the given checkpoint. Since PyKEEN stores all random states as well as the
states of the model, optimizer and early stopper, the results will be exactly the same compared to running the
training loop uninterruptedly. Of course, PyKEEN will also continue saving new checkpoints even when
resuming from a previous checkpoint.

On top of resuming interrupted training loops you can also resume training loops that finished successfully.
E.g. the above training loop finished successfully after 1000 epochs, but you would like to
train the same model from that state for 2000 epochs. All you have have to do is to change the argument
``num_epochs`` in the above code to:

.. code-block:: python

losses = training_loop.train(
num_epochs=2000,
checkpoint_name='my_checkpoint.pt',
checkpoint_frequency=5,
)

and now the training loop will resume from the state at 1000 epochs and continue to train until 2000 epochs.

Another nice feature is that the checkpoints functionality integrates with the pipeline. This means that you can simply
define a pipeline like this:

.. code-block:: python

from pykeen.pipeline import pipeline
pipeline_result = pipeline(
dataset='Nations',
model='TransE',
optimizer='Adam',
training_kwargs=dict(num_epochs=1000, checkpoint_name='my_checkpoint.pt', checkpoint_frequency=5),
)

Again, assuming that e.g. this pipeline crashes after 200 epochs, you can simply execute **the same code** and the
pipeline will load the last state from the checkpoint file and continue training as if nothing happened.

.. todo:: Tutorial on recovery from hpo_pipeline.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for a later PR?

Copy link
Member Author

@lvermue lvermue Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Basically the hpo_pipeline supports saving checkpoints through the pipeline, but that is specific to the training loop itself. Supporting to resume a cancelled hpo_pipeline would be an entirely different story.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this .. todo directive breaks the build


Checkpoints on Failure
----------------------
In cases where you only would like to save checkpoints whenever the training loop might fail, you can use the argument
``checkpoint_on_failure=True``, like:

.. code-block:: python

losses = training_loop.train(
num_epochs=2000,
checkpoint_on_failure=True,
)

This option differs from ordinary checkpoints, since ordinary checkpoints are only saved
after a successful epoch. When saving checkpoints due to failure of the training loop there is no guarantee that all
random states can be recovered correctly, which might cause problems with regards to the reproducibility of that
specific training loop. Therefore, these checkpoints are saved with a distinct checkpoint name, which will be
``PyKEEN_just_saved_my_day_{datetime}.pt`` in the given ``checkpoint_directory``, even when you also opted to use
ordinary checkpoints as defined above, e.g. with this code:

.. code-block:: python

losses = training_loop.train(
num_epochs=2000,
checkpoint_name='my_checkpoint.pt',
checkpoint_frequency=5,
checkpoint_on_failure=True,
)

Note: Use this argument with caution, since every failed training loop will create a distinct checkpoint file.

.. _word_of_caution:

Word of Caution and Possible Errors
-----------------------------------
When using checkpoints and trying out several configurations, which in return result in multiple different checkpoints,
the inherent risk of overwriting checkpoints arises. This would naturally happen when you change the configuration of
the KGEM, but don't change the ``checkpoint_name`` argument.
To prevent this from happening, PyKEEN makes a hash-sum comparison of the configurations of the checkpoint and
the one of the current configuration at hand. When these don't match, PyKEEN won't accept the checkpoint and raise
an error.

In case you want to overwrite the previous checkpoint file with a new configuration, you have to delete it explicitly.
The reason for this behavior is three-fold:

1. This allows a very easy and user friendly way of resuming an interrupted training loop by simply re-running
the exact same code.
2. By explicitly requiring to name the checkpoint files the user controls the naming of the files and thus makes
it easier to keep an overview.
3. Creating new checkpoint files for each run will lead most users to inadvertently spam their file systems with
unused checkpoints that with ease can add up to hundred of GBs when running many experiments.
6 changes: 6 additions & 0 deletions src/pykeen/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
"""Constants for PyKEEN."""

import os
import pathlib

__all__ = [
'PYKEEN_HOME',
'PYKEEN_DEFAULT_CHECKPOINT_DIR',
]

PYKEEN_HOME = os.environ.get('PYKEEN_HOME') or os.path.join(os.path.expanduser('~'), '.pykeen')
PYKEEN_DEFAULT_CHECKPOINT = "PyKEEN_just_saved_my_day.pt"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect name


PYKEEN_DEFAULT_CHECKPOINT_DIR = pathlib.Path(PYKEEN_HOME).joinpath("checkpoints")
PYKEEN_DEFAULT_CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
2 changes: 2 additions & 0 deletions src/pykeen/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def __init__(
# Random seeds have to set before the embeddings are initialized
if random_seed is None:
logger.warning('No random seed is specified. This may lead to non-reproducible results.')
self._random_seed = None
elif random_seed is not NoRandomSeedNecessary:
set_random_seed(random_seed)
self._random_seed = random_seed

# Loss
if loss is None:
Expand Down
28 changes: 24 additions & 4 deletions src/pykeen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
import json
import logging
import os
import pathlib
import time
from dataclasses import dataclass, field
from typing import Any, Collection, Dict, Iterable, List, Mapping, Optional, Set, Type, Union
Expand All @@ -176,6 +177,7 @@
import torch
from torch.optim.optimizer import Optimizer

from .constants import PYKEEN_DEFAULT_CHECKPOINT_DIR
from .datasets import get_dataset
from .datasets.base import DataSet
from .evaluation import Evaluator, MetricResults, get_evaluator_cls
Expand Down Expand Up @@ -824,7 +826,28 @@ def pipeline( # noqa: C901
:param use_testing_data:
If true, use the testing triples. Otherwise, use the validation triples. Defaults to true - use testing triples.
"""
if random_seed is None:
if training_kwargs is None:
training_kwargs = {}

# To allow resuming training from a checkpoint when using a pipeline, the pipeline needs to obtain the
# used random_seed to ensure reproducible results
checkpoint_name = training_kwargs.get('checkpoint_name')
if checkpoint_name is not None:
checkpoint_directory = pathlib.Path(training_kwargs.get('checkpoint_directory', PYKEEN_DEFAULT_CHECKPOINT_DIR))
checkpoint_directory.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_directory / checkpoint_name
if checkpoint_path.is_file():
checkpoint_dict = torch.load(checkpoint_path)
random_seed = checkpoint_dict['random_seed']
logger.info('loaded random seed %s from checkpoint.', random_seed)
# We have to set clear optimizer to False since training should be continued
clear_optimizer = False
else:
logger.info(f"=> no training loop checkpoint file found at '{checkpoint_path}'. Creating a new file.")
if random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
elif random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
set_random_seed(random_seed)
Expand Down Expand Up @@ -947,9 +970,6 @@ def pipeline( # noqa: C901
if evaluation_kwargs is None:
evaluation_kwargs = {}

if training_kwargs is None:
training_kwargs = {}

# Stopping
if 'stopper' in training_kwargs and stopper is not None:
raise ValueError('Specified stopper in training_kwargs and as stopper')
Expand Down
23 changes: 23 additions & 0 deletions src/pykeen/stoppers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,26 @@ def get_summary_dict(self) -> Mapping[str, Any]:
best_epoch=self.best_epoch,
best_metric=self.best_metric,
)

def _write_from_summary_dict(
self,
frequency: int,
patience: int,
relative_delta: float,
metric: str,
larger_is_better: bool,
results: List[float],
stopped: bool,
best_epoch: int,
best_metric: float,
) -> None:
"""Write attributes to stopper from a summary dict."""
self.frequency = frequency
self.patience = patience
self.relative_delta = relative_delta
self.metric = metric
self.larger_is_better = larger_is_better
self.results = results
self.stopped = stopped
self.best_epoch = best_epoch
self.best_metric = best_metric
34 changes: 34 additions & 0 deletions src/pykeen/stoppers/stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

"""Basic stoppers."""

import logging
import pathlib
from abc import ABC, abstractmethod
from typing import Any, Mapping, Union

import torch

__all__ = [
'Stopper',
'NopStopper',
]

logger = logging.getLogger(__name__)


class Stopper(ABC):
"""A harness for stopping training."""
Expand All @@ -25,6 +32,29 @@ def should_stop(self, epoch: int) -> bool:
"""Validate on validation set and check for termination condition."""
raise NotImplementedError

@abstractmethod
def get_summary_dict(self) -> Mapping[str, Any]:
"""Get a summary dict."""
raise NotImplementedError

def _write_from_summary_dict(self, **kwargs):
pass

@staticmethod
def load_summary_dict_from_training_loop_checkpoint(path: Union[str, pathlib.Path]) -> Mapping[str, Any]:
"""Load the summary dict from a training loop checkpoint.

:param path:
Path of the file where to store the state in.

:return:
The summary dict of the stopper at the time of saving the checkpoint.
"""
logger.info(f"=> loading stopper summary dict from training loop checkpoint in '{path}'")
checkpoint = torch.load(path)
logger.info(f"=> loaded stopper summary dictionary from checkpoint in '{path}'")
return checkpoint['stopper_dict']


class NopStopper(Stopper):
"""A stopper that does nothing."""
Expand All @@ -36,3 +66,7 @@ def should_evaluate(self, epoch: int) -> bool:
def should_stop(self, epoch: int) -> bool:
"""Return false; should never stop."""
return False

def get_summary_dict(self) -> Mapping[str, Any]:
"""Return empty mapping, doesn't have any attributes."""
return dict()