-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
When sampling from a torch.distributions.Categorical distribution initialized with a few distributions (e.g. 2 distributions each with 2**16 categories) the sample() functions allocates unnecessary huge amount of memory.
To Reproduce
import torch
p = torch.ones([2, 2**16], device='cuda:0')
dist = torch.distributions.Categorical(p)
dist.sample([2**16])
Output:
Traceback (most recent call last):
File ".../lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3326, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 4, in
dist.sample([2**16])
File ".../lib/python3.7/site-packages/torch/distributions/categorical.py", line 106, in sample
probs_2d = probs.reshape(-1, self._num_events)
RuntimeError: CUDA out of memory. Tried to allocate 32.00 GiB (GPU 0; 1.95 GiB total capacity; 1024.00 KiB already allocated; 1.15 GiB free; 2.00 MiB reserved in total by PyTorch)
Expected behavior
Looking into the torch.distributions.Categorical source shows:
def sample(self, sample_shape=torch.Size()):
sample_shape = self._extended_shape(sample_shape)
param_shape = sample_shape + torch.Size((self._num_events,))
probs = self.probs.expand(param_shape)
probs_2d = probs.reshape(-1, self._num_events)
sample_2d = torch.multinomial(probs_2d, 1, True)
return sample_2d.reshape(sample_shape)
The last three lines allocate the new memory and use it. probs_2d becomes a potentially huge tensor with repeated information.
The following code snippet solves it and uses much less memory, but I'm not sure it covers all cases and whether the order of the dimensions is as expected. It's not thoroughly tested.
def sample(self, sample_shape=torch.Size()):
if not isinstance(sample_shape, torch.Size):
sample_shape = torch.Size(sample_shape)
samples = torch.multinomial(self.probs, sample_shape.numel(), True)
return samples.reshape([-1] + list(sample_shape))
cc @vincentqb @fritzo @neerajprad @alicanb @vishwakftw @VitalyFedyunin @ngimel