Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] torch.multinomial : fix for 0 size dim #43775

Closed

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Aug 28, 2020

Fixes #43768

TO-DO:

  • Add test

@dr-ci
Copy link

dr-ci bot commented Aug 28, 2020

💊 CI failures summary and remediations

As of commit 6e1ed7d (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 11 times.

probs = torch.ones(0, 3)
num_samples = 1
probs = torch.ones(0, 128, device=device)
num_samples = 64
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curiously, test passed with

probs = torch.ones(0, 3, device=device)
num_samples = 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it smth with replacement, as it fails only when num_samples > 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible. I haven't actually stepped through the exact kernel code where it is failing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test passes with replacement=False, so you have to separately test different replacement modes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> torch.multinomial(x, 1, True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: invalid configuration argument

The above error comes from

if (n_sample == 1 && maxShared >= requiredShared) {
// Optimized allocation-free implementation
// To exploit greater parallelism for the sampling, generate the
// Uniform random samples in a separate kernel launch, into
// temporarily allocated memory. The device RNG is thread-limited
Tensor sampled = native::empty_cuda({numDist, n_sample}, self_v.options());
at::native::uniform_(sampled, 0.0, 1.0, generator);
dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
sampleMultinomialOnce<scalar_t, accscalar_t>
<<<grid, block,
requiredShared,
at::cuda::getCurrentCUDAStream()>>>(
result.data_ptr<int64_t>(),
numDist,
numCategories,
sampled.data_ptr<scalar_t>(),
self_v.data_ptr<scalar_t>(),
self_v.stride(0),
self_v.stride(1)
);

For num_samples > 1,

>>> torch.multinomial(x, 2, True)
Floating point exception (core dumped)

if (with_replacement) {
// Binary search is warp divergent (so effectively we're running
// with just a single thread), but for better utilization,
// we need each block to have at least 4 warps.
dim3 block(128);
// Each block will generate a sample from one
// distribution concurrently.
int grid_y=std::min<int>(numDist, at::cuda::getCurrentDeviceProperties()->maxGridSize[1]);
dim3 grid((n_sample-1)/block.x+1, grid_y);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
// offset is 4 times that.
auto offset = ((numDist-1)/grid.y+1)*4;
rng_engine_inputs = gen->philox_engine_inputs(offset);
}

Here grid_y is 0 and thus in the snippet below, we get floating point exception due to Divide By Zero.

auto offset = ((numDist-1)/grid.y+1)*4;

The fix takes care of both these cases.

@kshitij12345
Copy link
Collaborator Author

@ngimel Please review :)

@ngimel
Copy link
Collaborator

ngimel commented Aug 28, 2020

Thanks for the fix. I think that fixing root cause would be better, otherwise it is probably triggered in other situations.

@malfet malfet added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 28, 2020
@malfet malfet requested a review from ngimel August 28, 2020 20:51
@ngimel
Copy link
Collaborator

ngimel commented Aug 28, 2020

It looks like the error is in multinomial_kernel_impl, so the fix is ok. Can you also please add missing

AT_CUDA_CHECK(cudaGetLastError());

at line 396 of MultinomialKernel.cu?
And test replacement=True and replacement=False.

@codecov
Copy link

codecov bot commented Aug 29, 2020

Codecov Report

Merging #43775 into master will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #43775   +/-   ##
=======================================
  Coverage   69.34%   69.34%           
=======================================
  Files         378      378           
  Lines       46698    46698           
=======================================
  Hits        32381    32381           
  Misses      14317    14317           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 58a7e73...6e1ed7d. Read the comment docs.

@kshitij12345
Copy link
Collaborator Author

Gentle Ping :)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 0394c5a.

@kshitij12345 kshitij12345 deleted the fix/multinomial/0-dist branch September 11, 2020 09:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cuda Related to torch.cuda, and CUDA support in general open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multinomial distribution of empty tensor crashes when sampling with replacement on CUDA
7 participants