Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpoint saving and loading functionality to training loop #123

Merged
merged 74 commits into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
dc55230
Add checkpoint saving and loading functionality to training loop
lvermue Nov 1, 2020
9452a01
Flake8 and parameters defaults
lvermue Nov 1, 2020
21ab40e
Flake8
lvermue Nov 1, 2020
c05b38b
Refactor internal completed epoch tracking
lvermue Nov 1, 2020
3a0f420
Correct epoch number handling in training loop
lvermue Nov 2, 2020
326cf57
Add automatic loading of early stoppers from training loop checkpoints
lvermue Nov 2, 2020
12f5d9a
Add training loop checkpoint support for pipelines
lvermue Nov 2, 2020
0af8dbe
Fix indentation
lvermue Nov 2, 2020
aaa7eb2
Change training loop checksum creation from adler32 to md5
lvermue Nov 2, 2020
7ec0d44
Fix missing random_seed failure, if no training config file is provided
lvermue Nov 2, 2020
7b84f70
Flake8, refactor code and add CheckpointMismatchError
lvermue Nov 2, 2020
21a982c
Fix flake8
lvermue Nov 2, 2020
11e7c4e
Update exception
cthoyt Nov 2, 2020
5f11c15
Add outline for checkpoint tutorial
cthoyt Nov 2, 2020
fc9a5a9
Correct random state recovery
lvermue Nov 2, 2020
2bf5b98
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Nov 2, 2020
256beaa
Remove pipeline checkpoint helper file help function
lvermue Nov 22, 2020
1e170f3
Remove torch save helper file
lvermue Nov 22, 2020
3ba5288
Fix function argument for torch save
lvermue Nov 22, 2020
cfd0598
Add units tests for training loop checkpoint
lvermue Nov 22, 2020
8d2740f
Merge branch 'master' into allow_training_checkpoints
cthoyt Nov 23, 2020
07c3d42
Add unit tests for checkpoints
lvermue Nov 23, 2020
37b777b
Fix flake8
lvermue Nov 23, 2020
8af806f
Add failure fallback checkpoints to training loop
lvermue Nov 23, 2020
bc20d35
Correct checkpoint root dir handling
lvermue Nov 23, 2020
17cdfcc
Add implicit random seed handling from checkpoints in pipeline
lvermue Nov 24, 2020
b036b9d
Code cleanup
cthoyt Nov 24, 2020
5d2e47b
More refactoring
cthoyt Nov 24, 2020
583d494
Add CPU/GPU random state differentiation
lvermue Nov 24, 2020
604bc2f
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Nov 24, 2020
26e0b39
Workaround for CUDA rng state
cthoyt Nov 24, 2020
a564fef
Refactor tests
cthoyt Nov 24, 2020
67776c8
Remove unnecessary stuff
cthoyt Nov 24, 2020
578fd1e
Unnest logic
cthoyt Nov 24, 2020
09d03e3
Improve typing and safety
cthoyt Nov 24, 2020
9ca40dc
Fix testing
lvermue Nov 24, 2020
5d64d6c
Fix pipeline checkpoint unit tests
lvermue Nov 24, 2020
bf82e79
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 24, 2020
661eed3
Fix usage of forbidden characters for Windows in filepaths
lvermue Nov 24, 2020
675bd59
Refactor loading of states for the training loop and stoppers
lvermue Nov 25, 2020
f618a90
Fix flake8
lvermue Nov 25, 2020
14e14b7
Fix handling of stopper state dictionaries
lvermue Nov 25, 2020
cd64e7e
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 25, 2020
d9abffd
Fix flake8
lvermue Nov 25, 2020
f226b0a
Merge branch 'master' into allow_training_checkpoints
lvermue Nov 29, 2020
240140b
Fix unit tests
lvermue Nov 29, 2020
3cf1f69
Refactor pipeline unit tests
lvermue Nov 29, 2020
29780d0
Refactor training loop unit tests
lvermue Nov 29, 2020
bf6190b
Fix flake8
lvermue Nov 29, 2020
04b31c9
Merge branch 'master' into allow_training_checkpoints
lvermue Dec 1, 2020
1cabf5f
Add saving of checkpoints after successful training
lvermue Dec 1, 2020
1fa0f18
Add usage of temporary directories for unit tests
lvermue Dec 1, 2020
46bb380
Add checkpoint documentation and correct failure checkpoint handling
lvermue Dec 1, 2020
2b32fbb
Trigger CI
PyKEEN-bot Dec 1, 2020
566e9ab
Add missing variable default value
cthoyt Dec 1, 2020
778ce48
Get rid of tqdms
cthoyt Dec 1, 2020
317f639
Pass flake8
cthoyt Dec 1, 2020
563e5a4
Use class teardown for handling temporary directory
cthoyt Dec 1, 2020
7fba750
Update docs
cthoyt Dec 2, 2020
65cefa7
Update argument names and type hints
cthoyt Dec 2, 2020
52656ec
Trigger CI
PyKEEN-bot Dec 2, 2020
a745853
Add datetime formatting
lvermue Dec 2, 2020
12e3fed
Add docs for checkpoint_on_failure_file_path
lvermue Dec 2, 2020
86749bf
Merge branch 'master' into allow_training_checkpoints
cthoyt Dec 4, 2020
37069be
Update constants
cthoyt Dec 4, 2020
ffd8e55
Change temp dir creation and teardown during unit tests
lvermue Dec 5, 2020
d285704
Merge branch 'allow_training_checkpoints' of https://github.com/pykee…
lvermue Dec 5, 2020
b5fb84f
Update the checkpoint tutorial
lvermue Dec 5, 2020
92ed682
Trigger CI
PyKEEN-bot Dec 5, 2020
33f0776
Fix temp dir name handling
lvermue Dec 5, 2020
cc36bde
Trigger CI
PyKEEN-bot Dec 5, 2020
ed6b98e
Small fixes in docs
cthoyt Dec 7, 2020
6231d23
Merge branch 'master' into allow_training_checkpoints
lvermue Dec 7, 2020
605b76e
Trigger CI
PyKEEN-bot Dec 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/pykeen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,34 @@ def pipeline_from_config(
)


def save_pipeline_checkpoint_helper_file(path: str, random_seed: int) -> None:
"""Save the pipeline checkpoint helper file.

:param path:
Save the state of the pipeline.
:param random_seed:
The random_seed that was used for the pipeline.
"""
torch.save(
{
'random_seed': random_seed,
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
},
path,
)


def load_pipeline_checkpoint_helper_file(path: str) -> Mapping[str, Any]:
lvermue marked this conversation as resolved.
Show resolved Hide resolved
"""Load the pipeline checkpoint helper file.

:param path:
Save the state of the pipeline.

:return:
The pipeline checkpoint helper file dictionary loaded from the pipeline helper file.
"""
return torch.load(path)


def pipeline( # noqa: C901
*,
# 1. Dataset
Expand Down Expand Up @@ -823,9 +851,22 @@ def pipeline( # noqa: C901
:param use_testing_data:
If true, use the testing triples. Otherwise, use the validation triples. Defaults to true - use testing triples.
"""
if random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
# To allow resuming training from a checkpoint when using a pipeline, the pipeline needs to store a helper file
# containing the used random_seed to ensure reproducible results
if training_kwargs.get('checkpoint_file'):
checkpoint_file = training_kwargs.get('checkpoint_file')
pipeline_checkpoint_helper_file = f"{checkpoint_file}_pipeline_helper_file"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make a better naming convention for this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas?

if os.path.isfile(pipeline_checkpoint_helper_file):
pipeline_checkpoint_helper_dict = load_pipeline_checkpoint_helper_file(pipeline_checkpoint_helper_file)
random_seed = pipeline_checkpoint_helper_dict['random_seed']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the other function sets the states directly, is saving it in a file even necessary now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would for the least still be relevant to have the pipeline to correctly return the used random seed.
Aside that, it would give us the guarantee that whatever random method is involved before we resume the training loop also uses the correct random seed.

logger.info(f'Loaded random seed {random_seed} from checkpoint.')
else:
logger.info(f"=> no pipeline checkpoint helper file found at '{checkpoint_file}'. Creating a new file.")
if random_seed is None:
random_seed = random_non_negative_int()
logger.warning(f'No random seed is specified. Setting to {random_seed}.')
save_pipeline_checkpoint_helper_file(path=pipeline_checkpoint_helper_file, random_seed=random_seed)

set_random_seed(random_seed)

result_tracker_cls: Type[ResultTracker] = get_result_tracker_cls(result_tracker)
Expand Down
23 changes: 23 additions & 0 deletions src/pykeen/stoppers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,26 @@ def get_summary_dict(self) -> Mapping[str, Any]:
best_epoch=self.best_epoch,
best_metric=self.best_metric,
)

def _write_from_summary_dict(
self,
frequency: int,
patience: int,
relative_delta: float,
metric: str,
larger_is_better: bool,
results: List[float],
stopped: bool,
best_epoch: int,
best_metric: float,
) -> None:
"""Write attributes to stopper from a summary dict."""
self.frequency = frequency
self.patience = patience
self.relative_delta = relative_delta
self.metric = metric
self.larger_is_better = larger_is_better
self.results = results
self.stopped = stopped
self.best_epoch = best_epoch
self.best_metric = best_metric
10 changes: 10 additions & 0 deletions src/pykeen/stoppers/stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Basic stoppers."""

from abc import ABC, abstractmethod
from typing import Any, Mapping

__all__ = [
'Stopper',
Expand All @@ -25,6 +26,11 @@ def should_stop(self, epoch: int) -> bool:
"""Validate on validation set and check for termination condition."""
raise NotImplementedError

@abstractmethod
def get_summary_dict(self) -> Mapping[str, Any]:
"""Get a summary dict."""
raise NotImplementedError


class NopStopper(Stopper):
"""A stopper that does nothing."""
Expand All @@ -36,3 +42,7 @@ def should_evaluate(self, epoch: int) -> bool:
def should_stop(self, epoch: int) -> bool:
"""Return false; should never stop."""
return False

def get_summary_dict(self) -> Mapping[str, Any]:
"""Return empty mapping, doesn't have any attributes."""
return dict()
162 changes: 144 additions & 18 deletions src/pykeen/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import gc
import logging
import os
import time
from abc import ABC, abstractmethod
from hashlib import md5
from typing import Any, List, Mapping, Optional, Tuple, Type, Union

import torch
Expand Down Expand Up @@ -103,6 +106,10 @@ def __init__(
else:
self._loss_helper = self._label_loss_helper

# The internal epoch state tracks the last finished epoch of the training loop to allow for
# seamless loading and saving of training checkpoints
self._epoch = 0

@classmethod
def get_normalized_name(cls) -> str:
"""Get the normalized name of the training loop."""
Expand All @@ -118,6 +125,14 @@ def device(self): # noqa: D401
"""The device used by the model."""
return self.model.device

@property
def checksum(self) -> str: # noqa: D401
"""The checksum of the model and optimizer the training loop was configured with."""
h = md5()
h.update(str(self.model).encode('utf-8'))
h.update(str(self.optimizer).encode('utf-8'))
return h.hexdigest()

def train(
self,
num_epochs: int = 1,
Expand All @@ -135,6 +150,8 @@ def train(
sub_batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
clear_optimizer: bool = False,
checkpoint_file: Optional[str] = None,
checkpoint_frequency: Optional[int] = None,
) -> List[float]:
"""Train the KGE model.

Expand Down Expand Up @@ -168,6 +185,11 @@ def train(
:param clear_optimizer:
Whether to delete the optimizer instance after training (as the optimizer might have additional memory
consumption due to e.g. moments in Adam).
:param checkpoint_file:
The filename for saving checkpoints. If the given filename exists already, that file will be loaded and used
to continue training.
:param checkpoint_frequency:
The frequency of saving checkpoints in minutes.

:return:
A pair of the KGE model and the losses per epoch.
Expand All @@ -179,22 +201,39 @@ def train(
# In some cases, e.g. using Optuna for HPO, the cuda cache from a previous run is not cleared
torch.cuda.empty_cache()

result = self._train(
num_epochs=num_epochs,
batch_size=batch_size,
slice_size=slice_size,
label_smoothing=label_smoothing,
sampler=sampler,
continue_training=continue_training,
only_size_probing=only_size_probing,
use_tqdm=use_tqdm,
use_tqdm_batch=use_tqdm_batch,
tqdm_kwargs=tqdm_kwargs,
stopper=stopper,
result_tracker=result_tracker,
sub_batch_size=sub_batch_size,
num_workers=num_workers,
)
# If a checkpoint file is given we check whether it exists already and load it, if it does
if checkpoint_file:
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
if os.path.isfile(checkpoint_file):
lvermue marked this conversation as resolved.
Show resolved Hide resolved
stopper_dict = self._load_state(path=checkpoint_file)
# If the stopper dict has any keys, those are written back to the stopper
if stopper_dict:
stopper._write_from_summary_dict(**stopper_dict)
continue_training = True
else:
logger.info(f"=> no checkpoint found at '{checkpoint_file}'. Creating a new file.")

# If the stopper loaded from the training loop checkpoint stopped the training, we return those results
if getattr(stopper, 'stopped', False):
mberr marked this conversation as resolved.
Show resolved Hide resolved
result = self.losses_per_epochs
else:
result = self._train(
num_epochs=num_epochs,
batch_size=batch_size,
slice_size=slice_size,
label_smoothing=label_smoothing,
sampler=sampler,
continue_training=continue_training,
only_size_probing=only_size_probing,
use_tqdm=use_tqdm,
use_tqdm_batch=use_tqdm_batch,
tqdm_kwargs=tqdm_kwargs,
stopper=stopper,
result_tracker=result_tracker,
sub_batch_size=sub_batch_size,
num_workers=num_workers,
checkpoint_file=checkpoint_file,
checkpoint_frequency=checkpoint_frequency,
)

# Ensure the release of memory
torch.cuda.empty_cache()
Expand All @@ -221,6 +260,8 @@ def _train( # noqa: C901
result_tracker: Optional[ResultTracker] = None,
sub_batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
checkpoint_file: Optional[str] = None,
checkpoint_frequency: int = None,
) -> List[float]:
"""Train the KGE model.

Expand Down Expand Up @@ -255,6 +296,10 @@ def _train( # noqa: C901
If provided split each batch into sub-batches to avoid memory issues for large models / small GPUs.
:param num_workers:
The number of child CPU workers used for loading data. If None, data are loaded in the main process.
:param checkpoint_file:
The filename for saving checkpoints.
:param checkpoint_frequency:
The frequency of saving checkpoints in minutes. Setting it to 0 will save a checkpoint after every epoch.

:return:
A pair of the KGE model and the losses per epoch.
Expand Down Expand Up @@ -322,9 +367,11 @@ def _train( # noqa: C901
_tqdm_kwargs = dict(desc=f'Training epochs on {self.device}', unit='epoch')
if tqdm_kwargs is not None:
_tqdm_kwargs.update(tqdm_kwargs)
epochs = trange(1, 1 + num_epochs, **_tqdm_kwargs)
else:
epochs = trange(self._epoch + 1, 1 + num_epochs, **_tqdm_kwargs, initial=self._epoch, total=num_epochs)
elif only_size_probing:
epochs = range(1, 1 + num_epochs)
else:
epochs = range(self._epoch + 1, 1 + num_epochs)

logger.debug(f'using stopper: {stopper}')

Expand All @@ -336,6 +383,11 @@ def _train( # noqa: C901
num_workers=num_workers,
)

# Save the time to track when the saved point was available
last_checkpoint = time.time()
if checkpoint_frequency is None:
checkpoint_frequency = 30

# Training Loop
for epoch in epochs:
# Enforce training mode
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -412,8 +464,23 @@ def _train( # noqa: C901
'prev_loss': self.losses_per_epochs[-2] if epoch > 2 else float('nan'),
})

# Save the last successful finished epoch
self._epoch = epoch

if stopper is not None and stopper.should_evaluate(epoch) and stopper.should_stop(epoch):
# If a checkpoint file is given, we check whether it is time to save a checkpoint
if checkpoint_file:
minutes_since_last_checkpoint = (time.time() - last_checkpoint) // 60
if minutes_since_last_checkpoint >= checkpoint_frequency:
self._save_state(path=checkpoint_file, stopper=stopper)
return self.losses_per_epochs
else:
# If a checkpoint file is given, we check whether it is time to save a checkpoint
if checkpoint_file:
minutes_since_last_checkpoint = (time.time() - last_checkpoint) // 60
if minutes_since_last_checkpoint >= checkpoint_frequency:
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
self._save_state(path=checkpoint_file, stopper=stopper)
last_checkpoint = time.time()

return self.losses_per_epochs

Expand Down Expand Up @@ -673,3 +740,62 @@ def _free_graph_and_cache(self):
self.model.regularizer.reset()
# The cache of the previous run has to be freed to allow accurate memory availability estimates
torch.cuda.empty_cache()

def _save_state(self, path: str, stopper: Optional[Stopper] = None) -> None:
"""Save the state of the training loop.

:param path:
Path of the file where to store the state in.
:param stopper:
An instance of :class:`pykeen.stopper.EarlyStopper` with settings for checking
if training should stop early
"""
logger.debug("=> Saving checkpoint.")

if stopper is None:
stopper_dict = dict()
else:
stopper_dict = stopper.get_summary_dict()

torch.save(
{
'epoch': self._epoch,
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
'loss': self.losses_per_epochs,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'checksum': self.checksum,
'stopper_dict': stopper_dict,
},
path,
)
logger.info(f"=> Saved checkpoint after having finished epoch {self._epoch}.")

def _load_state(self, path: str) -> Mapping[str, Any]:
"""Load the state of the training loop from a checkpoint.

:param path:
Path of the file where to load the state from.

:return:
The summary dict of the stopper at the time of saving the checkpoint.

:raises FileExistsError:
If the given checkpoint file has a non-matching checksum, i.e. it was saved with a different configuration.
"""
logger.info(f"=> loading checkpoint '{path}'")
checkpoint = torch.load(path)
loaded_checksum = checkpoint['checksum']
if loaded_checksum == self.checksum:
self._epoch = checkpoint['epoch']
self.losses_per_epochs = checkpoint['loss']
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
stopper_dict = checkpoint['stopper_dict']
logger.info(f"=> loaded checkpoint '{path}' stopped after having finished epoch {checkpoint['epoch']}")
else:
raise FileExistsError(
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
f"The checkpoint file '{path}' that was provided already exists, but seems to be "
"from a different training loop setup.",
)

return stopper_dict