Can't get module gradient in autograd.Function's custom backward when DataParallel is used #33800
Labels
module: data parallel
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
I noticed a strange behavior when using
DataParallel
with custom backward pass inautograd.Function
. Here is an example:I want to compute and update
f.weight.grad
within the custombackward
function (see "<----- HERE" in the code). I found that whenCUDA_VISIBLE_DEVICES=0
(i.e., only 1 GPU is used), this works fine; but if I useCUDA_VISIBLE_DEVICES=0,1,2,3
, the printedf.weight.grad
will beNone
on each GPU device.My guess is that when using multiple GPUs, each device will store a copy of
f
, which creates this problem.The desired behavior is for each device to compute its own
f.weight.grad
and then added together when eventually collected by GPU 0. Is there anyway to resolve this?Thanks a lot!
The text was updated successfully, but these errors were encountered: