-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GRU model learns very slowly when using DataParallel with multiple GPUs #33238
Comments
I did the same test with an LSTM and a vanilla RNN instead of an GRU, and the same issue occurs also then. Thus the issue is not specific to GRUs but concerns RNNs in general. Also, when running the test with vanilla RNN I got a segmentation fault pretty often. The segmentation fault happens at different times on different runs (even though the seed is fixed). Example printouts below. Torch version 1.4.0 |
I did some debugging, and the issue seems to be related to the caching of the flattened weights in the RNN introduced in PR #27399. I made a test where I essentially rolled back the changes relating to the flattened weights to Below is a diff of the changes I made to fix this issue. I do not really understand how these flattened weights relate to other parts of the code, so I'm not confident enough to make a pull request of this directly. Hopefully this helps in solving the issue anyway!
|
Thanks @ousou, I was running into the same problem and this patch does seem to work so far. This is a surprisingly significant problem for a main release and needs some attention. |
This has been fixed in version 1.5.0, closing issue. |
🐛 Bug
When using DataParallel on a GRU model with multiple GPUs the model seems to learn very slowly during training, compared to when running on a single GPU. The issue is present in PyTorch 1.4.0 but not PyTorch 1.3.0.
To Reproduce
Run the following script on a multi-GPU machine (slightly modified from Pytorch RNN tutorial)
Output
Torch version 1.4.0
Training starts
Epoch [1/15], Step [100/600], Loss: 2.2840
Epoch [1/15], Step [200/600], Loss: 2.2771
Epoch [1/15], Step [300/600], Loss: 2.2441
Epoch [1/15], Step [400/600], Loss: 2.2385
Epoch [1/15], Step [500/600], Loss: 2.2250
Epoch [1/15], Step [600/600], Loss: 2.2158
Epoch [1/15], Duration 23.8333 s, Epoch average loss 2.2584
Epoch [2/15], Step [100/600], Loss: 2.1954
Epoch [2/15], Step [200/600], Loss: 2.1744
Epoch [2/15], Step [300/600], Loss: 2.1570
Epoch [2/15], Step [400/600], Loss: 2.1537
Epoch [2/15], Step [500/600], Loss: 2.1515
Epoch [2/15], Step [600/600], Loss: 2.1356
Epoch [2/15], Duration 9.2600 s, Epoch average loss 2.1780
Epoch [3/15], Step [100/600], Loss: 2.1289
Epoch [3/15], Step [200/600], Loss: 2.1365
Epoch [3/15], Step [300/600], Loss: 2.1135
Epoch [3/15], Step [400/600], Loss: 2.1153
Epoch [3/15], Step [500/600], Loss: 2.0827
Epoch [3/15], Step [600/600], Loss: 2.0771
Epoch [3/15], Duration 9.2692 s, Epoch average loss 2.1162
Epoch [4/15], Step [100/600], Loss: 2.1000
Epoch [4/15], Step [200/600], Loss: 2.1413
Epoch [4/15], Step [300/600], Loss: 2.0644
Epoch [4/15], Step [400/600], Loss: 2.0573
Epoch [4/15], Step [500/600], Loss: 2.0968
Epoch [4/15], Step [600/600], Loss: 2.0494
Epoch [4/15], Duration 9.2563 s, Epoch average loss 2.0668
Epoch [5/15], Step [100/600], Loss: 2.0678
Epoch [5/15], Step [200/600], Loss: 2.0399
Epoch [5/15], Step [300/600], Loss: 2.0628
Epoch [5/15], Step [400/600], Loss: 1.9648
Epoch [5/15], Step [500/600], Loss: 1.9510
Epoch [5/15], Step [600/600], Loss: 1.9990
Epoch [5/15], Duration 9.2674 s, Epoch average loss 2.0261
Epoch [5/15], Test Accuracy of the model on the 10000 test images: 37.12 %
Expected behavior
When running the same script using PyTorch 1.3.0 and torchvision 0.4.1 the model learns normally:
Torch version 1.3.0
Training starts
Epoch [1/15], Step [100/600], Loss: 0.8091
Epoch [1/15], Step [200/600], Loss: 0.3172
Epoch [1/15], Step [300/600], Loss: 0.3350
Epoch [1/15], Step [400/600], Loss: 0.2331
Epoch [1/15], Step [500/600], Loss: 0.1132
Epoch [1/15], Step [600/600], Loss: 0.3318
Epoch [1/15], Duration 27.2189 s, Epoch average loss 0.4798
Epoch [2/15], Step [100/600], Loss: 0.1276
Epoch [2/15], Step [200/600], Loss: 0.0696
Epoch [2/15], Step [300/600], Loss: 0.1202
Epoch [2/15], Step [400/600], Loss: 0.0390
Epoch [2/15], Step [500/600], Loss: 0.0975
Epoch [2/15], Step [600/600], Loss: 0.0764
Epoch [2/15], Duration 9.0211 s, Epoch average loss 0.1134
Epoch [3/15], Step [100/600], Loss: 0.0369
Epoch [3/15], Step [200/600], Loss: 0.0832
Epoch [3/15], Step [300/600], Loss: 0.0255
Epoch [3/15], Step [400/600], Loss: 0.1506
Epoch [3/15], Step [500/600], Loss: 0.2035
Epoch [3/15], Step [600/600], Loss: 0.0542
Epoch [3/15], Duration 9.0659 s, Epoch average loss 0.0693
Epoch [4/15], Step [100/600], Loss: 0.0173
Epoch [4/15], Step [200/600], Loss: 0.0687
Epoch [4/15], Step [300/600], Loss: 0.0878
Epoch [4/15], Step [400/600], Loss: 0.0255
Epoch [4/15], Step [500/600], Loss: 0.0944
Epoch [4/15], Step [600/600], Loss: 0.0198
Epoch [4/15], Duration 9.0609 s, Epoch average loss 0.0523
Epoch [5/15], Step [100/600], Loss: 0.0432
Epoch [5/15], Step [200/600], Loss: 0.1001
Epoch [5/15], Step [300/600], Loss: 0.0589
Epoch [5/15], Step [400/600], Loss: 0.1240
Epoch [5/15], Step [500/600], Loss: 0.0341
Epoch [5/15], Step [600/600], Loss: 0.0303
Epoch [5/15], Duration 9.0712 s, Epoch average loss 0.0408
Epoch [5/15], Test Accuracy of the model on the 10000 test images: 98.61 %
Also when using PyTorch 1.4.0 with just one GPU (without DataParallel) the model learns as it should (results are the same as above). With PyTorch 1.3.0 it doesn't matter whether one GPU (without DataParallel) or multiple GPUs (with DataParallel) is used - the results are the same.
Environment
The tests were run on an AWS g4dn.12xlarge instance.
Additional context
The problem does not seem to be related to torchvision even though the example uses it. We've noticed similar issues in our actual models that use GRUs but do not use torchvision at all.
Possible related issue: #33081
The text was updated successfully, but these errors were encountered: