Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.randint with out= kwarg: SymIntArrayRef expected to contain only concrete integers #121897

Closed
rahulvijayaraghavan opened this issue Mar 14, 2024 · 3 comments
Assignees
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@rahulvijayaraghavan
Copy link

rahulvijayaraghavan commented Mar 14, 2024

馃悰 Describe the bug

Error "SymIntArrayRef expected to contain only concrete integers" is observed for following repro. Not seen when dynamic=False is set.

Error logs

TorchRuntimeError                         Traceback (most recent call last)
[<ipython-input-12-e46ea442705a>](https://localhost:8080/#) in <cell line: 13>()
     11 
     12 out2 = torch.empty(12)
---> 13 opt_model(17, (12, ), out2)
     14 print(out2)

28 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    487                 dynamo_config_ctx.__enter__()
    488             try:
--> 489                 return fn(*args, **kwargs)
    490             finally:
    491                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in catch_errors(frame, cache_entry, frame_state)
    653 
    654         with compile_lock, _disable_current_modes():
--> 655             return callback(frame, cache_entry, hooks, frame_state)
    656 
    657     catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _convert_frame(frame, cache_entry, hooks, frame_state)
    725         counters["frames"]["total"] += 1
    726         try:
--> 727             result = inner_convert(frame, cache_entry, hooks, frame_state)
    728             counters["frames"]["ok"] += 1
    729             return result

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _convert_frame_assert(frame, cache_entry, hooks, frame_state)
    381 
    382         with config.patch(_patch_config_if_changed()):
--> 383             compiled_product = _compile(
    384                 frame.f_code,
    385                 frame.f_globals,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id)
    644     with compile_context(CompileContext(compile_id)):
    645         try:
--> 646             guarded_code = compile_inner(code, one_graph, hooks, transform)
    647             return guarded_code
    648         except (

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    242             with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    243                 t0 = time.time()
--> 244                 r = func(*args, **kwargs)
    245                 time_spent = time.time() - t0
    246             compilation_time_metrics[key].append(time_spent)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in compile_inner(code, one_graph, hooks, transform)
    560             CompileContext.get().attempt = attempt
    561             try:
--> 562                 out_code = transform_code_object(code, transform)
    563                 break
    564             except exc.RestartAnalysis as e:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in transform_code_object(code, transformations, safe)
   1031     propagate_line_nums(instructions)
   1032 
-> 1033     transformations(instructions, code_options)
   1034     return clean_and_assemble_instructions(instructions, keys, code_options)[1]
   1035 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    149         cleanup = setup_compile_debug()
    150         try:
--> 151             return fn(*args, **kwargs)
    152         finally:
    153             cleanup.close()

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in transform(instructions, code_options)
    525         try:
    526             with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 527                 tracer.run()
    528         except exc.UnspecializeRestartAnalysis:
    529             speculation_log.clear()

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
   2126 
   2127     def run(self):
-> 2128         super().run()
   2129 
   2130     def match_nested_cell(self, name, cell):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
    816                     self.instruction_pointer is not None
    817                     and not self.output.should_exit
--> 818                     and self.step()
    819                 ):
    820                     pass

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in step(self)
    779                 self.f_code.co_filename, self.lineno, self.f_code.co_name
    780             )
--> 781             getattr(self, inst.opname)(inst)
    782 
    783             return inst.opname != "RETURN_VALUE"

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in wrapper(self, inst)
    468                     self.f_code.co_filename, self.lineno, self.f_code.co_name
    469                 )
--> 470                 return inner_fn(self, inst)
    471             except Unsupported as excp:
    472                 if self.should_compile_partial_graph() and self.has_backedge():

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in CALL_FUNCTION_KW(self, inst)
   1262         kwargs = dict(zip(argnames, kwargs_list))
   1263         assert len(kwargs) == len(argnames)
-> 1264         self.call_function(fn, args, kwargs)
   1265 
   1266     def LOAD_METHOD_SUPER(self, inst):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in call_function(self, fn, args, kwargs)
    650         if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    651             raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 652         self.push(fn.call_function(self, args, kwargs))
    653 
    654     def update_locals_and_stack(self, oldvar: VariableTracker, newvar: VariableTracker):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/torch.py](https://localhost:8080/#) in call_function(self, tx, args, kwargs)
    540                     fn_ = torch._refs.tensor
    541 
--> 542             tensor_variable = wrap_fx_proxy(
    543                 tx=tx,
    544                 proxy=tx.output.create_proxy(

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py](https://localhost:8080/#) in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
   1312     }
   1313     if subclass_type is None:
-> 1314         return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
   1315     else:
   1316         result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py](https://localhost:8080/#) in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
   1397             # only allow_non_graph_fake in this instance because we handle the non-fake
   1398             # cases properly below.
-> 1399             example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
   1400 
   1401         # Handle recursive calls here

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in get_fake_value(node, tx, allow_non_graph_fake)
   1523         elif isinstance(cause, ValueRangeError):
   1524             raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e
-> 1525         raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
   1526 
   1527     if not allow_non_graph_fake:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in get_fake_value(node, tx, allow_non_graph_fake)
   1484     try:
   1485         with tx.fake_mode, enable_python_dispatcher():
-> 1486             ret_val = wrap_fake_exception(
   1487                 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1488             )

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in wrap_fake_exception(fn)
   1025 def wrap_fake_exception(fn):
   1026     try:
-> 1027         return fn()
   1028     except UnsupportedFakeTensorException as e:
   1029         from .exc import unimplemented

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in <lambda>()
   1485         with tx.fake_mode, enable_python_dispatcher():
   1486             ret_val = wrap_fake_exception(
-> 1487                 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1488             )
   1489     except Unsupported:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in run_node(tracer, node, args, kwargs, nnmodule)
   1590         except Exception as e:
   1591             fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n"
-> 1592             raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
   1593 
   1594     raise AssertionError(op)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in run_node(tracer, node, args, kwargs, nnmodule)
   1569         try:
   1570             if op == "call_function":
-> 1571                 return node.target(*args, **kwargs)
   1572             elif op == "call_method":
   1573                 return getattr(args[0], node.target)(*args[1:], **kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
     18             simple_call_counter[fn.__qualname__] = 0
     19         simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 20         return fn(*args, **kwargs)
     21     return wrapper

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in __torch_dispatch__(self, func, types, args, kwargs)
   1390         ), func
   1391         try:
-> 1392             return self.dispatch(func, types, args, kwargs)
   1393         except TypeError:
   1394             log.exception("fake tensor raised TypeError")

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in dispatch(self, func, types, args, kwargs)
   1710         try:
   1711             with in_kernel_invocation_manager(self):
-> 1712                 r = func(*args, **kwargs)
   1713         except NotImplementedError as not_implemented_error:
   1714             return maybe_run_unsafe_fallback(not_implemented_error)

[/usr/local/lib/python3.10/dist-packages/torch/_ops.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    511 
    512     def __call__(self, *args, **kwargs):
--> 513         return self._op(*args, **(kwargs or {}))
    514 
    515     def __hash__(self):

TorchRuntimeError: Failed running call_function <built-in method randint of type object at 0x7da562a85840>(*(17, (s0,)), **{'out': FakeTensor(..., size=(s1,))}):
aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3470: SymIntArrayRef expected to contain only concrete integers

from user code:
   File "<ipython-input-12-e46ea442705a>", line 4, in randint_fn
    return torch.randint(high, size, out=out)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Minified repro

import torch

def randint_fn(high, size, out):
  return torch.randint(high, size, out=out)

opt_model = torch.compile(randint_fn)

out1 = torch.empty(10)
opt_model(17, (10, ), out1)
print(out1)

out2 = torch.empty(12)
opt_model(17, (12, ), out2)
print(out2)

Versions

Collecting environment information...
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.9
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.58+-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang

@ezyang
Copy link
Contributor

ezyang commented Mar 15, 2024

This is out suckage. I'm not sure what you're trying to do, but if you can avoid dynamic=True (use automatic dynamic instead) that can help. Or rewrite your code not to use out.

@rahulvijayaraghavan
Copy link
Author

@ezyang
Issue is observed in default mode. (where dynamic=None)

@ezyang
Copy link
Contributor

ezyang commented Mar 15, 2024

Oops, try the "rewrite code not to use out" suggestion then.

@anijain2305 anijain2305 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 18, 2024
@ezyang ezyang changed the title torch.randint: SymIntArrayRef expected to contain only concrete integers torch.randint with out= kwarg: SymIntArrayRef expected to contain only concrete integers Mar 19, 2024
@ezyang ezyang self-assigned this Mar 21, 2024
ezyang added a commit that referenced this issue Mar 21, 2024
Fixes #121897

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 8594a8b8d2b21b308f7a7560459c1dc6ad7b478e
Pull Request resolved: #122375
pytorchmergebot pushed a commit that referenced this issue Mar 24, 2024
Fixes #121897

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 23a4702ed3929acfb7ed8e51613408f584952cb9
Pull Request resolved: #122375
pytorch-bot bot pushed a commit that referenced this issue Apr 22, 2024
Fixes #121897

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #122375
Approved by: https://github.com/lezcano
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants