Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] learning rate schedulers behave unexpectedly with pytorch 2.0.0 #3202

Closed
sjfleming opened this issue Apr 24, 2023 · 0 comments
Closed

Comments

@sjfleming
Copy link
Contributor

sjfleming commented Apr 24, 2023

Issue Description

Learning rate scheduler does not seem to behave as expected, and importantly, the learning rate schedule is different depending on whether you use pytorch version 1.13.0 versus 2.0.0

Environment

  • macOS Monterey
  • Python 3.8
  • pyro dev version pyro-ppl 1.8.4+dd4e0f81
  • pytorch versions specified below

Code Snippet

import pyro
import torch

# dummy dataloader
dataset = torch.randn((100))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)

epochs = 10


def model(x):
    loc = pyro.param('loc', torch.tensor(1.))
    with pyro.plate('plate', x.shape[0]):
        pyro.sample('obs', pyro.distributions.Normal(loc, 1.0), obs=x)


def guide(x):
    pass


optimizer = torch.optim.Adam
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.1}, 'gamma': 0.9})
svi = pyro.infer.SVI(model, guide, scheduler, loss=pyro.infer.Trace_ELBO())

for i in range(epochs):
    for minibatch in dataloader:
        svi.step(minibatch)
    lr = list(scheduler.optim_objs.values())[0].get_last_lr()[0]
    print(f'[{i + 1:03d}]  lr = {lr:.3e}')
    svi.optim.step()

for torch 2.0.0, I get

[001]  lr = 5.905e-02
[002]  lr = 3.138e-02
[003]  lr = 1.668e-02
[004]  lr = 8.863e-03
[005]  lr = 4.710e-03
[006]  lr = 2.503e-03
[007]  lr = 1.330e-03
[008]  lr = 7.070e-04
[009]  lr = 3.757e-04
[010]  lr = 1.997e-04

accompanied by the warning

UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "

while for torch 1.13.0, I get

[001]  lr = 1.000e-01
[002]  lr = 9.000e-02
[003]  lr = 8.100e-02
[004]  lr = 7.290e-02
[005]  lr = 6.561e-02
[006]  lr = 5.905e-02
[007]  lr = 5.314e-02
[008]  lr = 4.783e-02
[009]  lr = 4.305e-02
[010]  lr = 3.874e-02

The latter is a lot closer to what I'd expect, given that

0.1 * 0.9**10 = 0.0348

I also don't get the warning from pytorch 1.13.0

If I do what I think should be the same thing in pytorch 1.13.0 without any pyro code, I see the following learning rate schedule that agrees with the pyro + pytorch 1.13.0 version above:

import torch

# dummy dataloader
dataset = torch.randn((100))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)

epochs = 10

x = torch.nn.Parameter(torch.tensor(1.))
optimizer = torch.optim.Adam([x], lr=0.1)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
loss_fn = torch.nn.MSELoss()

for i in range(epochs):
    for minibatch in dataloader:
        optimizer.zero_grad()
        loss = loss_fn(minibatch, x)
        loss.backward()
        optimizer.step()
    lr = scheduler.get_last_lr()[0]
    print(f'[{i + 1:03d}]  lr = {lr:.3e}')
    scheduler.step()

gives me, using pytorch 1.13.0,

[001]  lr = 1.000e-01
[002]  lr = 9.000e-02
[003]  lr = 8.100e-02
[004]  lr = 7.290e-02
[005]  lr = 6.561e-02
[006]  lr = 5.905e-02
[007]  lr = 5.314e-02
[008]  lr = 4.783e-02
[009]  lr = 4.305e-02
[010]  lr = 3.874e-02

and the above with pytorch 2.0.0 gives me the same thing:

[001]  lr = 1.000e-01
[002]  lr = 9.000e-02
[003]  lr = 8.100e-02
[004]  lr = 7.290e-02
[005]  lr = 6.561e-02
[006]  lr = 5.905e-02
[007]  lr = 5.314e-02
[008]  lr = 4.783e-02
[009]  lr = 4.305e-02
[010]  lr = 3.874e-02

So I think this is somehow related to the pyro + pytorch interface.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant