-
Notifications
You must be signed in to change notification settings - Fork 61
[Autotune] Filter bad config with accuracy check #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
75e8bf2 to
86c3843
Compare
86c3843 to
7bb7176
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hrm, is the source of the bad configs here Triton bugs? I'd like to maintain the invariant that every config produces the same result -- is there some underlying bug we could fix?
We can't find a way to solve it I'm ok landing this, but I think we should try to fix the real issue.
@jansel Yes I believe it's a triton bug (triton-lang/triton#8259) i.e. their pipeline transform doesn't take dependent load and store into consideration. Helion can have a compiler pass to pattern-match this, but I feel that it would be hard to maintain and there could be other incorrect-result configs that we haven't discovered yet. Since silently wrong results are pretty scary (it would take a lot of work from user to verify the result and then manually ban those configs), I feel that this PR would be a low-cost catch-all way to ensure that autotuning filters out the bad configs and the best config that user gets is always numerically correct. |
|
My initial response to that would be to disable |
Yes I was also a bit worried that there is legitimately good configs from |
|
If we can't be confident in the correctness of a compiler pass we should not use it. |
|
@yf225 based on the comments on triton-lang/triton#8259 it seems like this might be a bug in one of kernels. |
So I think doing an output accurate check in autotuning would be a low-cost way to filter out numerically-bad configs and ensure the autotuned config always produces numerically correct kernel, while keeping our compiler passes simple and the maintenance cost low.
Here is a detailed example:
I noticed that autotuning can produce Triton kernels that are runnable but produce wrong numerical output due to issues like read-before-write, e.g.:
My understanding of why
pipeline_kernel_failfails accuracy check:tl.store/tl.loadonly enqueues a transaction; it doesn’t wait for it to finish.Because of this, the
tl.load(addr)can complete before thetl.store(addr, values)for the same iteration has reached memory, so it reads stale data and the accuracy check fails.I think in general we could write compiler passes that detect "num_stages > 1 and there is
tl.loadright aftertl.store" and not use num_stages > 1 in that case. But it's unclear if we have other cases where Triton codegen can produce runnable but wrong-output kernels. So I think doing an output accurate check in autotuning would be a low-cost way to filter out those bad configs and ensure the autotuned config always produces numerically correct kernel, while keeping our compiler passes simple and the maintenance cost low.