Skip to content

backward implicit conversion from tuple to torch.Tensor results in an indexing error message #116533

@pbelcak

Description

@pbelcak

🐛 The error message bug

Writing a custom backward pass by subclassing torch.autograd.Function and returning (torch.Tensor,) instead of the expected torch.Tensor as a gradient output per element received results in

TypeError: only integer tensors of a single element can be converted to an index

reported with the call stack to the backward() function.

This was also previously reported in torchdiffeq.

The issue is not that returning (torch.Tensor,) does not work, but that the error message is grossly misleading. I believe this error message can be tracked all the way to the implicit tensor type conversion.

Expected behaviour
It would seem that it would save a few people hours of debugging time if the TypeError was worded, in this specific case, differently:

TypeError: cannot convert an object of type tuple to torch.Tensor

Example code

import torch
from torch.autograd import Function

class MyFunction(Function):
	@staticmethod
	def forward(ctx, input: torch.Tensor):
		ctx.save_for_backward(input)
		return input
	
	@staticmethod
	def backward(ctx, grad_output):
		input = ctx.saved_tensors # returns (torch.Tensor,) and not torch.Tensor!
		return input * grad_output # causes the TypeError in backward pass stemming from loss.backward()

Versions

[pip3] numpy==1.26.2
[pip3] pytorch-msssim==1.0.0
[pip3] torch==2.1.2
[pip3] torchaudio==2.1.1
[pip3] torchvision==0.16.2
[pip3] triton==2.1.0
[conda] numpy                     1.26.2                   pypi_0    pypi
[conda] pytorch-msssim            1.0.0                    pypi_0    pypi
[conda] torch                     2.1.2                    pypi_0    pypi
[conda] torchaudio                2.1.1                    pypi_0    pypi
[conda] torchvision               0.16.2                   pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: molly-guardFeatures which help prevent users from committing common mistakesmodule: python frontendFor issues relating to PyTorch's Python frontendtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions