Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] torch.multinomial : fix for 0 size dim #43775

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/native/Distributions.cpp
Expand Up @@ -454,6 +454,7 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo
if (self.dim() > 1) {
int64_t n_dist = self.size(-2);
result.resize_({n_dist, n_sample});
if (n_dist == 0) { return result; };
} else {
result.resize_({n_sample});
}
Expand Down
4 changes: 2 additions & 2 deletions test/test_torch.py
Expand Up @@ -17797,8 +17797,8 @@ def test(probs, replacement):
test(z, True)

def test_multinomial_empty(self, device):
probs = torch.ones(0, 3)
num_samples = 1
probs = torch.ones(0, 128, device=device)
num_samples = 64
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curiously, test passed with

probs = torch.ones(0, 3, device=device)
num_samples = 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it smth with replacement, as it fails only when num_samples > 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible. I haven't actually stepped through the exact kernel code where it is failing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test passes with replacement=False, so you have to separately test different replacement modes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> torch.multinomial(x, 1, True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: invalid configuration argument

The above error comes from

if (n_sample == 1 && maxShared >= requiredShared) {
// Optimized allocation-free implementation
// To exploit greater parallelism for the sampling, generate the
// Uniform random samples in a separate kernel launch, into
// temporarily allocated memory. The device RNG is thread-limited
Tensor sampled = native::empty_cuda({numDist, n_sample}, self_v.options());
at::native::uniform_(sampled, 0.0, 1.0, generator);
dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
sampleMultinomialOnce<scalar_t, accscalar_t>
<<<grid, block,
requiredShared,
at::cuda::getCurrentCUDAStream()>>>(
result.data_ptr<int64_t>(),
numDist,
numCategories,
sampled.data_ptr<scalar_t>(),
self_v.data_ptr<scalar_t>(),
self_v.stride(0),
self_v.stride(1)
);

For num_samples > 1,

>>> torch.multinomial(x, 2, True)
Floating point exception (core dumped)

if (with_replacement) {
// Binary search is warp divergent (so effectively we're running
// with just a single thread), but for better utilization,
// we need each block to have at least 4 warps.
dim3 block(128);
// Each block will generate a sample from one
// distribution concurrently.
int grid_y=std::min<int>(numDist, at::cuda::getCurrentDeviceProperties()->maxGridSize[1]);
dim3 grid((n_sample-1)/block.x+1, grid_y);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
// offset is 4 times that.
auto offset = ((numDist-1)/grid.y+1)*4;
rng_engine_inputs = gen->philox_engine_inputs(offset);
}

Here grid_y is 0 and thus in the snippet below, we get floating point exception due to Divide By Zero.

auto offset = ((numDist-1)/grid.y+1)*4;

The fix takes care of both these cases.

expected = torch.empty(0, num_samples, dtype=torch.int64)
for replacement in (True, False):
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
Expand Down