Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace LambdaLR with PolynomialLR in segmentation training script #6405

Merged

Conversation

federicopozzi33
Copy link
Contributor

Replace LambdaLR with PolynomialLR in segmentation training script.

Closes: #4438

@federicopozzi33
Copy link
Contributor Author

A small snippet of code to compare the old scheduler (LambdaLR) with the new one (PolynomialLR). The example does not consider warmup.

>>> lr = 0.001
>>> epochs = 5
>>> data_loader = 2
>>> power = 1.0
>>> 
>>> scheduler_new = PolynomialLR(
>>>     torch.optim.SGD([torch.zeros(1)], lr=lr),
>>>     total_iters=epochs*data_loader,
>>>     power=power,
>>> )
>>> scheduler_old = torch.optim.lr_scheduler.LambdaLR(
>>>     torch.optim.SGD([torch.zeros(1)], lr=lr), lambda x: (1 - x / (epochs*data_loader)) ** power
>>> )
>>> 
>>> for i in range(epochs):
>>>   for j in range(data_loader):
>>>     new_current_lr = scheduler_new.optimizer.param_groups[0]['lr']
>>>     old_current_lr = scheduler_old.optimizer.param_groups[0]['lr']
>>>     print(f"epoch: {i} step: {j} | new: {new_current_lr:5f} old: {old_current_lr:5f}")
>>>     scheduler_new.step()
>>>     scheduler_old.step()

epoch: 0 step: 0 | new: 0.001000 old: 0.001000
epoch: 0 step: 1 | new: 0.000900 old: 0.000900
epoch: 1 step: 0 | new: 0.000800 old: 0.000800
epoch: 1 step: 1 | new: 0.000700 old: 0.000700
epoch: 2 step: 0 | new: 0.000600 old: 0.000600
epoch: 2 step: 1 | new: 0.000500 old: 0.000500
epoch: 3 step: 0 | new: 0.000400 old: 0.000400
epoch: 3 step: 1 | new: 0.000300 old: 0.000300
epoch: 4 step: 0 | new: 0.000200 old: 0.000200
epoch: 4 step: 1 | new: 0.000100 old: 0.000100

@datumbox
Why scheduler is used per step and not per epoch? Is there a specific reason?

@datumbox
Copy link
Contributor

datumbox commented Aug 12, 2022

@federicopozzi33 That should do it.

Could you please provide a dummy output that show-cases that it works as expected (before and after)? You can either use the script and remove the actual training or just test the two schedulers to show the behave identically with the configuration you use.

Edit: We were posting comments at the same time again. 😄 Let me have a closer look.

@datumbox datumbox marked this pull request as ready for review August 12, 2022 09:23
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @federicopozzi33.

Why scheduler is used per step and not per epoch? Is there a specific reason?

It allows you to decay LR in shorter steps within the epoch rather than use the same LR in each epoch. This allows you to have better control especially when you combine with warm ups (aka you don't have to wait multiple epochs to warm up).

@datumbox datumbox merged commit 6e535db into pytorch:main Aug 12, 2022
@datumbox datumbox added this to In progress in Batteries Included - Phase 3 via automation Aug 22, 2022
@datumbox datumbox moved this from In progress to Done in Batteries Included - Phase 3 Aug 22, 2022
facebook-github-bot pushed a commit that referenced this pull request Aug 24, 2022
… training script (#6405)

Reviewed By: datumbox

Differential Revision: D38824250

fbshipit-source-id: b10950254c0ba0471e0443a7cddba42594324185
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

Successfully merging this pull request may close these issues.

Investigate if lr_scheduler from segmentation can use PyTorch's schedulers
3 participants