-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Runtime error using nearest neighbour upsampling on tensor with channels-last memory layout #81665
Comments
Fallback to contiguous memory layout before upscale, as a workaround for pytorch/pytorch#81665. The condition (batch dimension >= 32) works for my tests, but it might not be general enough. Further analysis required.
Thank you for raising this issue. We can probably add a pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu Lines 83 to 95 in 35d4a80
|
Sounds great! Is there anything I can do to help? |
Any updates on this? 👀 |
#81665 CC @ngimel @ptrblck Pull Request resolved: #87901 Approved by: https://github.com/ngimel
+1 have run into this issue as well! |
Same here |
Fixed by #87901 |
) pytorch#81665 CC @ngimel @ptrblck Pull Request resolved: pytorch#87901 Approved by: https://github.com/ngimel
) pytorch#81665 CC @ngimel @ptrblck Pull Request resolved: pytorch#87901 Approved by: https://github.com/ngimel
same here |
@malfet I am running into the same issue with PyTorch 1.12 but with Bilinear upsampling. |
@gchhablani Update to the latest stable or nightly release as the fix seems to be in 1.13.1+. |
I seem to still have this issue on pytorch 2.0.1, not sure how :D |
Same here, but how does .contiguous() solve the issue? |
Probably because it converts the format to bchw internally. Also the related pull request only fixed nearest neighbour, but doesnt fix bilinear etc. Should be the same for them though. |
Still doesnt fix bilinear. |
🐛 Describe the bug
torch.nn.functional.interpolate
fails with aRuntimeError
when the following conditions are met:channels_last
memory format.The following code works fine, producing a tensor with the expected shape
[31, 64, 1024, 1024]
:However, when the input batch dimension is
32
or larger, it fails:If the memory layout is contiguous rather than channels last, it works fine too:
The error is raised here. I'm not sure about the details, but I think a potential workaround could be to automatically revert to contiguous format, rather than failing.
Versions
Collecting environment information...
PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 3090
Nvidia driver version: 470.42.01
cuDNN version: Probably one of the following:
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.1+cu113
[pip3] torchaudio==0.10.1+cu113
[pip3] torchvision==0.11.2+cu113
[conda] numpy 1.20.3 pypi_0 pypi
[conda] torch 1.10.1+cu113 pypi_0 pypi
[conda] torchaudio 0.10.1+cu113 pypi_0 pypi
[conda] torchvision 0.11.2+cu113 pypi_0 pypi
I was using the RTX 3090 in this test. I have observed the same behaviour using other cards (RTX A6000, for instance) in other systems running different versions of Python, PyTorch and OS.
cc @ezyang @gchanan @zou3519 @ngimel @VitalyFedyunin @jamesr66a @csarofeen @ptrblck @xwang233
The text was updated successfully, but these errors were encountered: