Skip to content

TorchScript does not allow accessing methods of nested tensors #156544

@PierreGtch

Description

@PierreGtch

🐛 Describe the bug

TorchScript seems to have issues with accessing the nested-specific methods of tensors, such as offsets:

Minimal example

def f(x):
    return x.offsets()

nt = torch.nested.as_nested_tensor(torch.randn(2,3), layout=torch.jagged)

f(nt)  # tensor([0, 3, 6])

torch.jit.script(f)  # ERROR 1

torch.jit.trace(f, example_inputs=nt)  # ERROR 2

ERROR 1 (script)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 1
----> 1 torch.jit.script(f)

File site-packages/torch/jit/_script.py:1443, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1441 prev = _TOPLEVEL
   1442 _TOPLEVEL = False
-> 1443 ret = _script_impl(
   1444     obj=obj,
   1445     optimize=optimize,
   1446     _frames_up=_frames_up + 1,
   1447     _rcb=_rcb,
   1448     example_inputs=example_inputs,
   1449 )
   1451 if prev:
   1452     log_torchscript_usage("script", model_id=_get_model_id(ret))

File site-packages/torch/jit/_script.py:1214, in _script_impl(obj, optimize, _frames_up, _rcb, example_inputs)
   1212 if _rcb is None:
   1213     _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
-> 1214 fn = torch._C._jit_script_compile(
   1215     qualified_name, ast, _rcb, get_default_args(obj)
   1216 )
   1217 # Forward docstrings
   1218 fn.__doc__ = obj.__doc__

RuntimeError:
'Tensor (inferred)' object has no attribute or method 'offsets'.:
  File "<ipython-input-12-39aeca0b7bc5>", line 2
def f(x):
    return x.offsets(), x
           ~~~~~~~~~ <--- HERE

ERROR 2 (trace)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[17], line 1
----> 1 torch.jit.trace(f, example_inputs=nt)

File site-packages/torch/jit/_trace.py:1002, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    989     warnings.warn(
    990         "`optimize` is deprecated and has no effect. "
    991         "Use `with torch.jit.optimized_execution()` instead",
    992         FutureWarning,
    993         stacklevel=2,
    994     )
    996 from torch._utils_internal import (
    997     check_if_torch_exportable,
    998     log_torch_jit_trace_exportability,
    999     log_torchscript_usage,
   1000 )
-> 1002 traced_func = _trace_impl(
   1003     func,
   1004     example_inputs,
   1005     optimize,
   1006     check_trace,
   1007     check_inputs,
   1008     check_tolerance,
   1009     strict,
   1010     _force_outplace,
   1011     _module_class,
   1012     _compilation_unit,
   1013     example_kwarg_inputs,
   1014     _store_inputs,
   1015 )
   1016 log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
   1018 if check_if_torch_exportable():

File site-packages/torch/jit/_trace.py:764, in _trace_impl(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    754     traced = torch._C._create_function_from_trace_with_dict(
    755         name,
    756         func,
   (...)
    761         get_callable_argument_names(func),
    762     )
    763 else:
--> 764     traced = torch._C._create_function_from_trace(
    765         name,
    766         func,
    767         example_inputs,
    768         var_lookup_fn,
    769         strict,
    770         _force_outplace,
    771         get_callable_argument_names(func),
    772     )
    774 # Check the trace against new traces created from user-specified inputs
    775 if check_trace:

RuntimeError: output 1 ( 0
 3
 6
[ CPULongType{3} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.

Versions

Collecting environment information...
PyTorch version: 2.7.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.22.4
Libc version: N/A

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.2.2
[pip3] pytorch-metric-learning==2.8.1
[pip3] torch==2.7.1
[pip3] torchaudio==2.7.1
[pip3] torchdata==0.7.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.2
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-lightning         2.2.2                    pypi_0    pypi
[conda] pytorch-metric-learning   2.8.1                    pypi_0    pypi
[conda] torch                     2.7.1                    pypi_0    pypi
[conda] torchaudio                2.7.1                    pypi_0    pypi
[conda] torchdata                 0.7.1                    pypi_0    pypi
[conda] torchinfo                 1.8.0                    pypi_0    pypi
[conda] torchmetrics              1.3.2                    pypi_0    pypi

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions