Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary not operator causes crash when Jit module is executed on different device #17970

Open
lorenwel opened this issue Mar 13, 2019 · 7 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lorenwel
Copy link

lorenwel commented Mar 13, 2019

馃悰 Bug

When the binary not operator ~ is used in a module, which is then Jit traced, it causes a crash when the resulting Jit module is moved to and executed on a different device type (e.g. CPU -> GPU).
This crash only occurs in the Jit module, but not in eager mode.
Using 1-tensor to negate the ByteTensor does not show the same issue in Jit.

To Reproduce

import torch
import torch.nn as nn

class OneMinus(nn.Module):
  def forward(self, inp):
    mask = inp > 0.5
    return inp[1-mask]

class Not(nn.Module):
  def forward(self, inp):
    mask = inp > 0.5
    return inp[~mask]

inp = torch.rand((8,))

# Eager
out = OneMinus()(inp)   # Works
out = Not()(inp)    # Works

out = OneMinus().cuda()(inp.cuda())   # Works
out = Not().cuda()(inp.cuda())  # Works


# Jit
oneminus_jit = torch.jit.trace(OneMinus(), inp)
not_jit = torch.jit.trace(Not(), inp)

oneminus_jit(inp)   # Works
not_jit(inp)    # Works

oneminus_jit.cuda()(inp.cuda())   # Works
not_jit.cuda()(inp.cuda())  # Fails


# Jit Trace with Cuda
not_jit_cuda = torch.jit.trace(Not().cuda(), inp.cuda())
not_jit_cuda(inp.cuda())  # Works
not_jit_cuda.cpu()(inp)  # Fails

The error console output is:

Traceback (most recent call last):
  File "issue.py", line 32, in <module>
    not_jit.cuda()(inp.cuda())  # Fails
  File "/home/lorenwel/venv/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/lorenwel/venv/pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 1347, in forward
    return self._get_method('forward')(*args, **kwargs)
RuntimeError: 
expected type CPUByteType but got CUDAByteType (compute_types at /home/lorenwel/git/pytorch/aten/src/ATen/native/TensorIterator.cpp:134)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6c (0x7f6df917b02c in /home/lorenwel/git/pytorch/torch/lib/libc10.so)
frame #1: at::TensorIterator::compute_types() + 0xcb1 (0x7f6dd0db09b1 in /home/lorenwel/git/pytorch/torch/lib/libcaffe2.so)
frame #2: at::TensorIterator::Builder::build() + 0x5c (0x7f6dd0db679c in /home/lorenwel/git/pytorch/torch/lib/libcaffe2.so)
frame #3: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&) + 0x30a (0x7f6dd0db762a in /home/lorenwel/git/pytorch/torch/lib/libcaffe2.so)
frame #4: at::native::sub_out(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::Scalar) + 0xb2 (0x7f6dd0c31af2 in /home/lorenwel/git/pytorch/torch/lib/libcaffe2.so)
frame #5: at::TypeDefault::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) const + 0x8d (0x7f6dd0fb771d in /home/lorenwel/git/pytorch/torch/lib/libcaffe2.so)
frame #6: torch::autograd::VariableType::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) const + 0x306 (0x7f6dd3894b06 in /home/lorenwel/git/pytorch/torch/lib/libtorch.so.1)
frame #7: <unknown function> + 0x5f20a0 (0x7f6dd3b5a0a0 in /home/lorenwel/git/pytorch/torch/lib/libtorch.so.1)
frame #8: <unknown function> + 0x626ab5 (0x7f6dd3b8eab5 in /home/lorenwel/git/pytorch/torch/lib/libtorch.so.1)
frame #9: torch::jit::InterpreterState::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x31 (0x7f6dd3b88ed1 in /home/lorenwel/git/pytorch/torch/lib/libtorch.so.1)
frame #10: <unknown function> + 0x60b0d3 (0x7f6dd3b730d3 in /home/lorenwel/git/pytorch/torch/lib/libtorch.so.1)
frame #11: <unknown function> + 0x3cc9c8 (0x7f6dffa789c8 in /home/lorenwel/git/pytorch/torch/lib/libtorch_python.so)
frame #12: <unknown function> + 0x3adc76 (0x7f6dffa59c76 in /home/lorenwel/git/pytorch/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x10eb46 (0x7f6dff7bab46 in /home/lorenwel/git/pytorch/torch/lib/libtorch_python.so)
<omitting python frames>
frame #17: python() [0x5381b4]
frame #20: python() [0x574417]
frame #25: python() [0x574417]
frame #29: python() [0x5381b4]
frame #31: python() [0x57cb45]
frame #33: python() [0x574417]
frame #35: python() [0x5e8ba2]
frame #40: __libc_start_main + 0xe7 (0x7f6e09502b97 in /lib/x86_64-linux-gnu/libc.so.6)
:
operation failed in interpreter:
issue.py(12): forward
/home/lorenwel/venv/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py(477): _slow_forward
/home/lorenwel/venv/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py(487): __call__
/home/lorenwel/venv/pytorch/lib/python3.6/site-packages/torch/jit/__init__.py(636): trace
issue.py(26): <module>

Expected behavior

It should not crash.

Environment

Please copy and paste the output from our
[environment collection script]:

Collecting environment information...
PyTorch version: 1.0.0a0+743fdbd
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080
Nvidia driver version: 410.78
cuDNN version: Probably one of the following:
/usr/local/MATLAB/R2018a/bin/glnxa64/libcudnn.so.7.0.3
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7.4.2
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
  • PyTorch Version 1.0.1
  • Ubuntu 18.04
  • Compiled from source with python setup.py install
  • Python 3.6.8
  • CUDA 10.0, cuDNN 7.4.2
  • Nvidia RTX 2080

Additional context

The same issue also occurs when loading the failing jit module into libtorch and executing it there.

cc @suo

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 13, 2019
@eellison
Copy link
Contributor

Thanks for the report! we are looking into this

@Krovatkin
Copy link
Contributor

@eellison i'll take a look!

@Krovatkin
Copy link
Contributor

@lorenwel

We recommend tracing for CPU and GPU and keeping both versions:

Q: I would like to train a model on GPU and do inference on CPU. What are the best practices?

https://pytorch.org/docs/master/jit.html#frequently-asked-questions

OneMinus doesn't fail due to a lucky coincidence that it never allocates temporary storage on CPU, whereas Not does just that.

@lorenwel
Copy link
Author

Thanks for tracing the issue.

Is this something that might be fixed in the future or is it a more fundamental issue?

Our specific use-case is, that we train a model on GPU with pytorch and would then like to deploy it on robots using libtorch. The latter we achieve through JIT tracing, as suggested in the docs. The problem is, that we don't know a priori whether a GPU is installed or not, so we need to be compatible to both CPU and GPU. Shipping and maintaining two copies of the same model is not ideal for this.

@Krovatkin
Copy link
Contributor

Krovatkin commented Apr 18, 2019

@lorenwel would it be possible for you guys to convert your model to TorchScript with @torch.jit.script or @torch.jit.script_method? You might still be able to trace chunks of your model that are device-agnostic, although that may take some trial and error.

@suo Michael could you please weigh in on this question?

Is this something that might be fixed in the future or is it a more fundamental issue?

@suo
Copy link
Member

suo commented Apr 19, 2019

We can probably fix this particular case when tracing, but you shouldn't expect .cuda() or .cpu() to work well with traced models. Using scripting is our recommended solution for this kind of thing. It lets us recover more information about the model, so we can transparently support things like moving module hierarchies between cpu and gpu.

@lorenwel
Copy link
Author

Alright, thank you both.
I will try to use TorchScript whenever I cannot work around the tracing issue.

I guess we can close this issue then, since the JIT tracing behavior is described in the FAQ link you posted?

@wanchaol wanchaol added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 2, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants