-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🐛 Bug
The grad_fn
of torch.where
returns the gradients of the wrong argument, rather than of the selected tensor, if the other tensor's gradients have infs or nans.
To Reproduce
Run this code:
x = torch.tensor([16., 0.], requires_grad=True)
y = x/2 # tensor([8., 0.], grad_fn=<DivBackward0>)
z = x.sqrt() + 1 # tensor([5., 1.], grad_fn=<SqrtBackward>)
# Calculate dy/dx, dz/dx
dydx = torch.autograd.grad(y.sum(), x, retain_graph=True)[0] # tensor([0.5000, 0.5000])
dzdx = torch.autograd.grad(z.sum(), x, retain_graph=True)[0] # tensor([0.1250, inf])
# Define w = [w0, w1] == [y0, z1]
w = torch.where(x == 0., y, z) # tensor([5., 0.], grad_fn=<SWhereBackward>)
expected_dw_dx = torch.where(x == 0., dydx, dzdx) # tensor([0.1250, 0.5000])
dwdx = torch.autograd.grad(w.sum(), x, retain_graph=True)[0] # is actually tensor([0.1250, inf])
print("`torch.where` communicates gradients correctly:", torch.equal(expected_dw_dx, dwdx))
Expected behavior
I would expect expected_dw_dx == dwdx
in the example above.
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.1.0
[pip] torchvision==0.3.0
[conda] torch 1.1.0
[conda] torchvision 0.3.0