Skip to content

inductor LOAF throws an error on toy GPT model #141134

@vkuzo

Description

@vkuzo

🐛 Describe the bug

When I run the repro script from pytorch/ao#1297 with TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1, I see an error. Here is a simplified repro:

Code: https://gist.github.com/vkuzo/7dcf4f3e9f25055c9cad23ca87b74e7a
Command to reproduce: TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python ~/local/tmp/20241120_loaf_bug.py
Full stack trace of error: https://gist.github.com/vkuzo/2766bbd5386a8a00d1b43be171a6d29d
Relevant part of stack trace:

  File "/data/users/vasiliy/pytorch/torch/_inductor/scheduler.py", line 3568, in _codegen
    self.get_backend(device).codegen_node(node)
  File "/data/users/vasiliy/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 80, in codegen_node
    return self._triton_scheduling.codegen_node(node)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/vasiliy/pytorch/torch/_inductor/codegen/simd.py", line 1195, in codegen_node
    return self.codegen_node_schedule(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/vasiliy/pytorch/torch/_inductor/codegen/simd.py", line 1230, in codegen_node_schedule
    tiled_groups = self.select_tiling(
                   ^^^^^^^^^^^^^^^^^^^
  File "/data/users/vasiliy/pytorch/torch/_inductor/codegen/simd.py", line 1581, in select_tiling
    if len(cls.candidate_tilings(node)) > 0:
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/vasiliy/pytorch/torch/_inductor/codegen/simd.py", line 1505, in candidate_tilings
    assert len(rw.range_vars) == len(ranges), f"{rw.range_vars=} {ranges=}"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: rw.range_vars=[d0, d1, d2] ranges=(12288, 1600)

Using TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=0, the code above runs without errors.

Versions

https://gist.github.com/vkuzo/d737f40af85d6b426d209f89692885c2

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions