Skip to content

Commit

Permalink
To fix the chainability at epoch zero for some schedulers (#63457)
Browse files Browse the repository at this point in the history
Summary:
It has been discussed in the #60836 (comment) that we have observed an obstacle to chain some type of learning rate schedulers. In particular we observed

* some of the learning rate schedulers returns initial learning rates at epoch 0 as
```
       return self.base_lrs`
```

* This can be a problem when two schedulers called as chained as

```
     scheduler1.step()
     scheduler2.step()
```

in particular, we completely ignore the effect of scheduler1 at epoch 0.  This could not be an issue if at epoch 0, scheduler1 was ineffective as in many schedulers, however for schedulers as WarmUp Schedulers, where at epoch 0 schedulers multiplicative value is smaller than 1 this could lead to undesired behaviors.

The following code snippet illustrates the problem better

## Reproducing the bug

```python
import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR, ExponentialLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 1.0)
scheduler1 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")
scheduler2 = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(10):
     print(epoch, scheduler2.get_last_lr()[0])
     optimizer.step()
     scheduler1.step()
     scheduler2.step()
```

### Current Result

```
0 1.0
1 0.9
2 0.81
3 0.7290000000000001
4 0.6561000000000001
5 5.904900000000001
6 5.314410000000001
7 4.782969000000001
8 4.304672100000001
9 3.874204890000001
```

### Expected Result

```
0 1.0
1 0.9
2 0.81
3 0.7290000000000001
4 0.6561000000000001
5 0.5904900000000001
6 0.5314410000000001
7 0.4782969000000001
8 0.4304672100000001
9 0.3874204890000001
```

Pull Request resolved: #63457

Reviewed By: datumbox

Differential Revision: D30424160

Pulled By: iramazanli

fbshipit-source-id: 3e15af8d278c872cd6f53406b55f4d3ce5002867
  • Loading branch information
iramazanli authored and facebook-github-bot committed Aug 19, 2021
1 parent 2d5b19f commit e7c4988
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ def test_adam(self):
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
[lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
[lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant"),
lambda opt: ExponentialLR(opt, gamma=0.9)]
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
Expand Down Expand Up @@ -1294,8 +1294,8 @@ def test_compound_exp_and_linear_warmup_lr(self):
for i in range(iters):
single_targets[i] *= factor + i / iters * (1 - factor)
targets = [single_targets, [x * epochs for x in single_targets]]
schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
schedulers[0] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
self._test(schedulers, targets, epochs)

def test_compound_step_and_constant_warmup(self):
Expand Down Expand Up @@ -1361,8 +1361,8 @@ def test_compound_cosanneal_and_linear_warmup_lr(self):
for i in range(iters):
single_targets[i] *= factor + i / iters * (1 - factor)
targets = [single_targets, [x * epochs for x in single_targets]]
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
schedulers[0] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
self._test(schedulers, targets, epochs)

def test_compound_cosanneal_and_exp_lr(self):
Expand Down
6 changes: 3 additions & 3 deletions torch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def get_lr(self):
return [group['lr'] * lmbda(self.last_epoch)
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
else:
return list(self.base_lrs)
return [group['lr'] for group in self.optimizer.param_groups]


class StepLR(_LRScheduler):
Expand Down Expand Up @@ -526,7 +526,7 @@ def get_lr(self):
"please use `get_last_lr()`.", UserWarning)

if self.last_epoch == 0:
return self.base_lrs
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]

Expand Down Expand Up @@ -586,7 +586,7 @@ def get_lr(self):
"please use `get_last_lr()`.", UserWarning)

if self.last_epoch == 0:
return self.base_lrs
return [group['lr'] for group in self.optimizer.param_groups]
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return [group['lr'] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / self.T_max)) / 2
Expand Down

0 comments on commit e7c4988

Please sign in to comment.