Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit.freeze'd models cannot be moved to GPU with .to() #57569

Closed
Linux-cpp-lisp opened this issue May 4, 2021 · 3 comments
Closed

torch.jit.freeze'd models cannot be moved to GPU with .to() #57569

Linux-cpp-lisp opened this issue May 4, 2021 · 3 comments
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@Linux-cpp-lisp
Copy link

Linux-cpp-lisp commented May 4, 2021

🐛 Bug

To Reproduce

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("twos", 2.0 * torch.ones(3))
    def forward(self, x):
        return self.twos * x

m = M().eval()
m(torch.ones(3))
ms = torch.jit.script(m)
ms = torch.jit.freeze(ms)
ms = ms.to("cuda")
ms(torch.ones(3))  # => tensor([2., 2., 2.])
ms(torch.ones(3).cuda())

Gives:

---------------------------------------------------------------------------
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<ipython-input-2-00f395df9a0d>", line 6, in forward
    def forward(self, x):
        return self.twos * x
               ~~~~~~~~~~~~~ <--- HERE
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Expected behavior

This works without torch.jit.freeze:

m = M().eval()
m(torch.ones(3))
ms = torch.jit.script(m)
ms = ms.to("cuda")
ms(torch.ones(3).cuda())  # => tensor([2., 2., 2.], device='cuda:0')

It also works if you freeze after moving the model to device:

m = M().eval()
m(torch.ones(3))
ms = torch.jit.script(m)
ms = ms.to("cuda")
ms = torch.jit.freeze(ms)
ms(torch.ones(3).cuda())  # => tensor([2., 2., 2.], device='cuda:0')

Environment

  • PyTorch Version (e.g., 1.0): 1.8.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.8.5
  • GPU models and configuration: V100

Additional context

Interestingly, this can be circumvented by using map_location in torch.jit.load instead of .to() to move the model:

m = M().eval()
ms = torch.jit.script(m)
ms = torch.jit.freeze(ms)
torch.jit.save(ms, "tmp.pth")
ms = torch.jit.load("tmp.pth", map_location="cuda")
ms(torch.ones(3).cuda())  # => tensor([2., 2., 2.], device='cuda:0')

cc @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label May 4, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage May 4, 2021
@gmagogsfm gmagogsfm added this to To do in JIT Performance via automation May 4, 2021
@gmagogsfm gmagogsfm removed this from Need triage in JIT Triage May 4, 2021
@eellison
Copy link
Contributor

eellison commented May 4, 2021

@Linux-cpp-lisp this is currently expected behavior. We need to be able to inline the attributes as constants in order to do anything useful in optimizing them. There is also nothing preventing the user from having device-specific logic we also bake in.

    def forward(self, x):
       if self.twos.device.is_cuda():
              ....

Models might also contain some CPU & some GPU compute. However, as you've shown, there are many models where it is completely valid to remap devices after freezing.

Can I ask what the specific use case is ?

@Linux-cpp-lisp
Copy link
Author

That makes sense @eellison, thanks!

Our specific use-case is compiling and freezing models for inference on both GPU and CPU. In our scientific ML context users might have/use one or the other and the hope was to be able to have a single deployed, frozen TorchScript file that could be loaded and used for both GPU and CPU inference. For that, map_location seems to do the trick — .to() was just the first thing I reached for to implement this.

Given what you said, it does seem like this isn't really a bug but more a hole in the docs. It would be great if there was some information on the page for torch.jit.freeze mentioning that constants are fixed to their device at freeze time, that .to() is the wrong way to modify that, etc.

@eellison
Copy link
Contributor

I updated the docs, closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
Development

No branches or pull requests

3 participants