Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.LSTM gives nondeterministic results with dropout and multiple layers #18110

Closed
freewym opened this issue Mar 17, 2019 · 14 comments
Closed

nn.LSTM gives nondeterministic results with dropout and multiple layers #18110

freewym opened this issue Mar 17, 2019 · 14 comments
Assignees
Labels
high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support module: dependency bug Problem is not caused by us, but caused by an upstream library we use triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@freewym
Copy link

freewym commented Mar 17, 2019

🐛 Bug

I got non-deterministic results when I run my model with nn.LSTM with its dropout > 0 on GPU, even when I seeded everything and torch.backends.cudnn.deterministic = True. Also, if I set torch.backends.cudnn.enabled = False, the results are deterministic.

To Reproduce

Steps to reproduce the behavior:

  1. torch.backends.cudnn.deterministic = True
    random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    np.random.seed(1)

  2. define a module as:
    nn.LSTM(input_size=256,
    hidden_size=256,
    num_layers=3,
    dropout=0.1,
    bidirectional=True,
    )

  3. training with the defined module multiple times

Expected behavior

The training should be deterministic across different runs

Environment

PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Debian GNU/Linux 9.4 (stretch)
GCC version: (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
CMake version: version 3.7.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla K80
GPU 1: Tesla K80
GPU 2: Tesla K80
GPU 3: Tesla K80

Nvidia driver version: 387.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.12.1
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl-service 1.1.2 py37h90e4bf4_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.1 py3.7_cuda9.0.176_cudnn7.4.2_2 pytorch
[conda] torchvision 0.2.2 py_3 pytorch

Additional context

@ezyang ezyang added the module: cudnn Related to torch.backends.cudnn, and CuDNN support label Mar 18, 2019
@ezyang
Copy link
Contributor

ezyang commented Mar 18, 2019

cuDNN's seed comes from us, but it's possible that the backwards is non-deterministic for RNNs. CC @ngimel

You wouldn't happen to have a self-contained script we could run and see the nondeterminism, do you?

@freewym
Copy link
Author

freewym commented Mar 19, 2019

I just push a toy example of how to reproduce the issue. please checkout the following code:

https://github.com/freewym/pytorch_test_code/blob/master/train.py

and type the command:

python3 train.py --gpuid <gpu-id> --dropout <dropout-value> --bidirectional <true|false>

it will give you the loss for each batch like:

loss at batch 0: 4.609390735626221
loss at batch 1: 4.607876777648926
loss at batch 2: 4.6088972091674805
loss at batch 3: 4.605157852172852
loss at batch 4: 4.603259086608887

if the command is:
python3 train.py --gpuid <gpu-id> --dropout 0.5 --bidirectional true

The loss on each batch is not consistent across different runs (although the difference is very small in this toy example, it is significant in my real experiments). However, if --dropout 0, or --bidirectional false, or Line 50 of the code is uncommented (i.e., cudnn is disabled), the printed losses are exactly the same across different runs.

To sum up, this inconsistency only occurs when 1) cudnn is enabled; and 2) nn.LSTM's argument bidirectional=True; and 3) nn.LSTM's argument dropout > 0; and 4) nn.LSTM's argument num_layers > 1

@mruberry
Copy link
Collaborator

mruberry commented Apr 5, 2019

We are investigating this issue.

@ezyang ezyang added module: dependency bug Problem is not caused by us, but caused by an upstream library we use triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module high priority and removed triage review labels Apr 9, 2019
@bentrevett
Copy link

bentrevett commented Apr 10, 2019

If it's any use, I also have this issue. Here is my environment:

Environment

CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 415.27
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.16.2
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.1 py3.7_cuda9.0.176_cudnn7.4.2_2 pytorch
[conda] pytorch-ignite 0.1.2
[conda] pytorch-pretrained-bert 0.6.1
[conda] torchtext 0.3.1
[conda] torchvision 0.2.2 py_3 pytorch

@skurzhanskyi
Copy link

The same problem at PyTorch 0.4 🙁

@mruberry
Copy link
Collaborator

Yes. Unfortunately I think we expect this issue with all versions of PyTorch. The issue is in cuDNN, not PyTorch.

@freewym
Copy link
Author

freewym commented Apr 12, 2019

Is there a way to report to Nvidia about this issue?

@mruberry
Copy link
Collaborator

Yes, I work at NVIDIA :P

@mhn226
Copy link

mhn226 commented Apr 26, 2019

I'm having the same problem here. Been trying for awhile to figure out what is the possible cause. Thanks for this report!

@gchanan
Copy link
Contributor

gchanan commented Aug 22, 2019

@nairbv to check if turning off a nondeterministic algorithm fixed this.

@freewym
Copy link
Author

freewym commented Aug 22, 2019 via email

@nairbv
Copy link
Collaborator

nairbv commented Aug 28, 2019

The semi-related issue was a case where non-deterministic results were incorrect, so setting that operation to always be deterministic fixed the issue.

In this issue we see non-deterministic results when deterministic=True. I'm not sure if there's a way to fix this path in pytorch without fixing in cudnn, though maybe we could add a warning.

@ngimel
Copy link
Collaborator

ngimel commented Aug 28, 2019

It should have been fixed in cudnn 7.6.2, @jjsjann123 can you please check nvbug?

@jjsjann123
Copy link
Collaborator

Closed and fixed in cudnn_7.6.1 @ngimel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support module: dependency bug Problem is not caused by us, but caused by an upstream library we use triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests