Skip to content

Sometimes, offsets are hoisted out of loops #1155

@v0i0

Description

@v0i0

Describe the bug
Looks like the expression with the min is getting hoisted out of the loop despite depending on the loop induction variable. Feels like maybe the handling for attn hoisting is coming back to us?

To Reproduce

On top of #1119:

+git diff
diff --git a/examples/gdn_fwd_h.py b/examples/gdn_fwd_h.py
index 45e11f0..5512366 100644
--- a/examples/gdn_fwd_h.py
+++ b/examples/gdn_fwd_h.py
@@ -63,7 +63,7 @@ def helion_gdn_fwd_h(
             p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
             b_v = p_v - b_v
             m_t = t_i.index < seqlen
-            t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
+            t_i_last = min(t_i.begin + chunk_size - 1, seqlen - 1)
             b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
             b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
             b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
+python examples/gdn_fwd_h.py
Testing helion correctness...
[0s] Autotune random seed: 2727137739
Traceback (most recent call last):
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 188, in _compute_baseline
    baseline_output = self.kernel.compile_config(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_mhoehnerbach/pj/cpjaap23wmi72e2hnjsd4ekh2kqwxu4glx5tiikiw7mr2brcnvae.py", line 121, in helion_gdn_fwd_h
    _launcher(_helion_helion_gdn_fwd_h, (8 * 80 * triton.cdiv(128, _BLOCK_SIZE_0),), h, w, u, g, k, _BLOCK_SIZE_0, _RDIM_SIZE_4, _BLOCK_SIZE_3, num_warps=4, num_stages=1)
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/__init__.py", line 86, in default_launcher
    return triton_kernel.run(
           ^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/runtime/jit.py", line 733, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/runtime/jit.py", line 861, in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/compiler/compiler.py", line 300, in compile
    module = src.make_ir(target, options, codegen_fns, module_map, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/compiler/compiler.py", line 80, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 18:38:
    pid_0 = tl.program_id(0) % num_blocks_0
    pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
    pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
    offset_1 = pid_0
    offset_2 = pid_1
    offset_0 = pid_2 * _BLOCK_SIZE_0
    indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
    indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
    # src[gdn_fwd_h.py:57]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
    b_h = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
    # src[gdn_fwd_h.py:67]: b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
    symnode_0 = 4095 * (4095 <= 255 + offset_4) + (255 + offset_4) * (255 + offset_4 < 4095)
                                      ^
NameError('offset_4 is not defined')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/examples/gdn_fwd_h.py", line 210, in <module>
    main()
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/examples/gdn_fwd_h.py", line 206, in main
    test(8, 80, 4096, 256, 64, 128)
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/examples/gdn_fwd_h.py", line 196, in test
    run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/_testing.py", line 607, in run_example
    func(*args).to(torch.float32),
    ^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/kernel.py", line 330, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/kernel.py", line 696, in __call__
    self.autotune(args, force=False)
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/kernel.py", line 574, in autotune
    config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/settings.py", line 253, in default_autotuner_fn
    return cache_cls(autotuner_cls(bound_kernel, args, **kwargs))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/pattern_search.py", line 45, in __init__
    super().__init__(kernel, args)
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 707, in __init__
    super().__init__(kernel, args)
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 136, in __init__
    ) = self._compute_baseline()
        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 203, in _compute_baseline
    raise exc.InvalidConfig(
helion.exc.InvalidConfig: Default config failed while computing baseline.
Default config: @helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', '', '', '', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=True)
Enable HELION_AUTOTUNE_LOG_LEVEL=DEBUG to log generated Triton code.
To work around this error, you could set `@helion.kernel(autotune_baseline_fn=...)` to provide a custom baseline function (e.g. PyTorch eager implementation of your kernel).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions