Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.multinomial chooses elements with zero weight #13867

Closed
jcjohnson opened this issue Nov 12, 2018 · 16 comments
Closed

torch.multinomial chooses elements with zero weight #13867

jcjohnson opened this issue Nov 12, 2018 · 16 comments

Comments

@jcjohnson
Copy link
Contributor

jcjohnson commented Nov 12, 2018

馃悰 Bug

torch.multinomial occasionally samples elements with zero weight. This should never happen.

To Reproduce

I've been unable to reproduce this issue with randomly generated weights, so I've included a particular value of weights from my application that triggers this behavior:

 wget https://cs.stanford.edu/people/jcjohns/weights.pt

These weights are all nonnegative (but contain a lot of zeros), have a nonzero sum, and contain no NaNs or Infs.

import torch

torch.manual_seed(1)
weights = torch.load('weights.pt')
N, S = weights.shape[0], 4096
num_trials = 100
for trial in range(1, num_trials + 1):
  print('Starting trial %d / %d' % (trial, num_trials))
  weights[weights < 0] = 0.0
  samples = weights.multinomial(S, replacement=True)
  sampled_weights = weights[samples]
  assert sampled_weights.min() > 0

I fail the assertion on trial 6.

Environment

PyTorch version: 1.0.0.dev20181112
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 396.51
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] Could not collect
[conda] pytorch 0.4.1 py37_py36_py35_py27__9.0.176_7.1.2_2 pytorch
[conda] pytorch-nightly 1.0.0.dev20181112 py3.7_cuda9.0.176_cudnn7.1.2_0 pytorch
[conda] torchvision 0.2.1
[conda] torchvision 0.2.1 py37_1 pytorch

@zou3519
Copy link
Contributor

zou3519 commented Nov 12, 2018

@jcjohnson can you confirm that you are running the latest pytorch when running this script? print(torch.__version__). I think we fixed an identical bug a while ago, but it looks like that fix wasn't enough.

@jcjohnson
Copy link
Contributor Author

@zou3519 I just reinstalled from the nightly build, version 1.0.0.dev20181112. Can you point me to the earlier bugfix?

@zou3519
Copy link
Contributor

zou3519 commented Nov 12, 2018

My bad, it looks like we fixed this for CUDA but we did not test on CPU: #4858. We'll look into it and get it fixed, thank you for the report :)

@jcjohnson
Copy link
Contributor Author

That's weird -- I'm seeing this issue only on CUDA, and it works properly when I cast weights to CPU.

@zou3519
Copy link
Contributor

zou3519 commented Nov 12, 2018

Got it, I didn't realize your weights were on CUDA. I can reproduce the assertion using your weights, so something is indeed wrong with the multinomial implementation

@zou3519
Copy link
Contributor

zou3519 commented Nov 12, 2018

I'm wondering if floating point error could be to blame. One interesting thing to note that weights < 0 returns False for element 0:

(Pdb) weights
tensor([1.6399e-05, 1.1493e-05, 1.0797e-05,  ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00], device='cuda:0')
(Pdb) weights < 0
tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0', dtype=torch.uint8)
(Pdb) weights[weights < 0] = 0
(Pdb) weights[0]
tensor(1.6399e-05, device='cuda:0')

@jcjohnson
Copy link
Contributor Author

Isn't that correct? 1.6399e-05 is small but positive.

However many of the weights are quite small (and will become even smaller if multinomial internally renormalizes to sum to one) so I wouldn't be surprised if some floating point error were to blame.

@zou3519
Copy link
Contributor

zou3519 commented Nov 12, 2018

Of course -- my apologies, I was reading that too quickly.

@jcjohnson
Copy link
Contributor Author

No worries, I'm grateful for the fast response =)

@syed-ahmed
Copy link
Contributor

syed-ahmed commented Nov 13, 2018

@jcjohnson @zou3519 I think the problem is more with how we are seeding a Mersenne Twister engine. I recently learned that the 19937 states of a Mersenne Twister engine is very prone to getting into a bad state when one seeds the engine with a number with many 0 bits ("all zeros causes it to not work at all, whereas lots of zero bits are merely bad" - http://www.pcg-random.org/posts/cpp-seeding-surprises.html). I ran your script with seed = 10, and it breaks the assertion at trial 17.

Your script passes in my current PR #13070, (the PR is almost done and is waiting on some builds). I have changed the CUDA generator engine for multinomial to philox engine and I suppose the script passes because the philox engine doesn't have as many states as a Mersenne twister engine and we are seeding it properly with a 64 bit number.

@D-X-Y
Copy link

D-X-Y commented Dec 11, 2018

In https://pytorch.org/docs/stable/torch.html?highlight=multinomial#torch.multinomial

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 4)
tensor([ 1,  2,  0,  0])

Why torch.multinomial outputs [1,2,0,0] ? Since replacement=False, it can not generate same indexes.
Is this a bug and can anyone help to explain this?

@jcjohnson
Copy link
Contributor Author

Is there any update on this? It has been two months.

@syed-ahmed
Copy link
Contributor

Hi @jcjohnson . Apologies for the super long delay! My PR referred above became huge for review, so I'm currently breaking that up into two parts. I promise to push the two parts by end of this week.

@jcjohnson
Copy link
Contributor Author

Thanks! Your PR looks pretty nontrivial indeed, so I'm not surprised it has taken a while to get sorted out. I'm looking forward to it!

@t-vi
Copy link
Collaborator

t-vi commented Jan 16, 2019

So in terms of a minimal fix:
The cumsum result (I'm not quite able to see this by calling cumsum manually, unfortunately, but used cuda-gdb) seems to include 0.99997884, 0.999978781 (in that order, i.e. it is not monotonically non-decreasing) in the critical positions.
Our logic to avoid zero probability items essentially checks for cumdist[n-1] == cumdist[n], but that doesn't work here.

I think the main options for a minimal fix ("1.0.1") are

  • write a cumsum replacement that returns tensors with non-decreasing entries for non-negative inputs,
  • pass the non-cumulated distribution to the sampling/bisection and check that for 0 in the above line.

I would expect the second to be the least risky fix because it seems to add the least logic.

@t-vi
Copy link
Collaborator

t-vi commented Jan 16, 2019

I seem to have a simpler repro:

        # test corner case from Issue #13867
        torch.cuda.manual_seed(33)
        probs = torch.randn(1_000_000, device='cuda').clamp(min=0)*3e-5
        samples = probs.multinomial(1_000_000, replacement=True)
        assert probs[samples].min().item() > 0

I'll have the PR in a few moments.

t-vi added a commit to t-vi/pytorch that referenced this issue Jan 16, 2019
The cumsum over the probabilities can be not monotonically
non-decreasing. Thus it is hard to detect zero probability
classes using just the cumsum.
This changes the binary search postprocessing to use the
(non-cumulated) distribution instead.

Thank you, @jcjohnson, for the bug report with
reproducing case.

Fixes: pytorch#13867
soumith pushed a commit that referenced this issue Feb 4, 2019
Summary:
The cumsum over the probabilities can be not monotonically
non-decreasing. Thus it is hard to detect zero probability
classes using just the cumsum.
This changes the binary search postprocessing to use the
(non-cumulated) distribution instead.

Thank you, jcjohnson, for the bug report with
reproducing case.

Fixes: #13867
Pull Request resolved: #16075

Differential Revision: D13695565

Pulled By: soumith

fbshipit-source-id: 02c4d6f868f0050c1ae7d333f4317c5610e49cd9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants