diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 695f0a2a03f6..62a293dec5ec 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -214,6 +214,8 @@ algorithms. lr_scheduler.LinearLR lr_scheduler.ExponentialLR lr_scheduler.CosineAnnealingLR + lr_scheduler.ChainedScheduler + lr_scheduler.SequentialLR lr_scheduler.ReduceLROnPlateau lr_scheduler.CyclicLR lr_scheduler.OneCycleLR diff --git a/test/test_optim.py b/test/test_optim.py index d69e9351d33a..2d88d6f4bdab 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -11,7 +11,7 @@ from torch.optim import SGD from torch.autograd import Variable from torch import sparse -from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \ +from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, StepLR, \ MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler from torch.optim.swa_utils import AveragedModel, SWALR, update_bn @@ -1255,6 +1255,41 @@ def test_reduce_lr_on_plateau8(self): threshold=0.1, patience=5, cooldown=5) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) + def test_sequentiallr1(self): + epochs = 19 + schedulers = [None] * 2 + targets = [[0.05, 0.04, 0.032] + [0.05 for x in range(4)] + + [0.05 * 0.1 for x in range(4)] + + [0.05 * 0.01 for x in range(4)] + + [0.05 * 0.001 for x in range(4)]] + milestones = [3] + schedulers[0] = ExponentialLR(self.opt, gamma=0.8) + schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4) + scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) + self._test(scheduler, targets, epochs) + + def test_sequentiallr2(self): + epochs = 13 + schedulers = [None] * 2 + targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9 ** x for x in range(10)]] + milestones = [3] + schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) + schedulers[1] = ExponentialLR(self.opt, gamma=0.9) + scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) + self._test(scheduler, targets, epochs) + + def test_sequentiallr3(self): + epochs = 12 + schedulers = [None] * 3 + targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032] + + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]] + milestones = [3, 6] + schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) + schedulers[1] = ExponentialLR(self.opt, gamma=0.8) + schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) + scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) + self._test(scheduler, targets, epochs) + def test_chained_lr1(self): epochs = 10 schedulers = [None] * 1 diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 42f7b511c54a..204340187666 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -582,6 +582,57 @@ def _get_closed_form_lr(self): for base_lr in self.base_lrs] +class SequentialLR(_LRScheduler): + """Receives the list of schedulers that is expected to be called sequentially during + optimization process and milestone points that provides exact intervals to reflect + which scheduler is supposed to be called at a given epoch. + + Args: + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + + Example: + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + if (len(milestones) != len(schedulers) - 1): + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + "than the number of milestone points, but got number of schedulers {} and the " + "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) + ) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + + def step(self): + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + self._schedulers[idx].step(0) + else: + self._schedulers[idx].step() + + class CosineAnnealingLR(_LRScheduler): r"""Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial lr and diff --git a/torch/optim/lr_scheduler.pyi b/torch/optim/lr_scheduler.pyi index 9b1b8ea63eed..9552e8e248b1 100644 --- a/torch/optim/lr_scheduler.pyi +++ b/torch/optim/lr_scheduler.pyi @@ -27,6 +27,12 @@ class LinearLR(_LRScheduler): class ExponentialLR(_LRScheduler): def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ... +class ChainedScheduler(_LRScheduler): + def __init__(self, schedulers: List[_LRScheduler]) -> None: ... + +class SequentialLR(_LRScheduler): + def __init__(self, schedulers: List[_LRScheduler], milestones: List[int], last_epoch: int=...) -> None: ... + class CosineAnnealingLR(_LRScheduler): def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float=..., last_epoch: int=...) -> None: ...