Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't get module gradient in autograd.Function's custom backward when DataParallel is used #33800

Open
jerrybai1995 opened this issue Feb 26, 2020 · 2 comments
Labels
module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jerrybai1995
Copy link

jerrybai1995 commented Feb 26, 2020

I noticed a strange behavior when using DataParallel with custom backward pass in autograd.Function. Here is an example:

import torch
import torch.nn as nn
from torch.autograd import Function
torch.set_default_tensor_type('torch.cuda.FloatTensor')
import os

class Combo(nn.Module):
    def __init__(self):
        super(Combo, self).__init__()
        self.func = nn.Conv1d(3,3,3,padding=1)
        
    def forward(self, x):
        z = Debug.apply(self.func, x)
        return z 

class Debug(Function):
    @staticmethod
    def forward(ctx, f, z):
        ctx.save_for_backward(z)
        ctx.f = f
        return z
    
    @staticmethod
    def backward(ctx, grad):
        grad = grad.clone()
        f = ctx.f
        z, = ctx.saved_tensors
        z = z.clone().detach().requires_grad_()
        with torch.enable_grad():
            y = f(z)
        y.backward(torch.randn(z.shape), retain_graph=False)
        print(f.weight.grad)                               # <------------------ HERE
        return None, grad

net = Combo()
para_net = nn.DataParallel(net)

xx = torch.randn(4,3,7).requires_grad_()    # Batch size 4
yy = para_net(xx)
loss = yy.mean()
loss.backward()

I want to compute and update f.weight.grad within the custom backward function (see "<----- HERE" in the code). I found that when CUDA_VISIBLE_DEVICES=0 (i.e., only 1 GPU is used), this works fine; but if I use CUDA_VISIBLE_DEVICES=0,1,2,3, the printed f.weight.grad will be None on each GPU device.

My guess is that when using multiple GPUs, each device will store a copy of f, which creates this problem.

The desired behavior is for each device to compute its own f.weight.grad and then added together when eventually collected by GPU 0. Is there anyway to resolve this?

Thanks a lot!

@agolynski agolynski added module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 26, 2020
@ruoshiliu
Copy link

I'm getting similar error. Has this issue been resolved yet?

@amrhamedp
Copy link

same here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants