Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assigning a parameter to an indexed tensor that was produced by DDP no longer works in torch nightly (1.7) #46242

Closed
david-macleod opened this issue Oct 13, 2020 · 9 comments
Assignees
Labels
high priority module: assert failure The issue involves an assert failure module: autograd Related to torch.autograd, and the autograd engine in general module: ddp Issues/PRs related distributed data parallel training module: error checking Bugs related to incorrect/lacking error checking oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@david-macleod
Copy link

david-macleod commented Oct 13, 2020

馃悰 Bug

In torch 1.6 assigning a parameter to an indexed tensor successfully create a new tensor which was part of forward graph, this behavior no long works using the latest torch 1.7 nightly

To Reproduce

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(8, 16)
        self.token = torch.nn.Parameter(torch.zeros(8))

    def forward(self, x):
        x[[True, False]] = self.token
        x = self.layer(x)
        return x

def main(rank, world_size):

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    model = Model().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    x = torch.randn(2, 8, device=rank)
    y = ddp_model(x).mean()
    y.backward()


    torch.cuda.synchronize()
    dist.barrier()
    dist.destroy_process_group()

if __name__ == "__main__":
    mp.spawn(main, args=(2,), nprocs=2, join=True)

Which returns the (truncated) output

    x[[True, False]] = self.token
RuntimeError: diff_view_meta->creation_meta == CreationMeta::DEFAULT INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/variable.cpp":87, please report a bug to PyTorch.

Expected behavior

No error to occur, as in torch 1.6

Environment

PyTorch version: 1.8.0.dev20201012
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.10.2

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti
GPU 8: GeForce RTX 2080 Ti
GPU 9: GeForce RTX 2080 Ti

Nvidia driver version: 440.100
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.2
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.8.0.4
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.4
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.4
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.4
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.4
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.4
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.8.0.dev20201012
[pip3] torchvision==0.8.0.dev20201012
[conda] Could not collect

Additional context

When running the same code with torch 1.6 the grad_fn on x after assiging the parameter is grad_fn=<IndexPutBackward>

cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @gqchen @pearu @nikitaved @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski @ejguan

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 13, 2020

I was able to repro this as well.

@bdhirsh bdhirsh added high priority module: assert failure The issue involves an assert failure labels Oct 13, 2020
@ezyang ezyang added the module: autograd Related to torch.autograd, and the autograd engine in general label Oct 13, 2020
@ezyang
Copy link
Contributor

ezyang commented Oct 13, 2020

@bdhirsh Don't forget to add a module label to bugs when you are triaging, this will make sure the correct person gets CC'ed

@albanD albanD self-assigned this Oct 13, 2020
@bdhirsh bdhirsh added the module: error checking Bugs related to incorrect/lacking error checking label Oct 13, 2020
@gchanan gchanan added this to the 1.7.0 milestone Oct 13, 2020
@gchanan gchanan added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 13, 2020
@mrshenli mrshenli added the module: ddp Issues/PRs related distributed data parallel training label Oct 13, 2020
@albanD
Copy link
Collaborator

albanD commented Oct 13, 2020

Hi,

After looking into this her are a few comments:

  • You should not modify the input to your net when you use DDP. As this input will sometimes be a view of the original input and sometimes not. So you will end up modifying your dataset for some samples at each forward which is most likely not what you want to do.
  • @mrshenli will look into removing the Scatter op that creates a view here as it is not always needed. This will improve perf and remove this error (even though the warning on the previous point will still apply)
  • I will send a patch to fix the internal assert error that happens because we do an inplace op where the input that is modified does not require grads but the other does. And so the inplace checks do no run as it only check for the first input's requires_grad to know if the function is differentiable.

@david-macleod
Copy link
Author

david-macleod commented Oct 14, 2020

Thanks for the quick response, given the scenario where we are wanting to replace some indices of an input tensors with a learnable embedding (determined by a boolean mask, which will be different for all batches and sometimes the mask will be all false), is there a recommended approach for is this behavior or is it just fundamentally incompatible with DDP after torch 1.6?

@albanD
Copy link
Collaborator

albanD commented Oct 14, 2020

Hi,

You just need to clone the input that was given to you before modifying it inplace and all will be good.
In this case, adding a x = x.clone() at the beginning of the forward will make the error go away.

@ezyang ezyang changed the title Assigning a parameter to an indexed tensor no longer works in torch nightly (1.7) Assigning a parameter to an indexed tensor that was produced by DDP no longer works in torch nightly (1.7) Oct 14, 2020
@albanD albanD removed this from the 1.7.0 milestone Oct 14, 2020
@albanD
Copy link
Collaborator

albanD commented Oct 14, 2020

Removing the 1.7 milestone as the fix is fairly dangerous and would only change from an internal assert to a nice error message. The release note will explicitly mention this case as raising an internal error to reduce user confusion when getting this error.

Also note that the underlying autograd issue was already present in 1.6.

@gchanan gchanan added this to the 1.7.0 milestone Oct 14, 2020
@gchanan
Copy link
Contributor

gchanan commented Oct 14, 2020

putting back 1.7.0 milestone while we investigate it from the DDP side.

mrshenli added a commit that referenced this issue Oct 14, 2020
#41567 changed the behavior
of chunk and split, and renamed the previous version as unsafe_*.
As a result, comm.scatter outputs become views, which leads to
the regression reported in #46242

This commit revert to use the previous versions of split and chunk.

[ghstack-poisoned]
mrshenli added a commit that referenced this issue Oct 14, 2020
#41567 changed the behavior
of chunk and split, and renamed the previous version as unsafe_*.
As a result, comm.scatter outputs become views, which leads to
the regression reported in #46242

This commit revert to use the previous versions of split and chunk.

ghstack-source-id: 1676545927922b917c0aa55a4c88c669a781a291
Pull Request resolved: #46361
facebook-github-bot pushed a commit that referenced this issue Oct 15, 2020
Summary:
As per title, temporary mitigation for #46242 for which #46296 will be a proper fix.

Pull Request resolved: #46406

Reviewed By: malfet

Differential Revision: D24339689

Pulled By: albanD

fbshipit-source-id: 0726e5abe4608d8ffcd7846cbaaffbb8564b04ab
@gchanan gchanan removed this from the 1.7.0 milestone Oct 15, 2020
@rohan-varma
Copy link
Member

@albanD is this fixed by #46406?

@albanD
Copy link
Collaborator

albanD commented Nov 16, 2020

The linked PR #46296 is not landed yet.
Note that the "fix" will just be to raise a nice error in this case and asking the user not to do the inplace op (the error message contains a special note about DDP to help users in that case).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: assert failure The issue involves an assert failure module: autograd Related to torch.autograd, and the autograd engine in general module: ddp Issues/PRs related distributed data parallel training module: error checking Bugs related to incorrect/lacking error checking oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
8 participants