Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Parameters of a model after .cuda() will be different objects with those before the call." is wrong. #7844

Closed
freud14 opened this issue May 25, 2018 · 8 comments

Comments

@freud14
Copy link

commented May 25, 2018

Hi,

In the documentation, it is written:

If you need to move a model to GPU via .cuda(), please do so before constructing optimizers for it. Parameters of a model after .cuda() will be different objects with those before the call.

In general, you should make sure that optimized parameters live in consistent locations when optimizers are constructed and used.

However, doing .cuda() after intialiazing the optimizer still works. This is because the Module class applies the .cuda() in this way:

param.data = fn(param.data)
if param._grad is not None:
    param._grad.data = fn(param._grad.data)

Thus, by modifying the .data attribute, it modifies the parameter tensors in-place.

I then suggest to remove this "warning" from the documentation since I actually find this quite useful to be able to initialize the optimizer before doing .cuda().

Thank you.

Frédérik

@SsnL

This comment has been minimized.

Copy link
Collaborator

commented May 25, 2018

The particular sentence you referenced is wrong, but you still can't initialize optimizer before moving modules in general

>>> l = torch.nn.Linear(3, 3)
>>> o = torch.optim.Adagrad(l.parameters())
>>> l.cuda()
Linear(in_features=3, out_features=3, bias=True)
>>> l(torch.randn(1,3,device='cuda')).sum().backward()
>>> o.step()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ssnl/sftp/pytorch/torch/optim/adagrad.py", line 92, in step
    state['sum'].addcmul_(1, grad, grad)
RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #4 'tensor1'
@freud14

This comment has been minimized.

Copy link
Author

commented May 25, 2018

It seems to be a problem only for Adagrad because the other optimizers work as expected. The other optimizers do if len(state) == 0 in their step method.

@zou3519 zou3519 added the todo label May 29, 2018
@gchanan

This comment has been minimized.

Copy link
Contributor

commented Jun 5, 2019

CC @yf225.

@fmassa

This comment has been minimized.

Copy link
Member

commented Jun 6, 2019

I'm not sure this is a behavior that we want to keep for the optimizers.

Moving tensors to cuda is not an in-place operation, so the references to the tensor should not be the same once we move the tensors to cuda.

For me, if some optimizers worked after moving the tensors to the GPU, that was an unintentional behavior and should not be relied upon.

@apaszke

This comment has been minimized.

Copy link
Member

commented Jun 6, 2019

Yes, while some optimizers might be implemented in such a way that works (there's nothing wrong with that), we absolutely do not guarantee that this will be the case if you change them, or even if you download a newer version of PyTorch. I'd like us to keep this restriction.

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jun 10, 2019

@yf225 can you comment here? This is very related to your Variabble changes.

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jun 10, 2019

@gchanan says this also depends on which optimizer you use

@yf225

This comment has been minimized.

Copy link
Contributor

commented Jun 19, 2019

If you need to move a model to GPU via .cuda(), please do so before constructing optimizers for it. Parameters of a model after .cuda() will be different objects with those before the call.

In general, you should make sure that optimized parameters live in consistent locations when optimizers are constructed and used.

Starting from #21613, the new behavior we will have in future releases is consistent with this warning, which is that parameters of a model after dtype/device conversion functions such as .cuda()/.cpu()/.to()/.float()/.double() will be different objects with those before the call (you can enable this new behavior by setting torch.__future__.set_overwrite_module_params_on_conversion(True). Hence we strongly recommend converting the model to a different device / dtype before constructing optimizers for it.

@yf225 yf225 closed this Jun 19, 2019
freud14 added a commit to GRAAL-Research/poutyne that referenced this issue Jun 22, 2019
freud14 added a commit to GRAAL-Research/poutyne that referenced this issue Jun 25, 2019
…imizer (#37)

See issue pytorch/pytorch#7844 and PR pytorch/pytorch#21613.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
9 participants
You can’t perform that action at this time.