Skip to content

[DataParallel] flatten_parameters doesn't work under torch.no_grad #21108

@apsdehal

Description

@apsdehal

🐛 Bug

When the model is using DataParallel and we call flatten_parameters inside the model under torch.no_grad it throws this error:

RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()

works fine otherwise. This behavior only happens on 1.1.0 and was working fine on 1.0.1.post2

To Reproduce

Run the code below on 1.1.0 to reproduce the behavior:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
    def forward(self, x):
        self.rnn.flatten_parameters()
        return self.rnn(x)  # N * T * hidden_dim


model = torch.nn.DataParallel(Model().to('cuda'))

with torch.no_grad():
    x = model(torch.rand(2, 4, 300))

Expected behavior

flatten_parameters should work as it does without DataParallel

Environment

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.9.4

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] msgpack-numpy==0.4.1
[pip] numpy==1.16.4
[pip] numpydoc==0.7.0
[pip] pytorch-nlp==0.3.5
[pip] pytorch-pretrained-bert==0.3.0
[pip] torch==1.1.0
[pip] torchfile==0.1.0
[pip] torchtext==0.2.3
[pip] torchvision==0.2.0
[conda] cuda90 1.0 h6433d27_0 pytorch
[conda] faiss-cpu 1.2.1 py36_cuda9.0.176_1 pytorch
[conda] faiss-gpu 1.2.1 py36_cuda9.0.176_1 pytorch
[conda] magma-cuda90 2.3.0 1 pytorch
[conda] mkl 2018.0.1 h19d6760_4 anaconda
[conda] mkl-fft 1.0.0
[conda] mkl-include 2018.0.3 1
[conda] mkl-random 1.0.1
[conda] mkl-service 1.1.2 py36h17a0993_4
[conda] mkl_fft 1.0.2 np114py36_intel_0 [intel] intel
[conda] mkl_random 1.0.1 np114py36_intel_0 [intel] intel
[conda] mkldnn 0.14.0 0 mingfeima
[conda] nccl2 1.0 0 pytorch
[conda] pytorch-nlp 0.3.5
[conda] pytorch-pretrained-bert 0.3.0
[conda] torch 1.1.0
[conda] torchfile 0.1.0
[conda] torchtext 0.2.3
[conda] torchvision 0.2.0

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions