-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Open
Labels
actionablemodule: molly-guardFeatures which help prevent users from committing common mistakesFeatures which help prevent users from committing common mistakesmodule: python frontendFor issues relating to PyTorch's Python frontendFor issues relating to PyTorch's Python frontendtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 The error message bug
Writing a custom backward pass by subclassing torch.autograd.Function
and returning (torch.Tensor,)
instead of the expected torch.Tensor
as a gradient output per element received results in
TypeError: only integer tensors of a single element can be converted to an index
reported with the call stack to the backward()
function.
This was also previously reported in torchdiffeq.
The issue is not that returning (torch.Tensor,)
does not work, but that the error message is grossly misleading. I believe this error message can be tracked all the way to the implicit tensor type conversion.
Expected behaviour
It would seem that it would save a few people hours of debugging time if the TypeError was worded, in this specific case, differently:
TypeError: cannot convert an object of type tuple to torch.Tensor
Example code
import torch
from torch.autograd import Function
class MyFunction(Function):
@staticmethod
def forward(ctx, input: torch.Tensor):
ctx.save_for_backward(input)
return input
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors # returns (torch.Tensor,) and not torch.Tensor!
return input * grad_output # causes the TypeError in backward pass stemming from loss.backward()
Versions
[pip3] numpy==1.26.2
[pip3] pytorch-msssim==1.0.0
[pip3] torch==2.1.2
[pip3] torchaudio==2.1.1
[pip3] torchvision==0.16.2
[pip3] triton==2.1.0
[conda] numpy 1.26.2 pypi_0 pypi
[conda] pytorch-msssim 1.0.0 pypi_0 pypi
[conda] torch 2.1.2 pypi_0 pypi
[conda] torchaudio 2.1.1 pypi_0 pypi
[conda] torchvision 0.16.2 pypi_0 pypi
[conda] triton 2.1.0 pypi_0 pypi
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7
Metadata
Metadata
Assignees
Labels
actionablemodule: molly-guardFeatures which help prevent users from committing common mistakesFeatures which help prevent users from committing common mistakesmodule: python frontendFor issues relating to PyTorch's Python frontendFor issues relating to PyTorch's Python frontendtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module