Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lr scheduler #1370

Merged
merged 60 commits into from May 25, 2017
Merged

Lr scheduler #1370

merged 60 commits into from May 25, 2017

Conversation

Jiaming-Liu
Copy link
Contributor

@Jiaming-Liu Jiaming-Liu commented Apr 26, 2017

Providing a unified LR scheduler.

Currently supports:

  • ReduceLROnPlateau (ported from Keras)
  • LambdaLR
  • StepLR
  • MultiStepLR
  • ExponentialLR
  • GroupLambdaLR (Need testing)

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks generally good. I've added some inline comments.

This needs some unit tests. Basically:

  • create an optimizer and LR scheduler
  • for a few different values of 'epoch', call step on the scheduler and check that the LR of the optimizer is correct

self.zip = zip(optimizer.param_groups, base_lrs, lr_lambdas)

def step(self, epoch):
for param_group, base_lr, lr_lambda in self.zip:

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if self.mode not in ['min', 'max']:
raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!')
if self.mode == 'min':
self.monitor_op = lambda a, b: np.less(a, b - self.epsilon)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

def _reset(self):
"""Resets wait counter and cooldown counter.
"""
if self.mode not in ['min', 'max']:

This comment was marked as off-topic.

self.wait = 0
self.lr_epsilon = self.min_lr * 1e-4

def reset(self):

This comment was marked as off-topic.

This comment was marked as off-topic.


def step(self, epoch, metrics):
current = metrics
if current is None:

This comment was marked as off-topic.

@Jiaming-Liu
Copy link
Contributor Author

Not sure where to put the unit tester. I have put it here.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good for the most part, but I think some parts could be simplified. Thanks for the PR!

param_group['lr'] = self.base_lr * self.lr_lambda(epoch)


class GroupLambdaLR(object):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

from torch.optim.optimizer import Optimizer


class LambdaLR(object):

This comment was marked as off-topic.

This comment was marked as off-topic.

>>> validate(...)
"""

def __init__(self, optimizer, base_lr=0.1, gamma=0.1, step_size=30):

This comment was marked as off-topic.

This comment was marked as off-topic.

>>> validate(...)
"""

def __init__(self, optimizer, base_lr=0.1, gamma=0.1, milestones=(10, 20, 30)):

This comment was marked as off-topic.

This comment was marked as off-topic.

be reduced. new_lr = lr * factor
patience: number of epochs with no improvement
after which learning rate will be reduced.
verbose: int. 0: quiet, 1: update messages.

This comment was marked as off-topic.

raise RuntimeError('Learning Rate Plateau threshold mode %s is unknown!')
if mode == 'min' and threshold_mode == 'rel':
rel_epsilon = 1. - threshold
self.monitor_op = lambda a, best: np.less(a, best * rel_epsilon)

This comment was marked as off-topic.

This comment was marked as off-topic.

raise RuntimeError('Learning Rate Plateau threshold mode %s is unknown!')
if mode == 'min' and threshold_mode == 'rel':
rel_epsilon = 1. - threshold
self.monitor_op = lambda a, best: np.less(a, best * rel_epsilon)

This comment was marked as off-topic.

This comment was marked as off-topic.

self.cooldown_counter -= 1
self.wait = 0

if self.monitor_op(current, self.best):

This comment was marked as off-topic.

This comment was marked as off-topic.

self.best = self.monitor_op.worse
self.cooldown_counter = 0
self.wait = 0
self.lr_epsilon = self.min_lr * 1e-4

This comment was marked as off-topic.

This comment was marked as off-topic.

new_lr = max(new_lr, self.min_lr)
param_group['lr'] = new_lr
if self.verbose > 0:
print('Epoch %05d: reducing learning rate of group %d to %s.' % (epoch, inx_group, new_lr))

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Apr 28, 2017

Also, tests should go to test_optim.py

@soumith
Copy link
Member

soumith commented May 4, 2017

@Jiaming-Liu i think this is good to go. You should also add docstrings for LambdaLR and GroupLambdaLR and add references in https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/optim.rst so that they will show up in documentation as well.
You can test documentation locally by doing this:

cd docs
pip install -r requirements.txt
make clean && make html

and then locally generated html documentation similar to pytorch.org/docs/ will be in docs/build/html

@apaszke
Copy link
Contributor

apaszke commented May 4, 2017

@pytorchbot test this please

@apaszke
Copy link
Contributor

apaszke commented May 4, 2017

@pytorchbot test this please

1 similar comment
@apaszke
Copy link
Contributor

apaszke commented May 4, 2017

@pytorchbot test this please

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the extensive of lambdas will prevent these classes from being pickled. (They probably could just be instance methods)



class LambdaLR(object):
def __init__(self, optimizer, lr_lambda):

This comment was marked as off-topic.

self.lr_lambdas = list(lr_lambda)
self.last_epoch = -1

def step(self, epoch=None):

This comment was marked as off-topic.

@szagoruyko szagoruyko mentioned this pull request May 9, 2017
2 tasks
@apaszke
Copy link
Contributor

apaszke commented May 10, 2017

Yes, get_lr would be a better name. Also, please remove base_lr and use the code I wrote in the commit comments.

@Jiaming-Liu
Copy link
Contributor Author

This error is weird. Any idea? @apaszke

Running optim tests
..............F........
FAIL: test_adagrad_sparse (main.TestOptim)
Traceback (most recent call last):
File "test_optim.py", line 285, in test_adagrad_sparse
lambda params: optim.Adagrad(params, lr=1e-1)
File "test_optim.py", line 103, in _test_rosenbrock_sparse
self.assertLessEqual(params.data.dist(solution), initial_dist)
AssertionError: 0.7290316658655626 not less than or equal to 0.7071067811865476

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now and should be ready to merge after these final fixes. Can you also add the schedulers to docs/source/optim.rst?

raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.step(last_epoch + 1)

This comment was marked as off-topic.

This comment was marked as off-topic.

param_group['lr'] = lr


class LambdaLR(_LRScheduler):

This comment was marked as off-topic.

This comment was marked as off-topic.

class LambdaLR(_LRScheduler):
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
self.optimizer = optimizer
self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented May 14, 2017

Also, can you try rebasing on top of master? The test might have been fixed in some other commit.

@Jiaming-Liu
Copy link
Contributor Author

Now it seems like a good time to update the documentation. I will have it done within a week. Thanks for the reviews!

@Jiaming-Liu
Copy link
Contributor Author

Rebasing doesn't help the error :(. Any ideas?

@thomasjpfan
Copy link
Contributor

The test looks like it's been fixed at 368ecb4. Rebasing on top of master fixes the error.

@Jiaming-Liu
Copy link
Contributor Author

Jiaming-Liu commented May 18, 2017

umm... something is wrong while rebasing

I will try to solve it tomorrow Solved, but 591ea75 is still here Solved

@pytorchbot
Copy link
Collaborator

Can one of the admins verify this patch?

3 similar comments
@pytorchbot
Copy link
Collaborator

Can one of the admins verify this patch?

@pytorchbot
Copy link
Collaborator

Can one of the admins verify this patch?

@pytorchbot
Copy link
Collaborator

Can one of the admins verify this patch?

@soumith
Copy link
Member

soumith commented May 25, 2017

@pytorchbot test this please

@soumith soumith merged commit 630af4d into pytorch:master May 25, 2017
@szagoruyko
Copy link
Contributor

As far as I see there is only one optimizer being kept, so on learning rate drop all other parameters are also kept. How would one add momentum resetting on learning rate drops in SGD?

@Jiaming-Liu
Copy link
Contributor Author

@soumith Kindly mention this pr in some release note to increase visibility.

@soumith
Copy link
Member

soumith commented Aug 6, 2017

hey jiaming. I'm really sorry for missing this commit in the release notes. It looks like I missed 4 commits by mistake. I've updated the release notes now, and I've made a note for myself to check if repeating the note about learning rate schedules will be appropriate for the next release as well (to increase visibility)

@jtoy
Copy link

jtoy commented Aug 7, 2017

so this is in the new release 0.2.0 ? great!

@FuriouslyCurious
Copy link

@soumith My post on PyTorch forum about LR schedules is still getting more likes every week, so I think people are not aware of this PR. You should consider megaphoning this PR in future release notes.

https://discuss.pytorch.org/t/adaptive-learning-rate/320/10

>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> scheduler.step()

This comment was marked as off-topic.

eqy pushed a commit to eqy/pytorch that referenced this pull request Jan 20, 2022
* force segment un-connected graphs

* derive heuristic on empty groups

* add test

* lint

* handled aliased output in batchnorm

* empty tensor

* lint and comment

* clang format

* check reference tv available in pointwise scheduler

* comment

* cleanup test and check utils
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this pull request Nov 1, 2022
* fix typo

* Update test_pipeline_parallel_fwd_bwd.py
jithunnair-amd pushed a commit that referenced this pull request Mar 18, 2024
* Triton build conditionalized on ROCM_VERSION

(cherry picked from commit 1a7e1fa)

* Update pinned commit for rocm6.1 conditionalisation

---------

Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet