Skip to content

Commit

Permalink
βŒ›πŸ“‹ Fix (Re-)Instantiation of LR-Schedule (#1386)
Browse files Browse the repository at this point in the history
Fix #1384

Also, do not register an LR scheduler callback if we do not have a
learning rate scheduler, instead of just doing nothing in the callback.

---------

Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>
  • Loading branch information
mberr and cthoyt committed May 17, 2024
1 parent ac85bef commit 98917df
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/pykeen/training/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ class LearningRateSchedulerTrainingCallback(TrainingCallback):

# docstr-coverage: inherited
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
if self.training_loop.lr_scheduler is not None:
self.training_loop.lr_scheduler.step(epoch=epoch)
if self.training_loop.lr_scheduler is None:
raise ValueError(f"{self} can only be called when a learning rate schedule is used.")
self.training_loop.lr_scheduler.step(epoch=epoch)


def _hasher(kwargs: Mapping[str, Any]) -> int:
Expand Down
16 changes: 9 additions & 7 deletions src/pykeen/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Training loops for KGE models using multi-modal information."""

import gc
import inspect
import logging
import os
import pathlib
Expand Down Expand Up @@ -115,13 +116,13 @@ def _get_optimizer_kwargs(optimizer: Optimizer) -> Mapping[str, Any]:


def _get_lr_scheduler_kwargs(lr_scheduler: LRScheduler) -> Mapping[str, Any]:
lr_scheduler_kwargs = lr_scheduler.state_dict()
lr_scheduler_kwargs = {
# note: this seems to be a pretty unsafe method to derive __init__ kwargs...
init_parameters = inspect.signature(lr_scheduler.__init__).parameters
return {
key: value
for key, value in lr_scheduler_kwargs.items()
if not key.startswith("_") and key not in ["base_lrs", "last_epoch"]
for key, value in lr_scheduler.state_dict().items()
if key not in {"last_epoch", "optimizer"} and key in init_parameters
}
return lr_scheduler_kwargs


class TrainingLoop(Generic[SampleType, BatchType], ABC):
Expand Down Expand Up @@ -646,7 +647,7 @@ def _train( # noqa: C901
if self.lr_scheduler is not None:
# Create a new lr scheduler and add the optimizer
lr_scheduler_kwargs = _get_lr_scheduler_kwargs(self.lr_scheduler)
self.lr_scheduler = self.lr_scheduler.__class__(self.optimizer, **lr_scheduler_kwargs)
self.lr_scheduler = self.lr_scheduler.__class__(optimizer=self.optimizer, **lr_scheduler_kwargs)
elif not self.optimizer.state:
raise ValueError("Cannot continue_training without being trained once.")

Expand Down Expand Up @@ -704,7 +705,8 @@ def _train( # noqa: C901
callback.register_callback(
OptimizerTrainingCallback(only_size_probing=only_size_probing, pre_step_callbacks=pre_step_callbacks)
)
callback.register_callback(LearningRateSchedulerTrainingCallback())
if self.lr_scheduler is not None:
callback.register_callback(LearningRateSchedulerTrainingCallback())

# Save the time to track when the saved point was available
last_checkpoint = time.time()
Expand Down
27 changes: 27 additions & 0 deletions tests/test_training/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Tests for LR schedulers."""

import pytest
from class_resolver import HintOrType, OptionalKwargs
from torch.optim import lr_scheduler

from pykeen.pipeline import pipeline


@pytest.mark.parametrize(
"cls, kwargs",
[
(None, None),
("CosineAnnealingWarmRestarts", None),
("CosineAnnealingWarmRestarts", {"T_0": 10}),
],
)
def test_lr_scheduler(cls: HintOrType[lr_scheduler.LRScheduler], kwargs: OptionalKwargs) -> None:
"""Smoke-test for training with learning rate schedule."""
pipeline(
dataset="nations",
model="mure",
model_kwargs=dict(embedding_dim=2),
training_kwargs=dict(num_epochs=1),
lr_scheduler=cls,
lr_scheduler_kwargs=kwargs,
)

0 comments on commit 98917df

Please sign in to comment.