-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segmentation fault during ONNX exportation. #49959
Comments
Grabbing for myself, as I have a symbolicated backtrace:
|
Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash. Fixes SIGSEGV reported in pytorch#49959
Summary: Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash. Fixes SIGSEGV reported in #49959 Pull Request resolved: #50237 Reviewed By: walterddr Differential Revision: D25839675 Pulled By: malfet fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab
Fixed crash, but underlying transformation of tensor lists (i.e. |
Summary: Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash. Fixes SIGSEGV reported in pytorch#49959 Pull Request resolved: pytorch#50237 Reviewed By: walterddr Differential Revision: D25839675 Pulled By: malfet fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab
@SplitInfinity , @BowenBao do you know if the underlying issue has been fixed? And should we remove a high-pri as it no longer causing crash? |
Looking into it. This is a trickier case of #81386 |
The generic output = []
for i in range(prediction.shape[0]):
output.append(torch.zeros((0, 6), device=prediction.device)) In ONNX, each operator must have a type while a As the crash doesn't happen anymore, I have remove the high pri label and will close this issue. |
馃悰 Bug
Run into segmentation fault during ONNX exportation.
Segmentation call stack:
To Reproduce
I'm sorry, but I haven't found easier steps to reproduce this segfault. Steps to reproduce the behavior:
git clone https://github.com/ultralytics/yolov5.git cd yolov5 git checkout 73cf75faa848cb
test_script.py
intomodels
directorypython3 test_script.py --weights /path/to/yolov5.pt --test-image /path/to/a/640x640/color/image.jpg --output /tmp/foo.onnx
Expected behavior
Not seg fault.
Environment
Collecting environment information...
PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 1060 6GB
Nvidia driver version: 450.66
cuDNN version: /usr/local/cuda-10.0/lib64/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.4
[pip3] torch==1.7.1
[pip3] torchvision==0.8.2
[conda] Could not collect
Additional context
test_script.py
:cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof
The text was updated successfully, but these errors were encountered: