Skip to content

Bug in MultiStepLR lr scheduler #31828

@Steve-Tod

Description

@Steve-Tod

🐛 Bug

Adding epoch argument to step() function of MultiStepLR lead to false learning rate.

To Reproduce

from torch import nn
import torch
net = nn.Linear(30, 10)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
s = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20, 30], gamma=0.1)
print(s.get_lr())
s.step(1)
print(s.get_lr())

Output

[0.001]
[1.0000000000000002e-06]

Expected behavior

[0.001]
[0.001]

Environment

PyTorch version: 1.4.0a0+d5bf51b
Is debug build: No
CUDA used to build PyTorch: 9.0

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.14.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 430.26
cuDNN version: Could not collect

Versions of relevant libraries:

[pip] numpy==1.17.3
[pip] torch==1.4.0a0+d5bf51b
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] mkl 2019.4 243
[conda] mkl-include 2019.4 243
[conda] mkl-service 2.3.0 py36he904b0f_0
[conda] mkl_fft 1.0.15 py36ha843d7b_0
[conda] mkl_random 1.1.0 py36hd6b4f25_0
[conda] torch 1.4.0a0+d5bf51b pypi_0 pypi

Additional context

Possible cause might be that the milestones of MultiStepLR is a counter rather then a list, which leads to false action of bisect in get_lr function.

cc @vincentqb

Metadata

Metadata

Assignees

Labels

module: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions