Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.tensordot has inconsistent signature with torch script #32695

Open
ferrine opened this issue Jan 28, 2020 · 1 comment
Open

torch.tensordot has inconsistent signature with torch script #32695

ferrine opened this issue Jan 28, 2020 · 1 comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ferrine
Copy link

ferrine commented Jan 28, 2020

馃悰 Bug

Inconsistent API across torch script and no script

To Reproduce

Steps to reproduce the behavior:

Working samples

def my_tensordot(a, b):
    return torch.tensordot(a, b, [[1], [1]])
@torch.jit.script
def my_tensordot(a, b):
    return torch.tensordot(a, b, [1], [1])

Broken one (using same API as in python)

@torch.jit.script
def my_tensordot(a, b):
    return torch.tensordot(a, b, [[1], [1]])

I get

aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> (Tensor):
Expected a value of type 'List[int]' for argument 'dims_self' but instead found type 'List[List[int]]'.

Removing jit decorator from working jit example does not help

def my_tensordot(a, b):
    return torch.tensordot(a, b, [1], [1])
TypeError: tensordot() takes from 2 to 3 positional arguments but 4 were given

Expected behavior

Same API should work in jit and non jit mode, like here

@torch.jit.script
def my_tensordot(a, b):
    return torch.tensordot(a, b, [[1], [1]])

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: 10.1.243

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.6
Is CUDA available: No
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 950M
Nvidia driver version: 384.130
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.17.1
[pip3] torch==1.3.1
[pip3] torch-cluster==1.4.5
[pip3] torch-geometric==1.3.2
[pip3] torch-scatter==1.4.0
[pip3] torch-sparse==0.4.3
[pip3] torch-spline-conv==1.1.1
[pip3] torchvision==0.4.2
[conda] Could not collect

Additional context

I found this running python with PYTORCH_JIT=false

Possible Solutions

  1. Overload both functions to handle both cases. No user code change required, errors dissapear
  2. Jit supports python API variant: few users are affected. pytorch version should be checked at import time
  3. Python function supports jit API variant: a lot of users are affected. pytorch version should be checked at import time

cc @suo

@ferrine ferrine mentioned this issue Jan 28, 2020
4 tasks
@mrshenli mrshenli added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 28, 2020
@suo suo added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 28, 2020
@xionghuaidong
Copy link

xionghuaidong commented Oct 21, 2021

The following is a workaround to this problem.

import torch

class OpHook(torch.nn.Module):
    def tensordot(self, a, b, dims_a, dims_b):
        return torch.tensordot(a, b, (dims_a, dims_b))

class OnlineOpHook(OpHook):
    from typing import List
    def tensordot(self, a, b, dims_a: List[int], dims_b: List[int]):
        return torch.tensordot(a, b, dims_a, dims_b)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._op = OpHook()

    def forward(self, x): 
        x = self._op.tensordot(x, x, [0], [0])
        return x + 1

    def prepare(self):
        self._op = OnlineOpHook()

    def restore(self):
        self._op = OpHook()

mod = MyModule()
x = torch.tensor([3])
y = mod(x)
print('x = {}'.format(x))
print('y = {}'.format(y))
mod.prepare()
scm = torch.jit.script(mod)
mod.restore()
torch.jit.save(scm, './my_module.ptm')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants