Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault upon calling torch.Tensor on a gpu tensor #33899

Closed
ben-heil opened this issue Feb 27, 2020 · 7 comments
Closed

Segfault upon calling torch.Tensor on a gpu tensor #33899

ben-heil opened this issue Feb 27, 2020 · 7 comments
Assignees
Labels
high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ben-heil
Copy link

ben-heil commented Feb 27, 2020

🐛 Bug

I accidentally called torch.Tensor on a Tensor object, thinking it was a numpy array. Instead of doing nothing or throwing a warning, the program segfaulted.

To Reproduce

Minimal code example:

import torch
import numpy as np
device = torch.device('cuda')
a = torch.Tensor(np.random.normal(size=10)).to(device)
torch.Tensor(a)

Upon running the final line, python prints "Segmentation fault (core dumped)" and dies.

Expected behavior

I would have expected casting a Tensor to a Tensor to have to result

Environment

Collecting environment information...
PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: 10.1.243

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2060
Nvidia driver version: 430.50
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] numpydoc==0.9.1
[pip] torch==1.3.1
[pip] torchfile==0.1.0
[pip] torchvision==0.4.2
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.0.2 py37h7b6447c_0
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.3.1 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchvision 0.4.2 pypi_0 pypi

Additional context

I was able to reproduce the issue running on google CoLab, so I don't think the issue is specific to my computer

cc @ezyang @gchanan @zou3519 @bdhirsh @heitorschueroff @ngimel

@ptrblck
Copy link
Collaborator

ptrblck commented Feb 28, 2020

Backtrace for:

import torch

a = torch.tensor(1, device='cuda')
b = torch.Tensor(a)
print(b)
#0  __strlen_avx2 () at ../sysdeps/x86_64/multiarch/strlen-avx2.S:62
#1  0x00007ffff781fa45 in printf_positional (s=s@entry=0x7fffffffcff0, format=format@entry=0x7fffbf982d83 "expected %s (got %s)", readonly_format=readonly_format@entry=0, ap=ap@entry=0x7fffffffd5a8, 
    ap_savep=ap_savep@entry=0x7fffffffcb98, done=done@entry=9, nspecs_done=0, lead_str_end=<optimized out>, work_buffer=<optimized out>, save_errno=<optimized out>, grouping=<optimized out>, 
    thousands_sep=<optimized out>) at vfprintf.c:2023
#2  0x00007ffff7821cba in _IO_vfprintf_internal (s=s@entry=0x7fffffffcff0, format=format@entry=0x7fffbf982d83 "expected %s (got %s)", ap=ap@entry=0x7fffffffd5a8) at vfprintf.c:1688
#3  0x00007ffff78f7169 in ___vsnprintf_chk (s=0x7fffffffd170 "expected \321\377\377\377\177", maxlen=<optimized out>, flags=1, slen=<optimized out>, format=0x7fffbf982d83 "expected %s (got %s)", 
    args=0x7fffffffd5a8) at vsnprintf_chk.c:63
#4  0x00007fffbf34dc6f in torch::formatMessage(char const*, __va_list_tag*) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#5  0x00007fffbf34f5d1 in torch::TypeError::TypeError(char const*, ...) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#6  0x00007fffbf6406f6 in torch::utils::(anonymous namespace)::new_with_tensor(c10::DispatchKey, c10::ScalarType, at::Tensor const&) ()
   from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#7  0x00007fffbf645b38 in torch::utils::legacy_tensor_ctor(c10::DispatchKey, c10::ScalarType, _object*, _object*) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#8  0x00007fffbf47364e in THPVariable_pynew(_typeobject*, _object*, _object*) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so

It seems that the code tries to throw the TypeError from here, but segfaults on strlen-avx2.

EDIT: seems to be related to #25518

@vincentqb
Copy link
Contributor

Have you tried the workaround you mentioned here?

@vincentqb vincentqb added module: cuda Related to torch.cuda, and CUDA support in general high priority labels Feb 28, 2020
@vincentqb vincentqb added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Feb 28, 2020
@vincentqb
Copy link
Contributor

Marking as high priority since it can be reproduced, and leads to segmentation fault.

@peterbell10 peterbell10 self-assigned this Feb 29, 2020
ttumiel pushed a commit to ttumiel/pytorch that referenced this issue Mar 4, 2020
…h#34019)

Summary:
Fixes pytorch#33899

In the issue, we have
```
TypeError("expected %s (got %s)", dispatch_key, toString(other.key_set()).c_str());
```
which results in `dispatch_key` being interpreted as a c-string by `sprintf`. Adding `__attrbute__((format))` to the `TypeError` constructor allows gcc or clang to detect this at compile time. Then `-Werror=format` makes it a hard error at compile time.
Pull Request resolved: pytorch#34019

Differential Revision: D20194842

Pulled By: ezyang

fbshipit-source-id: fa4448916c309d91e3d949fa65bb3aa7cca5c6a8
@PaulZhangIsing
Copy link

similar issue encounter here

@ChuanchuanZheng
Copy link

Backtrace for:

import torch

a = torch.tensor(1, device='cuda')
b = torch.Tensor(a)
print(b)
#0  __strlen_avx2 () at ../sysdeps/x86_64/multiarch/strlen-avx2.S:62
#1  0x00007ffff781fa45 in printf_positional (s=s@entry=0x7fffffffcff0, format=format@entry=0x7fffbf982d83 "expected %s (got %s)", readonly_format=readonly_format@entry=0, ap=ap@entry=0x7fffffffd5a8, 
    ap_savep=ap_savep@entry=0x7fffffffcb98, done=done@entry=9, nspecs_done=0, lead_str_end=<optimized out>, work_buffer=<optimized out>, save_errno=<optimized out>, grouping=<optimized out>, 
    thousands_sep=<optimized out>) at vfprintf.c:2023
#2  0x00007ffff7821cba in _IO_vfprintf_internal (s=s@entry=0x7fffffffcff0, format=format@entry=0x7fffbf982d83 "expected %s (got %s)", ap=ap@entry=0x7fffffffd5a8) at vfprintf.c:1688
#3  0x00007ffff78f7169 in ___vsnprintf_chk (s=0x7fffffffd170 "expected \321\377\377\377\177", maxlen=<optimized out>, flags=1, slen=<optimized out>, format=0x7fffbf982d83 "expected %s (got %s)", 
    args=0x7fffffffd5a8) at vsnprintf_chk.c:63
#4  0x00007fffbf34dc6f in torch::formatMessage(char const*, __va_list_tag*) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#5  0x00007fffbf34f5d1 in torch::TypeError::TypeError(char const*, ...) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#6  0x00007fffbf6406f6 in torch::utils::(anonymous namespace)::new_with_tensor(c10::DispatchKey, c10::ScalarType, at::Tensor const&) ()
   from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#7  0x00007fffbf645b38 in torch::utils::legacy_tensor_ctor(c10::DispatchKey, c10::ScalarType, _object*, _object*) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#8  0x00007fffbf47364e in THPVariable_pynew(_typeobject*, _object*, _object*) () from /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so

It seems that the code tries to throw the TypeError from here, but segfaults on strlen-avx2.

EDIT: seems to be related to #25518

@ptrblck Could you show how to solve the segmentation fault with your test code?
Recently, I encountered the same fault when moving model to cuda (modle.cuda()). Great Thanks.

@ngimel ngimel added the module: crash Problem manifests as a hard crash, as opposed to a RuntimeError label Oct 29, 2020
@ptrblck
Copy link
Collaborator

ptrblck commented Oct 29, 2020

@ChuanchuanZheng Which PyTorch version are you using?
This error was reported for an older PyTorch release and should be fixed by now.
For 1.7 I get:

>>> torch.__version__
'1.7.0'
>>> a = torch.tensor(1, device='cuda')
>>> b = torch.Tensor(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: expected CPU (got CUDA)

Trying to work around this error yields:

>>> b = torch.Tensor(a, device='cuda')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: legacy constructor for device type: cpu was passed device type: cuda, but device type must be: cpu

Using the suggested cpu flag yields an undefined tensor:

a = torch.tensor(1, device='cuda')
b = torch.Tensor(a, device='cpu')
print(b)
> tensor([8.9683e-44])
print(a)
> tensor(1, device='cuda:0')

While the original error cannot be reproduced, we might face a new one.
My code snippet shows wrong usage, but I'm afraid users might run into this issue nevertheless.
@ngimel Let me know, if I should create a new issue or if we want to track and fix it here.

@ngimel
Copy link
Collaborator

ngimel commented Oct 29, 2020

Thanks for digging in @ptrblck! Please open a new issue, this is different from original.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants