New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] torch.multinomial : fix for 0 size dim #43775
Conversation
💊 CI failures summary and remediationsAs of commit 6e1ed7d (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 11 times. |
test/test_torch.py
Outdated
probs = torch.ones(0, 3) | ||
num_samples = 1 | ||
probs = torch.ones(0, 128, device=device) | ||
num_samples = 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curiously, test passed with
probs = torch.ones(0, 3, device=device)
num_samples = 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it smth with replacement, as it fails only when num_samples > 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible. I haven't actually stepped through the exact kernel code where it is failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test passes with replacement=False, so you have to separately test different replacement modes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> torch.multinomial(x, 1, True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: invalid configuration argument
The above error comes from
pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu
Lines 320 to 342 in 58148c8
if (n_sample == 1 && maxShared >= requiredShared) { | |
// Optimized allocation-free implementation | |
// To exploit greater parallelism for the sampling, generate the | |
// Uniform random samples in a separate kernel launch, into | |
// temporarily allocated memory. The device RNG is thread-limited | |
Tensor sampled = native::empty_cuda({numDist, n_sample}, self_v.options()); | |
at::native::uniform_(sampled, 0.0, 1.0, generator); | |
dim3 block(numCategories < maxThreads ? numCategories : maxThreads); | |
dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); | |
sampleMultinomialOnce<scalar_t, accscalar_t> | |
<<<grid, block, | |
requiredShared, | |
at::cuda::getCurrentCUDAStream()>>>( | |
result.data_ptr<int64_t>(), | |
numDist, | |
numCategories, | |
sampled.data_ptr<scalar_t>(), | |
self_v.data_ptr<scalar_t>(), | |
self_v.stride(0), | |
self_v.stride(1) | |
); |
For num_samples > 1
,
>>> torch.multinomial(x, 2, True)
Floating point exception (core dumped)
pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu
Lines 364 to 383 in 58148c8
if (with_replacement) { | |
// Binary search is warp divergent (so effectively we're running | |
// with just a single thread), but for better utilization, | |
// we need each block to have at least 4 warps. | |
dim3 block(128); | |
// Each block will generate a sample from one | |
// distribution concurrently. | |
int grid_y=std::min<int>(numDist, at::cuda::getCurrentDeviceProperties()->maxGridSize[1]); | |
dim3 grid((n_sample-1)/block.x+1, grid_y); | |
{ | |
// See Note [Acquire lock when using random generators] | |
std::lock_guard<std::mutex> lock(gen->mutex_); | |
// each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use | |
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]), | |
// offset is 4 times that. | |
auto offset = ((numDist-1)/grid.y+1)*4; | |
rng_engine_inputs = gen->philox_engine_inputs(offset); | |
} |
Here grid_y
is 0
and thus in the snippet below, we get floating point exception due to Divide By Zero.
auto offset = ((numDist-1)/grid.y+1)*4; |
The fix takes care of both these cases.
@ngimel Please review :) |
Thanks for the fix. I think that fixing root cause would be better, otherwise it is probably triggered in other situations. |
It looks like the error is in multinomial_kernel_impl, so the fix is ok. Can you also please add missing
at line 396 of MultinomialKernel.cu? |
Codecov Report
@@ Coverage Diff @@
## master #43775 +/- ##
=======================================
Coverage 69.34% 69.34%
=======================================
Files 378 378
Lines 46698 46698
=======================================
Hits 32381 32381
Misses 14317 14317 Continue to review full report at Codecov.
|
Gentle Ping :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #43768
TO-DO: