Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added parameter range checks for all optimizers #6000

Merged
merged 4 commits into from Mar 28, 2018

Conversation

lazypanda1
Copy link
Contributor

@lazypanda1 lazypanda1 commented Mar 26, 2018

This PR adds parameter range checks to all optimizers to ensure that end-users do not end up providing invalid values to the optimizers and be confused by the output when there is no actual problem with their model.

For example, running the following program produces NaNs in the output, due to invalid value of rho (>1.0).

import torch
from torch.autograd import Variable

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate, rho=1.1)
for t in range(2):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Output:

0 651.8707885742188
1 nan

I tried adding constraints for all the parameters that I could infer from the corresponding articles, but I am still missing some. Please feel free to suggest what should be bound for the ones which are missing.

This is similar to the bounds check which I added for Adam Optimizer

I can also add tests if needed.

if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= momentum <= 1.0:
raise ValueError("Invalid momentum value: {}".format(momentum))

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@vishwakftw vishwakftw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reviewed the rest of the ranges. Looks good to me.

@ezyang
Copy link
Contributor

ezyang commented Mar 26, 2018

@pytorchbot test this please

@apaszke
Copy link
Contributor

apaszke commented Mar 26, 2018

@pytorchbot test this please

@@ -343,6 +349,8 @@ def test_adagrad(self):
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1)
)
with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"):
optim.Adagrad(None, lr=1e-2, lr_decay=-0.5)

def test_adagrad_sparse(self):
self._test_rosenbrock_sparse(

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Mar 27, 2018

@pytorchbot test this please

@lazypanda1
Copy link
Contributor Author

Can this be merged?

@apaszke apaszke merged commit 063946d into pytorch:master Mar 28, 2018
@apaszke
Copy link
Contributor

apaszke commented Mar 28, 2018

Thank you @lazypanda1!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants