Skip to content

To add warm-up scheduler to optim #60836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

iramazanli
Copy link
Contributor

@iramazanli iramazanli commented Jun 27, 2021

Warm up of learning rate scheduling has initially been discussed by Priya et. al. in the paper: https://arxiv.org/pdf/1706.02677.pdf .

In the section 2.2 of the paper they discussed and proposed idea of warming up learning schedulers in order to prevent big variance / noise in the learning rate. Then idea has been further discussed in the following papers:

There are two type of popularly used learning rate warm up ideas

  • Constant warmup (start with very small constant learning rate)
  • Linear Warmup ( start with small learning rate and gradually increase)

In this PR we are adding warm up as learning rate scheduler. Note that learning rates are chainable, which means that we can merge warmup scheduler with any other learning rate scheduler to make more sophisticated learning rate scheduler.

Linear Warmup

Linear Warmup is multiplying learning rate with pre-defined constant - warmup_factor in the first epoch (epoch 0). Then targeting to increase this multiplication constant to one in warmup_iters many epochs. Hence we can derive the formula at i-th step to have multiplication constant equal to:

                warmup_factor + (1-warmup_factor) * i /  warmup_iters

Moreover, the fraction of this quantity at point i to point i-1 will give us

       1 + (1.0 - warmup_factor) / [warmup_iters*warmup_factor+(i-1)*(1-warmup_factor)]

which is used in get_lr() method in our implementation. Below we provide an example how to use linear warmup scheduler and to give an example to show how does it works.

import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=10, warmup_method="linear")

for epoch in range(15):

    print(epoch, scheduler.get_last_lr()[0])

    optimizer.step()
    scheduler.step()
0 0.010000000000000002
1 0.019000000000000003
2 0.028000000000000008
3 0.03700000000000001
4 0.04600000000000001
5 0.055000000000000014
6 0.06400000000000002
7 0.07300000000000002
8 0.08200000000000003
9 0.09100000000000004
10 0.10000000000000005
11 0.10000000000000005
12 0.10000000000000005
13 0.10000000000000005
14 0.10000000000000005

Constant Warmup

Constant warmup has straightforward idea, to multiply learning rate by warmup_factor until we reach to epoch warmup_factor, then do nothing for following epochs

import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")

for epoch in range(10):

    print(epoch, scheduler.get_last_lr()[0])

    optimizer.step()
    scheduler.step()
0 0.010000000000000002
1 0.010000000000000002
2 0.010000000000000002
3 0.010000000000000002
4 0.010000000000000002
5 0.10000000000000002
6 0.10000000000000002
7 0.10000000000000002
8 0.10000000000000002
9 0.10000000000000002

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 27, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 6d4aade (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build win-vs2019-cuda10.1-py3 / test (default, 1, 1, windows.8xlarge.nvidia.gpu) (1/1)

Step: "Install Cuda" (full log | diagnosis details | 🔁 rerun)

2021-08-15T18:50:39.4479697Z ls: cannot access ...x64/nvToolsExt64_1.dll': No such file or directory
2021-08-15T18:47:47.4747508Z Compressed: 2565595880
2021-08-15T18:47:47.4880715Z + cd cuda_10.1.243_426.00_win10
2021-08-15T18:47:47.4884803Z + mkdir cuda_install_logs
2021-08-15T18:47:47.5313268Z + set +e
2021-08-15T18:47:47.5399061Z ++ pwd -W
2021-08-15T18:47:47.5409187Z + ./setup.exe -s nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1 -loglevel:6 -log:C:/actions-runner/_work/pytorch/pytorch/pytorch-1133061401/cuda_10.1.243_426.00_win10/cuda_install_logs
2021-08-15T18:50:39.1312275Z + set -e
2021-08-15T18:50:39.1313027Z + [[ 2019 == \2\0\1\7 ]]
2021-08-15T18:50:39.1321473Z + cp -r 'CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions/CUDA 10.1.props' 'CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions/CUDA 10.1.targets' 'CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions/CUDA 10.1.xml' CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions/Nvda.Build.CudaTasks.v10.1.dll 'C:/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools/MSBuild/Microsoft/VC/v160/BuildCustomizations/'
2021-08-15T18:50:39.4005847Z + ls '/c/Program Files/NVIDIA Corporation/NvToolsExt/bin/x64/nvToolsExt64_1.dll'
2021-08-15T18:50:39.4479697Z ls: cannot access '/c/Program Files/NVIDIA Corporation/NvToolsExt/bin/x64/nvToolsExt64_1.dll': No such file or directory
2021-08-15T18:50:39.4484743Z + curl --retry 3 -kLO https://ossci-windows.s3.amazonaws.com/NvToolsExt.7z
2021-08-15T18:50:39.4689454Z   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2021-08-15T18:50:39.4692159Z                                  Dload  Upload   Total   Spent    Left  Speed
2021-08-15T18:50:39.4692741Z 
2021-08-15T18:50:40.6925215Z   0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
2021-08-15T18:50:41.6928280Z   0     0    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
2021-08-15T18:50:42.6931832Z   0     0    0     0    0     0      0      0 --:--:--  0:00:02 --:--:--     0
2021-08-15T18:50:43.6934596Z   0     0    0     0    0     0      0      0 --:--:--  0:00:03 --:--:--     0
2021-08-15T18:50:44.6938489Z   0     0    0     0    0     0      0      0 --:--:--  0:00:04 --:--:--     0
2021-08-15T18:50:45.6946846Z   0     0    0     0    0     0      0      0 --:--:--  0:00:05 --:--:--     0

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@iramazanli iramazanli force-pushed the warmup_scheduler branch 7 times, most recently from beeecf8 to 3463ced Compare June 28, 2021 13:14
@codecov
Copy link

codecov bot commented Jun 28, 2021

Codecov Report

Merging #60836 (44afe75) into master (045c4cb) will increase coverage by 15.72%.
The diff coverage is 90.00%.

❗ Current head 44afe75 differs from pull request most recent head 6d4aade. Consider uploading reports for the commit 6d4aade to get more accurate results

@@             Coverage Diff             @@
##           master   #60836       +/-   ##
===========================================
+ Coverage   60.53%   76.26%   +15.72%     
===========================================
  Files         684     2062     +1378     
  Lines       88652   205622   +116970     
===========================================
+ Hits        53666   156812   +103146     
- Misses      34986    48810    +13824     

@iramazanli iramazanli force-pushed the warmup_scheduler branch 20 times, most recently from 88df5ed to fa80a5b Compare June 29, 2021 09:53
@iramazanli iramazanli force-pushed the warmup_scheduler branch 6 times, most recently from 94423e5 to fd66891 Compare June 29, 2021 23:24
@iramazanli iramazanli requested review from fmassa and datumbox June 29, 2021 23:39
@iramazanli iramazanli force-pushed the warmup_scheduler branch 2 times, most recently from 825595a to 44afe75 Compare June 30, 2021 00:55
@facebook-github-bot
Copy link
Contributor

@iramazanli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@datumbox
Copy link
Contributor

datumbox commented Jul 5, 2021

@iramazanli Thanks for the PR!

Note that learning rates are chainable, which means that we can merge warmup scheduler with any other learning rate scheduler to make more sophisticated learning rate scheduler.

Do you think it's worth providing an example of how this is done in the documentation?

cc @fmassa any thoughts on this?

@iramazanli
Copy link
Contributor Author

@iramazanli Thanks for the PR!

Note that learning rates are chainable, which means that we can merge warmup scheduler with any other learning rate scheduler to make more sophisticated learning rate scheduler.

Do you think it's worth providing an example of how this is done in the documentation?

cc @fmassa any thoughts on this?

Actually, here https://pytorch.org/docs/stable/optim.html we have example for the general case of chaining.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot Ilqar (and sorry for the delay in reviewing the PR)

As a side comment and for a follow-up PR, I think it would make it much easier for users to use WarmUpLR if we have a LRScheduler that combines multiple schedulers together.
This way, users don't need to change their training loop to add warmup.

Something in the lines of

scheduler = CombinedScheduler([scheduler1, scheduler2])

so that instead of doing

scheduler1.step()
scheduler2.step()

the user just needs to do

scheduler.step()

Also, @datumbox once this PR is merged, could you update the torchvision reference to use PyTorch's WarmUpLR?

@facebook-github-bot
Copy link
Contributor

@iramazanli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@iramazanli
Copy link
Contributor Author

LGTM, thanks a lot Ilqar (and sorry for the delay in reviewing the PR)

As a side comment and for a follow-up PR, I think it would make it much easier for users to use WarmUpLR if we have a LRScheduler that combines multiple schedulers together.
This way, users don't need to change their training loop to add warmup.

Something in the lines of

scheduler = CombinedScheduler([scheduler1, scheduler2])

so that instead of doing

scheduler1.step()
scheduler2.step()

the user just needs to do

scheduler.step()

Also, @datumbox once this PR is merged, could you update the torchvision reference to use PyTorch's WarmUpLR?

I completely agree with the idea of Combined Scheduler. Thanks for the suggestion !

There will be a follow-up PR regarding a bug we have discussed offline for chaining schedulers (general case). After that bug fix, usage of warmup scheduler as a chain will be more safe.

@facebook-github-bot
Copy link
Contributor

@iramazanli merged this pull request in cec08e7.

alanwaketan pushed a commit that referenced this pull request Aug 17, 2021
Summary:
Warm up of learning rate scheduling has initially been discussed  by Priya et. al. in the paper: https://arxiv.org/pdf/1706.02677.pdf .

In the section 2.2 of the paper they discussed and proposed idea of warming up learning schedulers in order to prevent big variance / noise in the learning rate. Then idea has been further discussed in the following papers:
  * Akilesh Gotmare et al. https://arxiv.org/abs/1810.13243
  * Bernstein et al  http://proceedings.mlr.press/v80/bernstein18a/bernstein18a.pdf
  * Liyuan Liu et al: https://arxiv.org/pdf/1908.03265.pdf

There are two type of popularly used learning rate warm up ideas
  * Constant warmup  (start with very small constant learning rate)
  * Linear Warmup        ( start with small learning rate and gradually increase)

In this PR we are adding warm up as learning rate scheduler. Note that learning rates are chainable, which means that we can merge warmup scheduler with any other learning rate scheduler to make more sophisticated learning rate scheduler.

## Linear Warmup

Linear Warmup is multiplying learning rate with pre-defined constant - warmup_factor in the first epoch (epoch 0). Then targeting to increase this multiplication constant to one in warmup_iters many epochs. Hence we can derive the formula at i-th step to have multiplication constant equal to:

                    warmup_factor + (1-warmup_factor) * i /  warmup_iters

Moreover, the fraction of this quantity at point i to point i-1 will give us

           1 + (1.0 - warmup_factor) / [warmup_iters*warmup_factor+(i-1)*(1-warmup_factor)]

which is used in get_lr() method in our implementation. Below we provide an example how to use linear warmup scheduler and to give an example to show how does it works.

```python
import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=10, warmup_method="linear")

for epoch in range(15):

    print(epoch, scheduler.get_last_lr()[0])

    optimizer.step()
    scheduler.step()
```

```
0 0.010000000000000002
1 0.019000000000000003
2 0.028000000000000008
3 0.03700000000000001
4 0.04600000000000001
5 0.055000000000000014
6 0.06400000000000002
7 0.07300000000000002
8 0.08200000000000003
9 0.09100000000000004
10 0.10000000000000005
11 0.10000000000000005
12 0.10000000000000005
13 0.10000000000000005
14 0.10000000000000005
```

## Constant Warmup

Constant warmup has straightforward idea, to multiply learning rate by warmup_factor until we reach to epoch warmup_factor, then do nothing for following epochs

```python
import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")

for epoch in range(10):

    print(epoch, scheduler.get_last_lr()[0])

    optimizer.step()
    scheduler.step()
```

```
0 0.010000000000000002
1 0.010000000000000002
2 0.010000000000000002
3 0.010000000000000002
4 0.010000000000000002
5 0.10000000000000002
6 0.10000000000000002
7 0.10000000000000002
8 0.10000000000000002
9 0.10000000000000002
```

Pull Request resolved: #60836

Reviewed By: saketh-are

Differential Revision: D29537615

Pulled By: iramazanli

fbshipit-source-id: d910946027acc52663b301f9c56ade686e62cb69
facebook-github-bot pushed a commit that referenced this pull request Aug 19, 2021
Summary:
It has been discussed in the #60836 (comment) that we have observed an obstacle to chain some type of learning rate schedulers. In particular we observed

* some of the learning rate schedulers returns initial learning rates at epoch 0 as
```
       return self.base_lrs`
```

* This can be a problem when two schedulers called as chained as

```
     scheduler1.step()
     scheduler2.step()
```

in particular, we completely ignore the effect of scheduler1 at epoch 0.  This could not be an issue if at epoch 0, scheduler1 was ineffective as in many schedulers, however for schedulers as WarmUp Schedulers, where at epoch 0 schedulers multiplicative value is smaller than 1 this could lead to undesired behaviors.

The following code snippet illustrates the problem better

## Reproducing the bug

```python
import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR, ExponentialLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 1.0)
scheduler1 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")
scheduler2 = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(10):
     print(epoch, scheduler2.get_last_lr()[0])
     optimizer.step()
     scheduler1.step()
     scheduler2.step()
```

### Current Result

```
0 1.0
1 0.9
2 0.81
3 0.7290000000000001
4 0.6561000000000001
5 5.904900000000001
6 5.314410000000001
7 4.782969000000001
8 4.304672100000001
9 3.874204890000001
```

### Expected Result

```
0 1.0
1 0.9
2 0.81
3 0.7290000000000001
4 0.6561000000000001
5 0.5904900000000001
6 0.5314410000000001
7 0.4782969000000001
8 0.4304672100000001
9 0.3874204890000001
```

Pull Request resolved: #63457

Reviewed By: datumbox

Differential Revision: D30424160

Pulled By: iramazanli

fbshipit-source-id: 3e15af8d278c872cd6f53406b55f4d3ce5002867
facebook-github-bot pushed a commit that referenced this pull request Sep 7, 2021
Summary:
Partially unblocks pytorch/vision#4281

Previously we have added WarmUp Schedulers to PyTorch Core in the PR : #60836 which had two mode of execution - linear and constant depending on warming up function.

In this PR we are changing this interface to more direct form, as separating linear and constant modes to separate Schedulers. In particular

```Python
scheduler1 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")
scheduler2 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="linear")
```

will look like

```Python
scheduler1 = ConstantLR(optimizer, warmup_factor=0.1, warmup_iters=5)
scheduler2 = LinearLR(optimizer, warmup_factor=0.1, warmup_iters=5)
```

correspondingly.

Pull Request resolved: #64395

Reviewed By: datumbox

Differential Revision: D30753688

Pulled By: iramazanli

fbshipit-source-id: e47f86d12033f80982ddf1faf5b46873adb4f324
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants