Skip to content

CUDA_KERNEL_LOOP: prevent int overflow in loop increment. #24818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

skrah
Copy link
Contributor

@skrah skrah commented Aug 18, 2019

Fixes #24309.

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Aug 18, 2019
@skrah
Copy link
Contributor Author

skrah commented Aug 18, 2019

The issue

Even for a valid array size, the loop variable in CUDA_KERNEL_LOOP() can overflow. For some reason the overflow is only visible with CUDA_LAUNCH_BLOCKING=1.

This is a minimal reproducer:

import torch

x = torch.randn(1, 1, 1, 1073741825, dtype=torch.float16, device="cuda:0")
torch.functional.F.avg_pool2d(x, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], count_include_pad=True)

master passes plain cuda-memcheck:

$ cuda-memcheck /home/stefan/rel-master/bin/python3 overflow.py 
========= CUDA-MEMCHECK
========= ERROR SUMMARY: 0 errors

master fails with CUDA_LAUNCH_BLOCKING=1 (invalid write):

$ CUDA_LAUNCH_BLOCKING=1 /home/stefan/rel-master/bin/python3 overflow.py 
Traceback (most recent call last):
  File "overflow.py", line 4, in <module>
    torch.functional.F.avg_pool2d(x, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], count_include_pad=True)
RuntimeError: avg_pool2d_out_cuda_frame failed with error code 0

Performance

@ngimel has recommended caution when using int64_t in kernels for indexing, but I'm not sure if that only applies to the actual indexing step or also to loop variables.

In this benchmark, the timings are the same before and after:

import torch

x = torch.randn(1, 1, 1, 1073741824, dtype=torch.float16, device="cuda:0")

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for i in range(10):
        torch.functional.F.avg_pool2d(x, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], count_include_pad=True)

print(prof)

@skrah skrah added open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 18, 2019
@skrah
Copy link
Contributor Author

skrah commented Aug 18, 2019

Incidentally, dim3 members are unsigned int, but everyone uses int or int64_t for indices:

struct __device_builtin__ dim3
{
    unsigned int x, y, z;
#if defined(__cplusplus)
    __host__ __device__ dim3(unsigned int vx = 1, unsigned int vy = 1, unsigned int vz = 1) : x(vx), y(vy), z(vz) {}
    __host__ __device__ dim3(uint3 v) : x(v.x), y(v.y), z(v.z) {}
    __host__ __device__ operator uint3(void) { uint3 t; t.x = x; t.y = y; t.z = z; return t; }
#endif /* __cplusplus */
};

@skrah skrah changed the title [WIP] CUDA_KERNEL_LOOP: prevent int overflow in loop increment. CUDA_KERNEL_LOOP: prevent int overflow in loop increment. Aug 18, 2019
facebook-github-bot pushed a commit that referenced this pull request Aug 19, 2019
Summary:
Spin-off from #24818.
Pull Request resolved: #24820

Differential Revision: D16890917

Pulled By: ezyang

fbshipit-source-id: 88df6d3ba98600acc810eda406daa1d850ed3320
@skrah
Copy link
Contributor Author

skrah commented Aug 19, 2019

@pytorchbot retest this please.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 19, 2019
Summary:
Spin-off from pytorch/pytorch#24818.
Pull Request resolved: pytorch/pytorch#24820

Differential Revision: D16890917

Pulled By: ezyang

fbshipit-source-id: 88df6d3ba98600acc810eda406daa1d850ed3320
@skrah skrah requested a review from ezyang August 19, 2019 19:17
@skrah
Copy link
Contributor Author

skrah commented Aug 19, 2019

@ezyang I guess this is ready for review, the failures should be unrelated. Thanks for merging #24820 so quickly!

@ezyang
Copy link
Contributor

ezyang commented Aug 19, 2019

I can see how the change you made avoids overflow, but I don't see how you get a correct result afterwards. If you always cast index to int then you are gonna truncate, one way or another. If I'm right, a test comparing cpu and cuda behavior should be sufficient evidence (which we don't seem to have in this PR.)

I think the traditional way we've solved this problem is have two versions of the kernel, one with 32-bit indexing and 64-bit indexing, and switch to the slower 64-bit kernel if applicable. Something like canUse32BitIndexMath

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this works.

@skrah
Copy link
Contributor Author

skrah commented Aug 19, 2019

It works because t.numel() fits in int, which is checked for:

const int count = safe_downcast<int, int64_t>(input.numel());

The overflow only matters in the last iteration, where e.g. index += 2**30 overflows to negative.
Then the loop condition is true and the loop continues with a negative index.

With int64_t the loop condition is false and the value that overflows in the cast is not used.

@skrah
Copy link
Contributor Author

skrah commented Aug 19, 2019

If the macro wasn't used, the cast would just be inside the loop:

for (int64_t _i = blockIdx.x * blockDim.x + threadIdx.x; _i < (n); _i += blockDim.x * gridDim.x) {
    int i = (int)_i;
    ...
}

This is not possible currently due to the way the macro is used. It would be cleaner in case nvcc has draconian UB elimination if an int overflows but is not used afterwards.

In that case we'd need to write out the loop in full and get rid of the macro.

@skrah
Copy link
Contributor Author

skrah commented Aug 19, 2019

Correction: of course signed casts are not UB, but are implementation-defined, so it should be fine.

@skrah
Copy link
Contributor Author

skrah commented Aug 20, 2019

@pytorchbot retest this please.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 29, 2019
Summary:
Fixes pytorch/pytorch#24309.
Pull Request resolved: pytorch/pytorch#24818

Differential Revision: D16927215

Pulled By: ezyang

fbshipit-source-id: aeab5226fec6045941399693479975db4542c79e
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in c845984.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cuda Related to torch.cuda, and CUDA support in general open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Illegal memory access occurs when using nn.AvgPool2d
6 participants