Skip to content

torch.distributions.Categorical.sample uses unnecessary huge amount of memory #34714

@jjabo

Description

@jjabo

🐛 Bug

When sampling from a torch.distributions.Categorical distribution initialized with a few distributions (e.g. 2 distributions each with 2**16 categories) the sample() functions allocates unnecessary huge amount of memory.

To Reproduce

import torch 
p = torch.ones([2, 2**16], device='cuda:0')
dist = torch.distributions.Categorical(p)
dist.sample([2**16])

Output:
Traceback (most recent call last):
File ".../lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3326, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 4, in
dist.sample([2**16])
File ".../lib/python3.7/site-packages/torch/distributions/categorical.py", line 106, in sample
probs_2d = probs.reshape(-1, self._num_events)
RuntimeError: CUDA out of memory. Tried to allocate 32.00 GiB (GPU 0; 1.95 GiB total capacity; 1024.00 KiB already allocated; 1.15 GiB free; 2.00 MiB reserved in total by PyTorch)

Expected behavior

Looking into the torch.distributions.Categorical source shows:

    def sample(self, sample_shape=torch.Size()):
        sample_shape = self._extended_shape(sample_shape)
        param_shape = sample_shape + torch.Size((self._num_events,))
        probs = self.probs.expand(param_shape)
        probs_2d = probs.reshape(-1, self._num_events)
        sample_2d = torch.multinomial(probs_2d, 1, True)
        return sample_2d.reshape(sample_shape)

The last three lines allocate the new memory and use it. probs_2d becomes a potentially huge tensor with repeated information.
The following code snippet solves it and uses much less memory, but I'm not sure it covers all cases and whether the order of the dimensions is as expected. It's not thoroughly tested.

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        samples = torch.multinomial(self.probs, sample_shape.numel(), True)
        return samples.reshape([-1] + list(sample_shape))

cc @vincentqb @fritzo @neerajprad @alicanb @vishwakftw @VitalyFedyunin @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: distributionsRelated to torch.distributionsmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions