Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile dynamo fails indexing into array from internal mutable state #123535

Open
MatthewCaseres opened this issue Apr 8, 2024 · 1 comment
Labels
high priority module: dynamic shapes module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@MatthewCaseres
Copy link
Contributor

MatthewCaseres commented Apr 8, 2024

馃悰 Describe the bug

the code

import torch
import torch.nn as nn
import logging

class CompiledClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.nums = torch.tensor([1,2,3,4,5,6,7,8,9,10])
        self.t = 5
    
    def forward(self):
        self.num = self.nums[self.t//12]
        self.t += 1
        return self.num
    
m = CompiledClass()
m = torch.compile(m, backend="eager")

torch._logging.set_logs(dynamo = logging.DEBUG)
torch._dynamo.config.verbose = True

# the first call works
m()
# the second call causes a failure
m()

Error logs

...
[2024-04-08 00:21:47,967] [0/0_1] torch._dynamo.symbolic_convert: [DEBUG] RETURN_VALUE triggered compile
[2024-04-08 00:21:47,967] [0/0_1] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /workspaces/torch-indexing-reproduction/main.py, line 14 in forward>], graph_break=False)
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_self_nums : torch.Tensor):
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_self_nums = L_self_nums
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /workspaces/torch-indexing-reproduction/main.py:12, code: self.num = self.nums[self.t//12]
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]         getitem = l_self_nums[0];  l_self_nums = None
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (getitem,)
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-04-08 00:21:47,968] [0/0_1] torch._dynamo.output_graph.__graph_code: [DEBUG] 
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG] Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]  ===== __compiled_fn_0 =====
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]     def forward(self, L_self_nums : torch.Tensor):
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]         l_self_nums = L_self_nums
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]         
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]         # File: /workspaces/torch-indexing-reproduction/main.py:12, code: self.num = self.nums[self.t//12]
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]         getitem = l_self_nums[0];  l_self_nums = None
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]         return (getitem,)
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG]         
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph: [DEBUG] 
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph_sizes: [DEBUG] TRACED GRAPH TENSOR SIZES
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph_sizes: [DEBUG] ===== __compiled_fn_0 =====
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_self_nums: (10,)
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph_sizes: [DEBUG] getitem: ()
[2024-04-08 00:21:47,969] [0/0_1] torch._dynamo.output_graph.__graph_sizes: [DEBUG] 
[2024-04-08 00:21:47,970] [0/0_1] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function eager
[2024-04-08 00:21:47,970] [0/0_1] torch._dynamo.output_graph: [INFO] Step 2: done compiler function eager
[2024-04-08 00:21:48,034] [0/0_1] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards
[2024-04-08 00:21:48,034] [0/0_1] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['self'].nums.size()[0] 10 None
[2024-04-08 00:21:48,034] [0/0_1] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['self'].nums.stride()[0] 1 None
[2024-04-08 00:21:48,034] [0/0_1] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['self'].nums.storage_offset() 0 None
[2024-04-08 00:21:48,034] [0/0_1] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['self'].nums.size()[0] == 10
[2024-04-08 00:21:48,034] [0/0_1] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['self'].nums.stride()[0] == 1
[2024-04-08 00:21:48,035] [0/0_1] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['self'].nums.storage_offset() == 0
[2024-04-08 00:21:48,035] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-04-08 00:21:48,035] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['self'], 94819259526288)                   # self.num = self.nums[self.t//12]  # orkspaces/torch-indexing-reproduction/main.py:12 in forward
[2024-04-08 00:21:48,035] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['self'].t, 140484762481376)                # self.num = self.nums[self.t//12]  # orkspaces/torch-indexing-reproduction/main.py:12 in forward
[2024-04-08 00:21:48,036] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] L['self'].t == 5                                              # self.num = self.nums[self.t//12]  # orkspaces/torch-indexing-reproduction/main.py:12 in forward
[2024-04-08 00:21:48,036] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['self'].nums, '_dynamo_dynamic_indices') == False   # self.num = self.nums[self.t//12]  # orkspaces/torch-indexing-reproduction/main.py:12 in forward
[2024-04-08 00:21:48,036] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:379 in init_ambient_guards
[2024-04-08 00:21:48,037] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140481868239056))  # _dynamo/output_graph.py:385 in init_ambient_guards
[2024-04-08 00:21:48,037] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] ___compile_config_hash() == '88a14d47e62622e2d97d70c8d06ad8bd'  # _dynamo/output_graph.py:387 in init_ambient_guards
[2024-04-08 00:21:48,037] [0/0_1] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['self'].nums, Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.int64, device=None, requires_grad=False, size=[10], stride=[1])  # self.num = self.nums[self.t//12]  # orkspaces/torch-indexing-reproduction/main.py:12 in forward
[2024-04-08 00:21:48,043] torch._dynamo.eval_frame: [DEBUG] skipping: _fn (reason: in skipfiles, file: /home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py)
[2024-04-08 00:21:48,043] torch._dynamo.eval_frame: [DEBUG] skipping: is_tracing (reason: in skipfiles, file: /home/vscode/.local/lib/python3.11/site-packages/torch/jit/_trace.py)
[2024-04-08 00:21:48,044] torch._dynamo.eval_frame: [DEBUG] skipping: is_scripting (reason: in skipfiles, file: /home/vscode/.local/lib/python3.11/site-packages/torch/_jit_internal.py)
[2024-04-08 00:21:48,044] torch._dynamo.eval_frame: [DEBUG] skipping: nothing (reason: in skipfiles, file: /home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py)
[2024-04-08 00:21:48,044] torch._dynamo.eval_frame: [DEBUG] skipping: __exit__ (reason: in skipfiles, file: /usr/local/lib/python3.11/contextlib.py)
[2024-04-08 00:21:48,044] torch._dynamo.eval_frame: [DEBUG] skipping: __exit__ (reason: in skipfiles, file: /usr/local/lib/python3.11/contextlib.py)
[2024-04-08 00:21:48,046] torch._dynamo.eval_frame: [DEBUG] skipping: __setattr__ (reason: in skipfiles, file: /home/vscode/.local/lib/python3.11/site-packages/torch/nn/modules/module.py)
[2024-04-08 00:21:48,046] torch._dynamo.eval_frame: [DEBUG] skipping: __instancecheck__ (reason: in skipfiles, file: /home/vscode/.local/lib/python3.11/site-packages/torch/nn/parameter.py)
[2024-04-08 00:21:48,046] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 88a14d47e62622e2d97d70c8d06ad8bd
[2024-04-08 00:21:48,046] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 88a14d47e62622e2d97d70c8d06ad8bd
[2024-04-08 00:21:48,047] [0/1] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /workspaces/torch-indexing-reproduction/main.py:11
[2024-04-08 00:21:48,048] [0/1] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-04-08 00:21:48,048] [0/1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /workspaces/torch-indexing-reproduction/main.py:11 in forward (CompiledClass.forward)
[2024-04-08 00:21:48,048] [0/1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self):
[2024-04-08 00:21:48,048] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-04-08 00:21:48,049] [0/1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /workspaces/torch-indexing-reproduction/main.py:12 in forward (CompiledClass.forward)
[2024-04-08 00:21:48,049] [0/1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             self.num = self.nums[self.t//12]
[2024-04-08 00:21:48,049] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self []
[2024-04-08 00:21:48,049] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR nums [LazyVariableTracker()]
[2024-04-08 00:21:48,050] [0/1] torch._dynamo.output_graph: [DEBUG] create_graph_input L_self_nums L['self'].nums
[2024-04-08 00:21:48,051] [0/1] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['self'].nums (10,) [<DimDynamic.STATIC: 2>] [None]
[2024-04-08 00:21:48,052] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self [TensorVariable()]
[2024-04-08 00:21:48,052] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR t [TensorVariable(), UnspecializedNNModuleVariable()]
[2024-04-08 00:21:48,052] [0/1] torch._dynamo.variables.builder: [DEBUG] automatic dynamic int L['self'].t val 6 != 5
[2024-04-08 00:21:48,052] [0/1] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s0 = 6 for L['self'].t [-9223372036854775808, 9223372036854775807]
[2024-04-08 00:21:48,053] [0/1] torch._dynamo.output_graph: [DEBUG] create_graph_input L_self_t L['self'].t
[2024-04-08 00:21:48,053] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST 12 [TensorVariable(), SymNodeVariable()]
[2024-04-08 00:21:48,053] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_OP 2 [TensorVariable(), SymNodeVariable(), ConstantVariable(int)]
[2024-04-08 00:21:48,053] [0/1] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call floordiv from /workspaces/torch-indexing-reproduction/main.py:12 in forward (CompiledClass.forward)
[2024-04-08 00:21:48,053] [0/1] torch._dynamo.output_graph.__trace_call: [DEBUG]         self.num = self.nums[self.t//12]
[2024-04-08 00:21:48,053] [0/1] torch._dynamo.output_graph.__trace_call: [DEBUG]                              ~~~~~~^^~~
[2024-04-08 00:21:48,055] [0/1] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_SUBSCR None [TensorVariable(), SymNodeVariable()]
[2024-04-08 00:21:48,055] [0/1] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call select from /workspaces/torch-indexing-reproduction/main.py:12 in forward (CompiledClass.forward)
[2024-04-08 00:21:48,055] [0/1] torch._dynamo.output_graph.__trace_call: [DEBUG]         self.num = self.nums[self.t//12]
[2024-04-08 00:21:48,055] [0/1] torch._dynamo.output_graph.__trace_call: [DEBUG]                    ~~~~~~~~~^^^^^^^^^^^^
[2024-04-08 00:21:48,099] [0/1] torch.fx.experimental.symbolic_shapes: [INFO] eval -(s0//12) <= 10 [guard added] at orkspaces/torch-indexing-reproduction/main.py:12 in forward (_meta_registrations.py:4831 in meta_select)
[2024-04-08 00:21:48,101] [0/1] torch.fx.experimental.symbolic_shapes: [DEBUG] eval (s0//12) >= 10 == False [statically known]
[2024-04-08 00:21:48,102] [0/1] torch.fx.experimental.symbolic_shapes: [DEBUG] eval (s0//12) >= 0 == False [statically known]
[2024-04-08 00:21:48,103] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 88a14d47e62622e2d97d70c8d06ad8bd
Traceback (most recent call last):
  File "/workspaces/torch-indexing-reproduction/main.py", line 25, in <module>
    m()
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
                       ^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
        ^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 249, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 594, in call_function
    return wrap_fx_proxy(tx, proxy)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1525, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
    return fn()
           ^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1487, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1592, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1571, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1712, in dispatch
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/_meta_registrations.py", line 4836, in meta_select
    index = index if index >= 0 else index + size
                     ^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/__init__.py", line 365, in __bool__
    return self.node.bool_()
           ^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 392, in bool_
    return self.guard_bool("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 358, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 226, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3575, in evaluate_expr
    assert static_expr == hint, f"{static_expr} != {hint}"
           ^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method select of type object at 0x7fc52646b8a0>(*(FakeTensor(..., size=(10,), dtype=torch.int64), 0, (s0//12)), **{}):
False != True

from user code:
   File "/workspaces/torch-indexing-reproduction/main.py", line 12, in forward
    self.num = self.nums[self.t//12]


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2024-04-08 00:21:48,118] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-04-08 00:21:48,118] torch._dynamo.utils: [INFO] Function, Runtimes (s)
[2024-04-08 00:21:48,118] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner, 0.0934
[2024-04-08 00:21:48,118] torch._dynamo.utils: [INFO] OutputGraph.call_user_compiler, 0.0002

Minified repro

torch._dynamo.debug_utils did not produce a minified repro

Versions

Collecting environment information...
PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 12.2.0-14) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36

Python version: 3.11.8 (main, Mar 12 2024, 11:41:52) [GCC 12.2.0] (64-bit runtime)
Python platform: Linux-6.2.0-1019-azure-x86_64-with-glibc2.36
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 2
On-line CPU(s) list: 0,1
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7763 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
Stepping: 1
BogoMIPS: 4890.84
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr rdpru arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload umip vaes vpclmulqdq rdpid fsrm
Virtualization: AMD-V
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 32 KiB (1 instance)
L1i cache: 32 KiB (1 instance)
L2 cache: 512 KiB (1 instance)
L3 cache: 32 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78

@shunting314 shunting314 added module: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 8, 2024
@ezyang
Copy link
Contributor

ezyang commented Apr 9, 2024

I don't know what your real code looks like, but as a workaround you can pass dynamic=False. However, we should fix this bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dynamic shapes module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants