-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Open
Labels
module: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Describe the bug
TorchScript
seems to have issues with accessing the nested-specific methods of tensors, such as offsets
:
Minimal example
def f(x):
return x.offsets()
nt = torch.nested.as_nested_tensor(torch.randn(2,3), layout=torch.jagged)
f(nt) # tensor([0, 3, 6])
torch.jit.script(f) # ERROR 1
torch.jit.trace(f, example_inputs=nt) # ERROR 2
ERROR 1 (script)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[16], line 1
----> 1 torch.jit.script(f)
File site-packages/torch/jit/_script.py:1443, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1441 prev = _TOPLEVEL
1442 _TOPLEVEL = False
-> 1443 ret = _script_impl(
1444 obj=obj,
1445 optimize=optimize,
1446 _frames_up=_frames_up + 1,
1447 _rcb=_rcb,
1448 example_inputs=example_inputs,
1449 )
1451 if prev:
1452 log_torchscript_usage("script", model_id=_get_model_id(ret))
File site-packages/torch/jit/_script.py:1214, in _script_impl(obj, optimize, _frames_up, _rcb, example_inputs)
1212 if _rcb is None:
1213 _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
-> 1214 fn = torch._C._jit_script_compile(
1215 qualified_name, ast, _rcb, get_default_args(obj)
1216 )
1217 # Forward docstrings
1218 fn.__doc__ = obj.__doc__
RuntimeError:
'Tensor (inferred)' object has no attribute or method 'offsets'.:
File "<ipython-input-12-39aeca0b7bc5>", line 2
def f(x):
return x.offsets(), x
~~~~~~~~~ <--- HERE
ERROR 2 (trace)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[17], line 1
----> 1 torch.jit.trace(f, example_inputs=nt)
File site-packages/torch/jit/_trace.py:1002, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
989 warnings.warn(
990 "`optimize` is deprecated and has no effect. "
991 "Use `with torch.jit.optimized_execution()` instead",
992 FutureWarning,
993 stacklevel=2,
994 )
996 from torch._utils_internal import (
997 check_if_torch_exportable,
998 log_torch_jit_trace_exportability,
999 log_torchscript_usage,
1000 )
-> 1002 traced_func = _trace_impl(
1003 func,
1004 example_inputs,
1005 optimize,
1006 check_trace,
1007 check_inputs,
1008 check_tolerance,
1009 strict,
1010 _force_outplace,
1011 _module_class,
1012 _compilation_unit,
1013 example_kwarg_inputs,
1014 _store_inputs,
1015 )
1016 log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
1018 if check_if_torch_exportable():
File site-packages/torch/jit/_trace.py:764, in _trace_impl(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
754 traced = torch._C._create_function_from_trace_with_dict(
755 name,
756 func,
(...)
761 get_callable_argument_names(func),
762 )
763 else:
--> 764 traced = torch._C._create_function_from_trace(
765 name,
766 func,
767 example_inputs,
768 var_lookup_fn,
769 strict,
770 _force_outplace,
771 get_callable_argument_names(func),
772 )
774 # Check the trace against new traces created from user-specified inputs
775 if check_trace:
RuntimeError: output 1 ( 0
3
6
[ CPULongType{3} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.
Versions
Collecting environment information...
PyTorch version: 2.7.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.22.4
Libc version: N/A
Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.2.2
[pip3] pytorch-metric-learning==2.8.1
[pip3] torch==2.7.1
[pip3] torchaudio==2.7.1
[pip3] torchdata==0.7.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.2
[conda] numpy 1.26.4 pypi_0 pypi
[conda] pytorch-lightning 2.2.2 pypi_0 pypi
[conda] pytorch-metric-learning 2.8.1 pypi_0 pypi
[conda] torch 2.7.1 pypi_0 pypi
[conda] torchaudio 2.7.1 pypi_0 pypi
[conda] torchdata 0.7.1 pypi_0 pypi
[conda] torchinfo 1.8.0 pypi_0 pypi
[conda] torchmetrics 1.3.2 pypi_0 pypi
cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ
Metadata
Metadata
Assignees
Labels
module: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue