Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRU model learns very slowly when using DataParallel with multiple GPUs #33238

Closed
ousou opened this issue Feb 12, 2020 · 5 comments
Closed

GRU model learns very slowly when using DataParallel with multiple GPUs #33238

ousou opened this issue Feb 12, 2020 · 5 comments
Labels
module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ousou
Copy link

ousou commented Feb 12, 2020

🐛 Bug

When using DataParallel on a GRU model with multiple GPUs the model seems to learn very slowly during training, compared to when running on a single GPU. The issue is present in PyTorch 1.4.0 but not PyTorch 1.3.0.

To Reproduce

Run the following script on a multi-GPU machine (slightly modified from Pytorch RNN tutorial)

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time

print('Torch version', torch.__version__)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 15
learning_rate = 0.001
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

# Recurrent neural network (many-to-one)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        self.gru.flatten_parameters()
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.gru(x)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())), output_device=0)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print ('Training starts')
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    model.train()
    epoch_start_time = time.time()
    total_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
        total_loss += loss.item()
    epoch_duration = time.time() - epoch_start_time
    epoch_loss = total_loss / i
    print ('Epoch [{}/{}], Duration {:.4f} s, Epoch average loss {:.4f}' 
           .format(epoch+1, num_epochs, epoch_duration, epoch_loss))

    if (epoch+1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.reshape(-1, sequence_length, input_size).to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            print('Epoch [{}/{}], Test Accuracy of the model on the 10000 test images: {} %'.format(epoch+1,num_epochs, 100 * correct / total)) 

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

Output

Torch version 1.4.0
Training starts
Epoch [1/15], Step [100/600], Loss: 2.2840
Epoch [1/15], Step [200/600], Loss: 2.2771
Epoch [1/15], Step [300/600], Loss: 2.2441
Epoch [1/15], Step [400/600], Loss: 2.2385
Epoch [1/15], Step [500/600], Loss: 2.2250
Epoch [1/15], Step [600/600], Loss: 2.2158
Epoch [1/15], Duration 23.8333 s, Epoch average loss 2.2584
Epoch [2/15], Step [100/600], Loss: 2.1954
Epoch [2/15], Step [200/600], Loss: 2.1744
Epoch [2/15], Step [300/600], Loss: 2.1570
Epoch [2/15], Step [400/600], Loss: 2.1537
Epoch [2/15], Step [500/600], Loss: 2.1515
Epoch [2/15], Step [600/600], Loss: 2.1356
Epoch [2/15], Duration 9.2600 s, Epoch average loss 2.1780
Epoch [3/15], Step [100/600], Loss: 2.1289
Epoch [3/15], Step [200/600], Loss: 2.1365
Epoch [3/15], Step [300/600], Loss: 2.1135
Epoch [3/15], Step [400/600], Loss: 2.1153
Epoch [3/15], Step [500/600], Loss: 2.0827
Epoch [3/15], Step [600/600], Loss: 2.0771
Epoch [3/15], Duration 9.2692 s, Epoch average loss 2.1162
Epoch [4/15], Step [100/600], Loss: 2.1000
Epoch [4/15], Step [200/600], Loss: 2.1413
Epoch [4/15], Step [300/600], Loss: 2.0644
Epoch [4/15], Step [400/600], Loss: 2.0573
Epoch [4/15], Step [500/600], Loss: 2.0968
Epoch [4/15], Step [600/600], Loss: 2.0494
Epoch [4/15], Duration 9.2563 s, Epoch average loss 2.0668
Epoch [5/15], Step [100/600], Loss: 2.0678
Epoch [5/15], Step [200/600], Loss: 2.0399
Epoch [5/15], Step [300/600], Loss: 2.0628
Epoch [5/15], Step [400/600], Loss: 1.9648
Epoch [5/15], Step [500/600], Loss: 1.9510
Epoch [5/15], Step [600/600], Loss: 1.9990
Epoch [5/15], Duration 9.2674 s, Epoch average loss 2.0261
Epoch [5/15], Test Accuracy of the model on the 10000 test images: 37.12 %

Expected behavior

When running the same script using PyTorch 1.3.0 and torchvision 0.4.1 the model learns normally:

Torch version 1.3.0
Training starts
Epoch [1/15], Step [100/600], Loss: 0.8091
Epoch [1/15], Step [200/600], Loss: 0.3172
Epoch [1/15], Step [300/600], Loss: 0.3350
Epoch [1/15], Step [400/600], Loss: 0.2331
Epoch [1/15], Step [500/600], Loss: 0.1132
Epoch [1/15], Step [600/600], Loss: 0.3318
Epoch [1/15], Duration 27.2189 s, Epoch average loss 0.4798
Epoch [2/15], Step [100/600], Loss: 0.1276
Epoch [2/15], Step [200/600], Loss: 0.0696
Epoch [2/15], Step [300/600], Loss: 0.1202
Epoch [2/15], Step [400/600], Loss: 0.0390
Epoch [2/15], Step [500/600], Loss: 0.0975
Epoch [2/15], Step [600/600], Loss: 0.0764
Epoch [2/15], Duration 9.0211 s, Epoch average loss 0.1134
Epoch [3/15], Step [100/600], Loss: 0.0369
Epoch [3/15], Step [200/600], Loss: 0.0832
Epoch [3/15], Step [300/600], Loss: 0.0255
Epoch [3/15], Step [400/600], Loss: 0.1506
Epoch [3/15], Step [500/600], Loss: 0.2035
Epoch [3/15], Step [600/600], Loss: 0.0542
Epoch [3/15], Duration 9.0659 s, Epoch average loss 0.0693
Epoch [4/15], Step [100/600], Loss: 0.0173
Epoch [4/15], Step [200/600], Loss: 0.0687
Epoch [4/15], Step [300/600], Loss: 0.0878
Epoch [4/15], Step [400/600], Loss: 0.0255
Epoch [4/15], Step [500/600], Loss: 0.0944
Epoch [4/15], Step [600/600], Loss: 0.0198
Epoch [4/15], Duration 9.0609 s, Epoch average loss 0.0523
Epoch [5/15], Step [100/600], Loss: 0.0432
Epoch [5/15], Step [200/600], Loss: 0.1001
Epoch [5/15], Step [300/600], Loss: 0.0589
Epoch [5/15], Step [400/600], Loss: 0.1240
Epoch [5/15], Step [500/600], Loss: 0.0341
Epoch [5/15], Step [600/600], Loss: 0.0303
Epoch [5/15], Duration 9.0712 s, Epoch average loss 0.0408
Epoch [5/15], Test Accuracy of the model on the 10000 test images: 98.61 %

Also when using PyTorch 1.4.0 with just one GPU (without DataParallel) the model learns as it should (results are the same as above). With PyTorch 1.3.0 it doesn't matter whether one GPU (without DataParallel) or multiple GPUs (with DataParallel) is used - the results are the same.

Environment

The tests were run on an AWS g4dn.12xlarge instance.

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: version 3.13.3

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: 
GPU 0: Tesla T4
GPU 1: Tesla T4
GPU 2: Tesla T4
GPU 3: Tesla T4

Nvidia driver version: 418.87.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.17.2
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect

Additional context

The problem does not seem to be related to torchvision even though the example uses it. We've noticed similar issues in our actual models that use GRUs but do not use torchvision at all.

Possible related issue: #33081

@zhangguanheng66 zhangguanheng66 added module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 12, 2020
@ousou
Copy link
Author

ousou commented Feb 13, 2020

I did the same test with an LSTM and a vanilla RNN instead of an GRU, and the same issue occurs also then. Thus the issue is not specific to GRUs but concerns RNNs in general.

Also, when running the test with vanilla RNN I got a segmentation fault pretty often. The segmentation fault happens at different times on different runs (even though the seed is fixed). Example printouts below.

Torch version 1.4.0
Training starts
Epoch [1/15], Step [100/600], Loss: 2.2660
Epoch [1/15], Step [200/600], Loss: 2.2493
Epoch [1/15], Step [300/600], Loss: 2.2162
Epoch [1/15], Step [400/600], Loss: 2.1765
Epoch [1/15], Step [500/600], Loss: 2.1570
Epoch [1/15], Step [600/600], Loss: 2.1345
Epoch [1/15], Duration 21.5740 s, Epoch average loss 2.2182
Epoch [2/15], Step [100/600], Loss: 2.0875
Epoch [2/15], Step [200/600], Loss: 2.1005
Epoch [2/15], Step [300/600], Loss: 2.0381
Epoch [2/15], Step [400/600], Loss: 2.0286
Epoch [2/15], Step [500/600], Loss: 2.0638
Epoch [2/15], Step [600/600], Loss: 2.0215
Epoch [2/15], Duration 9.0928 s, Epoch average loss 2.0842
Epoch [3/15], Step [100/600], Loss: 1.9658
Epoch [3/15], Step [200/600], Loss: 1.9806
Epoch [3/15], Step [300/600], Loss: 2.0045
Epoch [3/15], Step [400/600], Loss: 1.9696
Epoch [3/15], Step [500/600], Loss: 1.9714
Epoch [3/15], Step [600/600], Loss: 1.9749
Epoch [3/15], Duration 9.0357 s, Epoch average loss 1.9898
Epoch [4/15], Step [100/600], Loss: 1.9287
Epoch [4/15], Step [200/600], Loss: 2.0148
Epoch [4/15], Step [300/600], Loss: 1.9006
Epoch [4/15], Step [400/600], Loss: 1.9341
Epoch [4/15], Step [500/600], Loss: 1.8464
Epoch [4/15], Step [600/600], Loss: 1.8892
Epoch [4/15], Duration 9.0434 s, Epoch average loss 1.9196
Epoch [5/15], Step [100/600], Loss: 2.0397
Epoch [5/15], Step [200/600], Loss: 1.8745
Epoch [5/15], Step [300/600], Loss: 1.8347
Epoch [5/15], Step [400/600], Loss: 1.8775
Epoch [5/15], Step [500/600], Loss: 1.8454
Epoch [5/15], Step [600/600], Loss: 1.8164
Epoch [5/15], Duration 9.0442 s, Epoch average loss 1.8647
Epoch [5/15], Test Accuracy of the model on the 10000 test images: 43.06 %
Epoch [6/15], Step [100/600], Loss: 1.8553
Epoch [6/15], Step [200/600], Loss: 1.8733
Epoch [6/15], Step [300/600], Loss: 1.8868
Epoch [6/15], Step [400/600], Loss: 1.7853
Epoch [6/15], Step [500/600], Loss: 1.8486
Epoch [6/15], Step [600/600], Loss: 1.8526
Epoch [6/15], Duration 9.0407 s, Epoch average loss 1.8207
Epoch [7/15], Step [100/600], Loss: 1.7770
Epoch [7/15], Step [200/600], Loss: 1.7661
Epoch [7/15], Step [300/600], Loss: 1.7560
Epoch [7/15], Step [400/600], Loss: 1.6892
Epoch [7/15], Step [500/600], Loss: 1.7210
Epoch [7/15], Step [600/600], Loss: 1.6954
Epoch [7/15], Duration 9.0591 s, Epoch average loss 1.7851
Epoch [8/15], Step [100/600], Loss: 1.7385
Epoch [8/15], Step [200/600], Loss: 1.7263
Epoch [8/15], Step [300/600], Loss: 1.6587
Segmentation fault (core dumped)

@ousou
Copy link
Author

ousou commented Feb 13, 2020

I did some debugging, and the issue seems to be related to the caching of the flattened weights in the RNN introduced in PR #27399. I made a test where I essentially rolled back the changes relating to the flattened weights to torch/nn/modules/rnn.py in commit kill _parameter_list, and after this change the training worked on multiple GPUs, i.e. it learned in the same way as with one GPU and as with PyTorch 1.3.0.

Below is a diff of the changes I made to fix this issue. I do not really understand how these flattened weights relate to other parts of the code, so I'm not confident enough to make a pull request of this directly. Hopefully this helps in solving the issue anyway!

diff --git a/rnn.py b/rnn_fixed.py
index a9a0d0d..e787937 100644
--- a/rnn.py
+++ b/rnn_fixed.py
@@ -60,7 +60,6 @@ class RNNBase(Module):
         else:
             raise ValueError("Unrecognized RNN mode: " + mode)
 
-        self._flat_weights_names = []
         self._all_weights = []
         for layer in range(num_layers):
             for direction in range(num_directions):
@@ -82,10 +81,8 @@ class RNNBase(Module):
 
                 for name, param in zip(param_names, layer_params):
                     setattr(self, name, param)
-                self._flat_weights_names.extend(param_names)
                 self._all_weights.append(param_names)
 
-        self._flat_weights = [getattr(self, weight) for weight in self._flat_weights_names]
         self.flatten_parameters()
         self.reset_parameters()
 
@@ -134,7 +131,7 @@ class RNNBase(Module):
         # Resets _flat_weights
         # Note: be v. careful before removing this, as 3rd party device types
         # likely rely on this behavior to properly .to() modules like LSTM.
-        self._flat_weights = [getattr(self, weight) for weight in self._flat_weights_names]
+        # self._flat_weights = [getattr(self, weight) for weight in self._flat_weights_names]
 
         # Flattens params (on CUDA)
         self.flatten_parameters()
@@ -146,6 +143,9 @@ class RNNBase(Module):
         for weight in self.parameters():
             init.uniform_(weight, -stdv, stdv)
 
+    def _get_flat_weights_names(self): 
+        return [weight for weights in self._all_weights for weight in weights] 
+
     def check_input(self, input, batch_sizes):
         # type: (Tensor, Optional[Tensor]) -> None
         expected_input_dim = 2 if batch_sizes is not None else 3
@@ -264,17 +264,13 @@ class RNNBase(Module):
         self._flat_weights = [getattr(self, weight) for weight in self._flat_weights_names]
 
     @property
+    def _flat_weights(self):   
+        return [p for layerparams in self.all_weights for p in layerparams]
+
+    @property
     def all_weights(self):
         return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
 
-    def _replicate_for_data_parallel(self):
-        replica = super(RNNBase, self)._replicate_for_data_parallel()
-        # Need to copy these caches, otherwise the replica will share the same
-        # flat weights list.
-        replica._flat_weights = replica._flat_weights[:]
-        replica._flat_weights_names = replica._flat_weights_names[:]
-        return replica
-
 
 class RNN(RNNBase):
     r"""Applies a multi-layer Elman RNN with :math:`tanh` or :math:`ReLU` non-linearity to an

@algoterranean
Copy link

Thanks @ousou, I was running into the same problem and this patch does seem to work so far.

This is a surprisingly significant problem for a main release and needs some attention.

@ousou
Copy link
Author

ousou commented Mar 24, 2020

Another related issue is #33552, and it's possible that this problem has already been fixed by PR #33907.

@ousou
Copy link
Author

ousou commented Apr 22, 2020

This has been fixed in version 1.5.0, closing issue.

@ousou ousou closed this as completed Apr 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants