Skip to content

Incorrect gradients for torch.where when one of the target tensors contains inf/nan #23395

@egrefen

Description

@egrefen

🐛 Bug

The grad_fn of torch.where returns the gradients of the wrong argument, rather than of the selected tensor, if the other tensor's gradients have infs or nans.

To Reproduce

Run this code:

x = torch.tensor([16., 0.], requires_grad=True)
y = x/2  # tensor([8., 0.], grad_fn=<DivBackward0>)
z = x.sqrt() + 1  # tensor([5., 1.], grad_fn=<SqrtBackward>)

# Calculate dy/dx, dz/dx
dydx = torch.autograd.grad(y.sum(), x, retain_graph=True)[0]  # tensor([0.5000, 0.5000])
dzdx = torch.autograd.grad(z.sum(), x, retain_graph=True)[0]  # tensor([0.1250,    inf])

# Define w = [w0, w1] == [y0, z1]
w = torch.where(x == 0., y, z)  # tensor([5., 0.], grad_fn=<SWhereBackward>)
expected_dw_dx = torch.where(x == 0., dydx, dzdx)  # tensor([0.1250, 0.5000])
dwdx = torch.autograd.grad(w.sum(), x, retain_graph=True)[0]  # is actually tensor([0.1250, inf])
print("`torch.where` communicates gradients correctly:", torch.equal(expected_dw_dx, dwdx))

Expected behavior

I would expect expected_dw_dx == dwdx in the example above.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.1.0
[pip] torchvision==0.3.0
[conda] torch 1.1.0
[conda] torchvision 0.3.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions