Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Sep 22, 2025

So I think doing an output accurate check in autotuning would be a low-cost way to filter out numerically-bad configs and ensure the autotuned config always produces numerically correct kernel, while keeping our compiler passes simple and the maintenance cost low.

Here is a detailed example:

I noticed that autotuning can produce Triton kernels that are runnable but produce wrong numerical output due to issues like read-before-write, e.g.:

import torch
import triton
import triton.language as tl

"""
NOTE: this store-then-load pattern can be generated by applying
`@helion.kernel(config=helion.Config(block_sizes=[128], indexing='tensor_descriptor', num_stages=5, num_warps=1, pid_type='persistent_interleaved', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[2, 3], range_unroll_factors=[1, 2]), static_shapes=True)`
to `kl_div_forward` kernel in examples/kl_div.py.
Full generated code: https://gist.github.com/yf225/cf7f0e30e6b8c97a08cbaa9470fc41c5
"""
@triton.jit
def pipeline_kernel_fail(ptr, diff):
    for pid in tl.range(tl.program_id(0), 4096, loop_unroll_factor=1, num_stages=2):
        for offset in tl.range(0, 4096, 128, loop_unroll_factor=2, num_stages=2, flatten=True):
            idx0 = offset + tl.arange(0, 128)
            addr = ptr + pid * 4096 + idx0
            values = (idx0 + pid).to(tl.float32)
            tl.store(addr, values)
            loaded = tl.load(addr)
            tl.store(diff + pid * 4096 + idx0, loaded - values)

@triton.jit
def safe_kernel(ptr, diff):
    for pid in tl.range(tl.program_id(0), 4096, loop_unroll_factor=1, num_stages=1):
        for offset in tl.range(0, 4096, 128, loop_unroll_factor=2, num_stages=1, flatten=False):
            idx0 = offset + tl.arange(0, 128)
            addr = ptr + pid * 4096 + idx0
            values = (idx0 + pid).to(tl.float32)
            tl.store(addr, values)
            loaded = tl.load(addr)
            tl.store(diff + pid * 4096 + idx0, loaded - values)


def run(kernel, name):
    ptr = torch.zeros((4096 * 4096,), device='cuda', dtype=torch.float32)
    diff = torch.empty_like(ptr)
    kernel[(4096,)](ptr, diff)
    max_abs = diff.abs().max().item()
    nonzero = (diff != 0).sum().item()
    print(f"{name}: max diff={max_abs}, nonzero={nonzero}")
    print('  sample', diff[:10])


def main():
    run(pipeline_kernel_fail, 'pipeline')
    run(safe_kernel, 'safe')


if __name__ == '__main__':
    main()

"""
pipeline: max diff=4095.0, nonzero=4095
  sample tensor([ 0., -1., -2., -3., -4., -5., -6., -7., -8., -9.], device='cuda:0')
safe: max diff=0.0, nonzero=0
  sample tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
"""

My understanding of why pipeline_kernel_fail fails accuracy check:

  • Every tl.store / tl.load only enqueues a transaction; it doesn’t wait for it to finish.
  • When we tell Triton to pipeline a loop (num_stages = 2, 3, …), the compiler is free to overlap several iterations so the GPU can hide memory latency.
  • The compiler does not track memory‐address dependencies between a store and a subsequent load inside the same pipelined loop.

Because of this, the tl.load(addr) can complete before the tl.store(addr, values) for the same iteration has reached memory, so it reads stale data and the accuracy check fails.

I think in general we could write compiler passes that detect "num_stages > 1 and there is tl.load right after tl.store" and not use num_stages > 1 in that case. But it's unclear if we have other cases where Triton codegen can produce runnable but wrong-output kernels. So I think doing an output accurate check in autotuning would be a low-cost way to filter out those bad configs and ensure the autotuned config always produces numerically correct kernel, while keeping our compiler passes simple and the maintenance cost low.

@yf225 yf225 requested review from jansel and oulgen September 22, 2025 07:56
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 22, 2025
@yf225 yf225 force-pushed the autotune_accuracy_check branch from 75e8bf2 to 86c3843 Compare September 22, 2025 08:23
@yf225 yf225 force-pushed the autotune_accuracy_check branch from 86c3843 to 7bb7176 Compare September 22, 2025 08:25
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, is the source of the bad configs here Triton bugs? I'd like to maintain the invariant that every config produces the same result -- is there some underlying bug we could fix?

We can't find a way to solve it I'm ok landing this, but I think we should try to fix the real issue.

@yf225
Copy link
Contributor Author

yf225 commented Sep 23, 2025

Hrm, is the source of the bad configs here Triton bugs? I'd like to maintain the invariant that every config produces the same result -- is there some underlying bug we could fix?

We can't find a way to solve it I'm ok landing this, but I think we should try to fix the real issue.

@jansel Yes I believe it's a triton bug (triton-lang/triton#8259) i.e. their pipeline transform doesn't take dependent load and store into consideration. Helion can have a compiler pass to pattern-match this, but I feel that it would be hard to maintain and there could be other incorrect-result configs that we haven't discovered yet.

Since silently wrong results are pretty scary (it would take a lot of work from user to verify the result and then manually ban those configs), I feel that this PR would be a low-cost catch-all way to ensure that autotuning filters out the bad configs and the best config that user gets is always numerically correct.

@yf225 yf225 requested a review from jansel September 23, 2025 06:57
@jansel
Copy link
Contributor

jansel commented Sep 23, 2025

My initial response to that would be to disable range_pipeline until it is fixed.

@yf225
Copy link
Contributor Author

yf225 commented Sep 23, 2025

My initial response to that would be to disable range_pipeline until it is fixed.

Yes I was also a bit worried that there is legitimately good configs from range_pipeline and disabling it would make Helion generated Triton kernel no competitive vs. hand-written Triton kernels that have it.. I'll land this as a safeguard to still allow us to search over the range_pipeline

@yf225 yf225 merged commit 5ccf6f4 into main Sep 23, 2025
13 checks passed
@jansel
Copy link
Contributor

jansel commented Sep 23, 2025

If we can't be confident in the correctness of a compiler pass we should not use it.

@jansel
Copy link
Contributor

jansel commented Sep 26, 2025

@yf225 based on the comments on triton-lang/triton#8259 it seems like this might be a bug in one of kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants