Skip to content

Unable to return Tensor or List[Tensor] as Any type from TorchScript function #42646

@yf225

Description

@yf225

🐛 Bug

I tried this:

from typing import Any, List
import torch

class TestModule(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, input: torch.Tensor) -> Any:
    ret: List[torch.Tensor] = []
    if input.shape[0] == 1:
      return input
    else:
      ret.append(input)
      return ret

m = TestModule()
m_scripted = torch.jit.script(m)

and it throws:

Traceback (most recent call last):
  File "test_yf225.py", line 17, in <module>
    m_scripted = torch.jit.script(m)
  File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/__init__.py", line 1516, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_recursive.py", line 318, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_recursive.py", line 376, in create_script_module_impl
    create_methods_from_stubs(concrete_type, stubs)
  File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_recursive.py", line 292, in create_methods_from_stubs
    concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError: it != all_nodes.end() INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/ir/ir.cpp":1806, please report a bug to PyTorch.

Expected behavior

From my understanding, returning Any type from TorchScript function should work.

Environment

PyTorch nightly build

Additional context

One of the DPER3 modules (https://fburl.com/diffusion/bkbuoid2) used by the PyPer Ads ctr_mbl_feed model can return either torch.Tensor or List[torch.Tensor] based on a constructor argument. Without Any support here, we will need to split the module into two modules (corresponding to two different return types), which will cause a lot of confusion and downstream refactoring, slowing down the progress on scripting PyPer Ads models.

cc. @SplitInfinity @wanchaol @suo

cc @suo @gmagogsfm

Metadata

Metadata

Assignees

Labels

daysoncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions