Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Key error in index_propagation when looking up dynamic shape vr #127677

Closed
ColinPeppler opened this issue Jun 1, 2024 · 12 comments
Closed

Key error in index_propagation when looking up dynamic shape vr #127677

ColinPeppler opened this issue Jun 1, 2024 · 12 comments
Assignees
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ColinPeppler
Copy link
Contributor

ColinPeppler commented Jun 1, 2024

🐛 Describe the bug

No response

Error logs

  ...
    so_path = torch._inductor.aot_compile(
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/__init__.py", line 104, in aot_compile
    return compile_fx_aot(
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 1122, in compile_fx_aot
    compiled_lib_path = compile_fx(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 1228, in compile_fx
    return compile_fx(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 1262, in compile_fx
    return compile_fx(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 1466, in compile_fx
    return inference_compiler(unlifted_gm, example_inputs_)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 1376, in fw_compiler_base
    return inner_compile(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 497, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/compile_fx.py", line 793, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/graph.py", line 1727, in compile_to_fn
    code, linemap = self.codegen_with_cpp_wrapper()
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/graph.py", line 1563, in codegen_with_cpp_wrapper
    compiled = self.compile_to_module().call
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/graph.py", line 1689, in compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/graph.py", line 1645, in codegen
    self.scheduler = Scheduler(self.buffers)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/scheduler.py", line 1409, in __init__
    self.nodes = [self.create_scheduler_node(n) for n in nodes]
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/scheduler.py", line 1409, in <listcomp>
    self.nodes = [self.create_scheduler_node(n) for n in nodes]
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/scheduler.py", line 1507, in create_scheduler_node
    return SchedulerNode(self, node)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/scheduler.py", line 781, in __init__
    self._compute_attrs()
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/scheduler.py", line 788, in _compute_attrs
    self._sizes, self._body = self.node.simplify_and_reorder(
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/ir.py", line 3382, in simplify_and_reorder
    ) = self.get_default_sizes_body()
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/utils.py", line 409, in wrapper
    setattr(self, key, fn(self))
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/ir.py", line 3339, in get_default_sizes_body
    body = LoopBody(
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/ir.py", line 7981, in __init__
    self.root_block = LoopBodyBlock(self, fn, args)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/ir.py", line 8211, in __init__
    ops.output(fn(*args))
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/ir.py", line 1618, in store_reduction
    values = [inner_fn(idx) for inner_fn in self.inner_fns]
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/ir.py", line 1618, in <listcomp>
    values = [inner_fn(idx) for inner_fn in self.inner_fns]
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/lowering.py", line 2914, in fn
    ops.indirect_indexing(
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/virtualized.py", line 283, in indirect_indexing
    return _ops.indirect_indexing(index, size, check)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/index_propagation.py", line 339, in indirect_indexing
    can_prove_upper = self.statically_true(expr < size)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/_inductor/index_propagation.py", line 307, in statically_true
    evaluated = self.shape_env._maybe_evaluate_static(
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/fx/experimental/symbolic_shapes.py", line 1446, in wrapper
    return fn_cache(self, *args, **kwargs)
  File "/dev/shm/uid-99/0a1a759c-seed-nspid4026547201_cgpid19114956-ns-4026547025/torch/fx/experimental/symbolic_shapes.py", line 4414, in _maybe_evaluate_static
    vr = var_ranges[k]
KeyError: indirect0
Rethrown from:
KeyError(indirect0)

Minified repro

No response

Versions

n/a

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@ColinPeppler
Copy link
Contributor Author

this error comes from an internal model so I'm working on getting a minimal repro

@ColinPeppler
Copy link
Contributor Author

>>> expr
indirect0
>>> index
IndexPropVar(value=TypedExpr(expr=indirect0, dtype=torch.int64), is_symbolic=True)
>>> index.value
TypedExpr(expr=indirect0, dtype=torch.int64)

@ezyang
Copy link
Contributor

ezyang commented Jun 1, 2024

@lezcano this is a perfect example of "people allocating symbols willy nilly and not keeping track of what their value ranges" are. A bandaid solution that would definitely work is if we make maybe_evaluate_static just put an unknown range if it can't do var_to_range. We could also study the allocation of INDIRECT at torch/_inductor/ir.py and see if we know anything about the range, and then try to maintain them and pass them as an extra var_to_range argument to the maybe_evaluate_static.

@lezcano
Copy link
Collaborator

lezcano commented Jun 1, 2024

This one's a bit odd. You shouldn't get an indirect symbol inside a TypedExpr. The whole point of TypedExpr is that they can be expressed in terms of already known expressions.

ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
@ColinPeppler
Copy link
Contributor Author

This one's a bit odd. You shouldn't get an indirect symbol inside a TypedExpr. The whole point of TypedExpr is that they can be expressed in terms of already known expressions.

I see this, should q0 be replaced by s0 in this case?

>>> body.var_ranges
{q0: s0}

@lezcano
Copy link
Collaborator

lezcano commented Jun 1, 2024

That is unrelated. The point is that the symbol indirect0 should have never got to be inside TypedExpr.

@ColinPeppler ColinPeppler self-assigned this Jun 1, 2024
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
ColinPeppler added a commit to ColinPeppler/pytorch that referenced this issue Jun 1, 2024
…nge (pytorch#127681)

Summary:

Issue: pytorch#127677

Test Plan:
ci

---

Differential Revision: D58048558
@peterbell10
Copy link
Collaborator

the symbol indirect0 should have never got to be inside TypedExpr.

This happens when you have something like:

a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)

This is a bit weird but is valid IR so we should handle it gracefully

@peterbell10
Copy link
Collaborator

And yeah it would be good if we could use the bound of [0, size) that gets asserted.

pytorchmergebot pushed a commit that referenced this issue Jun 2, 2024
#127681)

Purpose of this PR is to get around this error: #127677

Differential Revision: D58048558

Pull Request resolved: #127681
Approved by: https://github.com/lezcano
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 3, 2024
petrex pushed a commit to petrex/pytorch that referenced this issue Jun 5, 2024
pytorch#127681)

Purpose of this PR is to get around this error: pytorch#127677

Differential Revision: D58048558

Pull Request resolved: pytorch#127681
Approved by: https://github.com/lezcano
@bhack
Copy link
Contributor

bhack commented Jun 6, 2024

Compiling this encoder #121637 with nightly I got:

W0606 13:31:53.266000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.267000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q1 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.277000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.277000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q1 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.610000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.610000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z1 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.620000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:53.620000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z1 is not in var_ranges, defaulting to unknown range.
W0606 13:31:54.979000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:55.114000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:57.485000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:57.617000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:59.574000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] q0 is not in var_ranges, defaulting to unknown range.
W0606 13:31:59.709000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] z0 is not in var_ranges, defaulting to unknown range.
W0606 13:32:06.453000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] y1 is not in var_ranges, defaulting to unknown range.
W0606 13:32:06.454000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] y0 is not in var_ranges, defaulting to unknown range.
W0606 13:32:06.464000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] y1 is not in var_ranges, defaulting to unknown range.
W0606 13:32:06.464000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] y0 is not in var_ranges, defaulting to unknown range.
W0606 13:32:07.437000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] x0 is not in var_ranges, defaulting to unknown range.
W0606 13:32:08.775000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] x0 is not in var_ranges, defaulting to unknown range.
W0606 13:32:10.533000 136005565806400 torch/fx/experimental/symbolic_shapes.py:4417] [0/0] x0 is not in var_ranges, defaulting to unknown range.

@ezyang
Copy link
Contributor

ezyang commented Jun 6, 2024

Yeah, this being a warning is too chatty, we should reduce the verbosity

@ColinPeppler
Copy link
Contributor Author

was able to minimal repro this, will share smth soon

@bhack
Copy link
Contributor

bhack commented Jun 8, 2024

Also compiling that mentioned SwinB encoder with the last nightly it is going to take forever. I was waiting form more then 30/40 minutes and the compilation never end but it is still going to continuously print these

W0608 10:19:56.462000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q1 is not in var_ranges, defaulting to unknown range.
W0608 10:19:56.463000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:19:57.250000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:19:57.251000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:00.017000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:00.388000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:02.362000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:02.638000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:02.891000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:03.145000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:04.216000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:04.562000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:06.352000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:06.641000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:06.858000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:07.080000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:12.044000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps2 is not in var_ranges, defaulting to unknown range.
W0608 10:20:14.859000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:14.859000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x2 is not in var_ranges, defaulting to unknown range.
W0608 10:20:15.997000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps9 is not in var_ranges, defaulting to unknown range.
W0608 10:20:21.599000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:22.175000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:20:22.176000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:22.680000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:24.548000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:24.871000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:25.147000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:25.335000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:27.103000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:20:27.103000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:27.550000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:20:27.550000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:20:30.250000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:31.050000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:31.051000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:20:31.564000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:33.395000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:33.628000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:33.843000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:34.354000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:37.129000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:37.130000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:20:37.495000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:20:37.496000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/0] ps11 is not in var_ranges, defaulting to unknown range.
V0608 10:21:56.286000 137538496919360 torch/_dynamo/guards.py:2610] [5/1] [__recompiles] Recompiling function forward in /workspace/networks/encoders/swin/swin_transformer.py:425
V0608 10:21:56.286000 137538496919360 torch/_dynamo/guards.py:2610] [5/1] [__recompiles]     triggered by the following guard failure(s):
V0608 10:21:56.286000 137538496919360 torch/_dynamo/guards.py:2610] [5/1] [__recompiles]     - Eq(IntTrueDiv(L['H'], L['self'].window_size), 38.8571428571429)  # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 10:21:58.570000 137538496919360 torch/_dynamo/guards.py:2610] [6/1] [__recompiles] Recompiling function torch_dynamo_resume_in_forward_at_433 in /workspace/networks/encoders/swin/swin_transformer.py:433
V0608 10:21:58.570000 137538496919360 torch/_dynamo/guards.py:2610] [6/1] [__recompiles]     triggered by the following guard failure(s):
V0608 10:21:58.570000 137538496919360 torch/_dynamo/guards.py:2610] [6/1] [__recompiles]     - Eq(IntTrueDiv(L['W'], L['self'].window_size), 38.8571428571429)  # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 10:21:58.746000 137538496919360 torch/_dynamo/guards.py:2610] [7/1] [__recompiles] Recompiling function torch_dynamo_resume_in_forward_at_434 in /workspace/networks/encoders/swin/swin_transformer.py:434
V0608 10:21:58.746000 137538496919360 torch/_dynamo/guards.py:2610] [7/1] [__recompiles]     triggered by the following guard failure(s):
V0608 10:21:58.746000 137538496919360 torch/_dynamo/guards.py:2610] [7/1] [__recompiles]     - tensor 'L['x']' dtype mismatch. expected Float, actual Half
W0608 10:25:12.267000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:12.267000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:13.055000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:13.055000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:15.955000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:16.348000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:18.323000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:18.626000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:18.887000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:19.153000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] q0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:20.253000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:20.597000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:22.368000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:22.635000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:22.854000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:23.088000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] z0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:28.195000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps2 is not in var_ranges, defaulting to unknown range.
W0608 10:25:30.976000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:30.976000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x2 is not in var_ranges, defaulting to unknown range.
W0608 10:25:31.979000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps9 is not in var_ranges, defaulting to unknown range.
W0608 10:25:37.977000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:38.572000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:25:38.573000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:39.106000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:41.024000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:41.367000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:41.601000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:41.793000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:43.550000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:25:43.550000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:43.988000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:25:43.989000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x0 is not in var_ranges, defaulting to unknown range.
W0608 10:25:46.799000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:47.611000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:47.611000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:25:48.141000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:50.047000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:50.294000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:50.510000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:50.703000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:53.414000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:53.414000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps11 is not in var_ranges, defaulting to unknown range.
W0608 10:25:53.771000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] x1 is not in var_ranges, defaulting to unknown range.
W0608 10:25:53.771000 137538496919360 torch/fx/experimental/symbolic_shapes.py:4451] [7/1] ps11 is not in var_ranges, defaulting to unknown range.
V0608 10:27:12.925000 137538496919360 torch/_dynamo/guards.py:2610] [5/2] [__recompiles] Recompiling function forward in /workspace/networks/encoders/swin/swin_transformer.py:425
V0608 10:27:12.925000 137538496919360 torch/_dynamo/guards.py:2610] [5/2] [__recompiles]     triggered by the following guard failure(s):
V0608 10:27:12.925000 137538496919360 torch/_dynamo/guards.py:2610] [5/2] [__recompiles]     - Eq(IntTrueDiv(L['H'], L['self'].window_size), 19.4285714285714)  # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 10:27:12.925000 137538496919360 torch/_dynamo/guards.py:2610] [5/2] [__recompiles]     - Eq(IntTrueDiv(L['H'], L['self'].window_size), 38.8571428571429)  # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 10:27:15.318000 137538496919360 torch/_dynamo/guards.py:2610] [6/2] [__recompiles] Recompiling function torch_dynamo_resume_in_forward_at_433 in /workspace/networks/encoders/swin/swin_transformer.py:433
V0608 10:27:15.318000 137538496919360 torch/_dynamo/guards.py:2610] [6/2] [__recompiles]     triggered by the following guard failure(s):
V0608 10:27:15.318000 137538496919360 torch/_dynamo/guards.py:2610] [6/2] [__recompiles]     - Eq(IntTrueDiv(L['W'], L['self'].window_size), 19.4285714285714)  # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 10:27:15.318000 137538496919360 torch/_dynamo/guards.py:2610] [6/2] [__recompiles]     - Eq(IntTrueDiv(L['W'], L['self'].window_size), 38.8571428571429)  # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 10:27:15.399000 137538496919360 torch/_dynamo/guards.py:2610] [7/2] [__recompiles] Recompiling function torch_dynamo_resume_in_forward_at_434 in /workspace/networks/encoders/swin/swin_transformer.py:434
V0608 10:27:15.399000 137538496919360 torch/_dynamo/guards.py:2610] [7/2] [__recompiles]     triggered by the following guard failure(s):
V0608 10:27:15.399000 137538496919360 torch/_dynamo/guards.py:2610] [7/2] [__recompiles]     - tensor 'L['x']' stride mismatch at index 1. expected 256, actual 512
V0608 10:27:15.399000 137538496919360 torch/_dynamo/guards.py:2610] [7/2] [__recompiles]     - tensor 'L['x']' dtype mismatch. expected Float, actual Half

ColinPeppler added a commit that referenced this issue Jun 12, 2024
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
We can probably skip the check to prove `indirect0`'s bounds because its purpose is to add a bounds check as a device side assert. Thankfully, we already do this in the codegen pass.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/codegen/common.py#L1730-L1733





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this issue Jun 12, 2024
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
Add `indirectX  symbols with a default range (-inf, +inf) to `self.var_to_range` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this issue Jun 12, 2024
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
Add `indirectX  symbols with a default range (-inf, +inf) to `self.var_to_range` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this issue Jun 13, 2024
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this issue Jun 13, 2024
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this issue Jun 14, 2024
…se for index_propagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this issue Jun 14, 2024
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this issue Jun 14, 2024
…ytorch#128378)

Tries to fix pytorch#127677.

# Context

Just as @peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0,
    triton_poi_fused_index_select_0_xnumel,
    grid=grid(triton_poi_fused_index_select_0_xnumel),
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.

Pull Request resolved: pytorch#128378
Approved by: https://github.com/lezcano, https://github.com/peterbell10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants