-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🐛 Bug
When the model is using DataParallel
and we call flatten_parameters
inside the model under torch.no_grad
it throws this error:
RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()
works fine otherwise. This behavior only happens on 1.1.0 and was working fine on 1.0.1.post2
To Reproduce
Run the code below on 1.1.0 to reproduce the behavior:
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
def forward(self, x):
self.rnn.flatten_parameters()
return self.rnn(x) # N * T * hidden_dim
model = torch.nn.DataParallel(Model().to('cuda'))
with torch.no_grad():
x = model(torch.rand(2, 4, 300))
Expected behavior
flatten_parameters
should work as it does without DataParallel
Environment
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.9.4
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] msgpack-numpy==0.4.1
[pip] numpy==1.16.4
[pip] numpydoc==0.7.0
[pip] pytorch-nlp==0.3.5
[pip] pytorch-pretrained-bert==0.3.0
[pip] torch==1.1.0
[pip] torchfile==0.1.0
[pip] torchtext==0.2.3
[pip] torchvision==0.2.0
[conda] cuda90 1.0 h6433d27_0 pytorch
[conda] faiss-cpu 1.2.1 py36_cuda9.0.176_1 pytorch
[conda] faiss-gpu 1.2.1 py36_cuda9.0.176_1 pytorch
[conda] magma-cuda90 2.3.0 1 pytorch
[conda] mkl 2018.0.1 h19d6760_4 anaconda
[conda] mkl-fft 1.0.0
[conda] mkl-include 2018.0.3 1
[conda] mkl-random 1.0.1
[conda] mkl-service 1.1.2 py36h17a0993_4
[conda] mkl_fft 1.0.2 np114py36_intel_0 [intel] intel
[conda] mkl_random 1.0.1 np114py36_intel_0 [intel] intel
[conda] mkldnn 0.14.0 0 mingfeima
[conda] nccl2 1.0 0 pytorch
[conda] pytorch-nlp 0.3.5
[conda] pytorch-pretrained-bert 0.3.0
[conda] torch 1.1.0
[conda] torchfile 0.1.0
[conda] torchtext 0.2.3
[conda] torchvision 0.2.0