Skip to content

_C.disabled_torch_function doesn't return NotImplemented for unrecognized types #64687

@ezyang

Description

@ezyang

Reported by @pritamdamania

repro

import torch

class MyTensor():
    
    def __torch_function__(self, func, types, args=(), kwargs=None):
        # ignore all errors.
        print('called!')
        pass

# torch function called.
t1 = MyTensor()
t2 = torch.nn.Parameter(torch.rand(2, 2))
torch.add(t2, t1)

# torch function not called.
inp = torch.rand(10, 10)
torch.nn.functional.linear(inp, t1, t2)

i know what's going on

cc @hameerabbasi @rgommers @peterbell10

Metadata

Metadata

Assignees

Labels

module: __torch_function__triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions