New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added parameter range checks for all optimizers #6000
Conversation
torch/optim/sgd.py
Outdated
if not 0.0 <= lr: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if not 0.0 <= momentum <= 1.0: | ||
raise ValueError("Invalid momentum value: {}".format(momentum)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reviewed the rest of the ranges. Looks good to me.
@pytorchbot test this please |
@pytorchbot test this please |
@@ -343,6 +349,8 @@ def test_adagrad(self): | |||
self._build_params_dict(weight, bias, lr=1e-2), | |||
lr=1e-1) | |||
) | |||
with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"): | |||
optim.Adagrad(None, lr=1e-2, lr_decay=-0.5) | |||
|
|||
def test_adagrad_sparse(self): | |||
self._test_rosenbrock_sparse( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@pytorchbot test this please |
Can this be merged? |
Thank you @lazypanda1! |
This PR adds parameter range checks to all optimizers to ensure that end-users do not end up providing invalid values to the optimizers and be confused by the output when there is no actual problem with their model.
For example, running the following program produces
NaN
s in the output, due to invalid value ofrho
(>1.0).Output:
I tried adding constraints for all the parameters that I could infer from the corresponding articles, but I am still missing some. Please feel free to suggest what should be bound for the ones which are missing.
This is similar to the bounds check which I added for
Adam Optimizer
I can also add tests if needed.