Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime error using nearest neighbour upsampling on tensor with channels-last memory layout #81665

Closed
pcuenca opened this issue Jul 18, 2022 · 13 comments
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general module: interpolation module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pcuenca
Copy link

pcuenca commented Jul 18, 2022

🐛 Describe the bug

torch.nn.functional.interpolate fails with a RuntimeError when the following conditions are met:

  • The input tensor uses the channels_last memory format.
  • The input shape is larger than a certain threshold.

The following code works fine, producing a tensor with the expected shape [31, 64, 1024, 1024]:

x = torch.rand((31, 64, 512, 512)).cuda().to(memory_format=torch.channels_last)
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest').shape
torch.Size([31, 64, 1024, 1024])

However, when the input batch dimension is 32 or larger, it fails:

x = torch.rand((32, 64, 512, 512)).cuda().to(memory_format=torch.channels_last)
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest').shape
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements

If the memory layout is contiguous rather than channels last, it works fine too:

x = torch.rand((32, 64, 512, 512)).cuda()
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest').shape
torch.Size([32, 64, 1024, 1024])

The error is raised here. I'm not sure about the details, but I think a potential workaround could be to automatically revert to contiguous format, rather than failing.

Versions

Collecting environment information...
PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 3090

Nvidia driver version: 470.42.01
cuDNN version: Probably one of the following:
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.1+cu113
[pip3] torchaudio==0.10.1+cu113
[pip3] torchvision==0.11.2+cu113
[conda] numpy 1.20.3 pypi_0 pypi
[conda] torch 1.10.1+cu113 pypi_0 pypi
[conda] torchaudio 0.10.1+cu113 pypi_0 pypi
[conda] torchvision 0.11.2+cu113 pypi_0 pypi


I was using the RTX 3090 in this test. I have observed the same behaviour using other cards (RTX A6000, for instance) in other systems running different versions of Python, PyTorch and OS.

cc @ezyang @gchanan @zou3519 @ngimel @VitalyFedyunin @jamesr66a @csarofeen @ptrblck @xwang233

pcuenca added a commit to pcuenca/SwinIR that referenced this issue Jul 18, 2022
Fallback to contiguous memory layout before upscale, as a workaround for
pytorch/pytorch#81665.

The condition (batch dimension >= 32) works for my tests, but it might
not be general enough. Further analysis required.
@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: memory format Memory format/layout related issues/changes (channels_last, nhwc) labels Jul 20, 2022
@VitalyFedyunin VitalyFedyunin added module: cudnn Related to torch.backends.cudnn, and CuDNN support module: cuda Related to torch.cuda, and CUDA support in general labels Jul 20, 2022
@xwang233 xwang233 added module: interpolation and removed module: cudnn Related to torch.backends.cudnn, and CuDNN support labels Jul 20, 2022
@xwang233
Copy link
Collaborator

Thank you for raising this issue. We can probably add a index_t template function with int64_t indexing here with large tensors.

template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest2d_nhwc_out_frame(
const scalar_t* idata,
scalar_t* odata,
const size_t channels,
const size_t height1,
const size_t width1,
const size_t height2,
const size_t width2,
float height_scale,
float width_scale,
const size_t out_numel) {

@pcuenca
Copy link
Author

pcuenca commented Jul 24, 2022

Sounds great! Is there anything I can do to help?

@NouamaneTazi
Copy link

Any updates on this? 👀

@rtaori
Copy link

rtaori commented Oct 29, 2022

+1 have run into this issue as well!

@hadaev8
Copy link

hadaev8 commented Oct 31, 2022

Same here

@malfet
Copy link
Contributor

malfet commented Oct 31, 2022

Fixed by #87901
Not a regression, if we are doing 1.13.1 tentatively we should pick this one fix into the branch

@yyt-2378
Copy link

same here

@gchhablani
Copy link

@malfet I am running into the same issue with PyTorch 1.12 but with Bilinear upsampling.

@ptrblck
Copy link
Collaborator

ptrblck commented Jul 21, 2023

@gchhablani Update to the latest stable or nightly release as the fix seems to be in 1.13.1+.

@Parskatt
Copy link

I seem to still have this issue on pytorch 2.0.1, not sure how :D
Also, seem unable to reproduce it on other machines so I just added a .contiguous() in the interpolate as a workaround.

@SMSD75
Copy link

SMSD75 commented Aug 14, 2023

I seem to still have this issue on pytorch 2.0.1, not sure how :D Also, seem unable to reproduce it on other machines so I just added a .contiguous() in the interpolate as a workaround.

Same here, but how does .contiguous() solve the issue?

@Parskatt
Copy link

Probably because it converts the format to bchw internally. Also the related pull request only fixed nearest neighbour, but doesnt fix bilinear etc. Should be the same for them though.

@cjissmart
Copy link

Probably because it converts the format to bchw internally. Also the related pull request only fixed nearest neighbour, but doesnt fix bilinear etc. Should be the same for them though.

Still doesnt fix bilinear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general module: interpolation module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests