Skip to content

Commit

Permalink
To add SequentialLR to PyTorch Core Schedulers (#64037)
Browse files Browse the repository at this point in the history
Summary:
Partially resolves pytorch/vision#4281

In this PR we are proposing a new scheduler --SequentialLR-- which enables list of different schedulers called in different periods of the training process.

The main motivation of this scheduler is recently gained popularity of warming up phase in the training time. It has been shown that having a small steps in initial stages of training can help convergence procedure get faster.

With the help of SequentialLR we mainly enable to call a small constant (or linearly increasing) learning rate followed by actual target learning rate scheduler.

```PyThon
scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[5])

for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()
```

which this code snippet will call `ConstantLR` in the first 5 epochs and will follow up with `ExponentialLR` in the following epochs.

This scheduler could be used to provide call of any group of schedulers next to each other. The main consideration we should make is every time we switch to a new scheduler we assume that new scheduler starts from the beginning- zeroth epoch.

We also add Chained Scheduler to `optim.rst` and `lr_scheduler.pyi` files here.

Pull Request resolved: #64037

Reviewed By: albanD

Differential Revision: D30841099

Pulled By: iramazanli

fbshipit-source-id: 94f7d352066ee108eef8cda5f0dcb07f4d371751
  • Loading branch information
iramazanli authored and facebook-github-bot committed Sep 9, 2021
1 parent c3203ef commit 2b41bf4
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ algorithms.
lr_scheduler.LinearLR
lr_scheduler.ExponentialLR
lr_scheduler.CosineAnnealingLR
lr_scheduler.ChainedScheduler
lr_scheduler.SequentialLR
lr_scheduler.ReduceLROnPlateau
lr_scheduler.CyclicLR
lr_scheduler.OneCycleLR
Expand Down
37 changes: 36 additions & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.optim import SGD
from torch.autograd import Variable
from torch import sparse
from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \
from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, StepLR, \
MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
_LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
Expand Down Expand Up @@ -1255,6 +1255,41 @@ def test_reduce_lr_on_plateau8(self):
threshold=0.1, patience=5, cooldown=5)
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_sequentiallr1(self):
epochs = 19
schedulers = [None] * 2
targets = [[0.05, 0.04, 0.032] + [0.05 for x in range(4)]
+ [0.05 * 0.1 for x in range(4)]
+ [0.05 * 0.01 for x in range(4)]
+ [0.05 * 0.001 for x in range(4)]]
milestones = [3]
schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
self._test(scheduler, targets, epochs)

def test_sequentiallr2(self):
epochs = 13
schedulers = [None] * 2
targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9 ** x for x in range(10)]]
milestones = [3]
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
self._test(scheduler, targets, epochs)

def test_sequentiallr3(self):
epochs = 12
schedulers = [None] * 3
targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032]
+ [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]]
milestones = [3, 6]
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
self._test(scheduler, targets, epochs)

def test_chained_lr1(self):
epochs = 10
schedulers = [None] * 1
Expand Down
51 changes: 51 additions & 0 deletions torch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,57 @@ def _get_closed_form_lr(self):
for base_lr in self.base_lrs]


class SequentialLR(_LRScheduler):
"""Receives the list of schedulers that is expected to be called sequentially during
optimization process and milestone points that provides exact intervals to reflect
which scheduler is supposed to be called at a given epoch.
Args:
schedulers (list): List of chained schedulers.
milestones (list): List of integers that reflects milestone points.
Example:
>>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.1 if epoch == 0
>>> # lr = 0.1 if epoch == 1
>>> # lr = 0.9 if epoch == 2
>>> # lr = 0.81 if epoch == 3
>>> # lr = 0.729 if epoch == 4
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
>>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
for scheduler_idx in range(1, len(schedulers)):
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
raise ValueError(
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
"got schedulers at index {} and {} to be different".format(0, scheduler_idx)
)
if (len(milestones) != len(schedulers) - 1):
raise ValueError(
"Sequential Schedulers expects number of schedulers provided to be one more "
"than the number of milestone points, but got number of schedulers {} and the "
"number of milestones to be equal to {}".format(len(schedulers), len(milestones))
)
self._schedulers = schedulers
self._milestones = milestones
self.last_epoch = last_epoch + 1

def step(self):
self.last_epoch += 1
idx = bisect_right(self._milestones, self.last_epoch)
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
self._schedulers[idx].step(0)
else:
self._schedulers[idx].step()


class CosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
Expand Down
6 changes: 6 additions & 0 deletions torch/optim/lr_scheduler.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class LinearLR(_LRScheduler):
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ...

class ChainedScheduler(_LRScheduler):
def __init__(self, schedulers: List[_LRScheduler]) -> None: ...

class SequentialLR(_LRScheduler):
def __init__(self, schedulers: List[_LRScheduler], milestones: List[int], last_epoch: int=...) -> None: ...

class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float=..., last_epoch: int=...) -> None: ...

Expand Down

0 comments on commit 2b41bf4

Please sign in to comment.