-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistency between constants as arguments and captured globals #3924
Comments
NB: Also working on fixing inductor's codegen for this situation pytorch-126195 |
I think I have a fix for this, but I am not sure the semantics of I have a branch https://github.com/alexbaden/triton/tree/alex/fix_globals_constexpr that implements this fix that I would be happy to PR, if indeed this is a triton bug and not an inductor bug / invalid dereference of
|
@alexbaden that patch would fix the reproducer I wrote, but that was more of a demonstration rather than an exhaustive set of tests. The issue I have is really with the asymmetry between the handling of (locals, kernel arguments) and (capture globals) with respect to the annotation. The former group is intercepted at the assignment and GLOBAL: tl.constexpr = 1
@triton.jit
def kernel(arg: tl.constexpr):
local_var: tl.constexpr = 1
if local_var == arg:
...
if local_var == GLOBAL:
...
if arg == local_var:
...
if arg == GLOBAL:
...
if GLOBAL == local_var:
...
if GLOBAL == arg:
... The last two clauses trigger a compile error and you need to use Even worse, accessing a captured global requires either it carry the type GLOBAL: tl.constexpr = 1
@triton.jit(debug=True)
def kernel(arg: tl.constexpr):
GLOBAL = 7
tl.device_assert(local_var == GLOBAL) # runtime failure |
TLDR: After #3762 global variables which are captured by a kernel must be
tl.constexpr
or annotated as such. It is surprising to me that the kernel argument which has an annotation is actually an object of typeconstexpr
when theCodeGenerator.visit
is running, but the captured global is not. Either that should be fixed, or the suggestion in the error message should only recommend globals be defined asVAR = tl.constexpr(<value>)
.Details
I had some code that looks like this (actual original is from pytorch tests:
After getting the new error about globals needing to be
tl.constexpr
I tried defining STRING_CONSTANT_C` is defined globally like this:Compilation of fails w/
if conditionals can only accept values of type {int, NoneType, bool}, not objects of type NotImplementedType
. Digging into that a bit I realize that this is parsing asstr.__eq__
comparing a string to atl.constexpr
object. So I modify the kernel source so the conditional usesSTRING_CONSTANT_C.value
, which works.The other recommendation from the error message introduced by #3762 is to use an annotation on the captured global, trying that out
That fails with the modified conditional and works with the original source.
Proposal
Why not translate these captured variables to always be
tl.constexpr
instances the way that arguments with the annotation are handled?Reproducer script: https://gist.github.com/amjames/973b378f7c0fa8c92b6c92d05d90547b
The text was updated successfully, but these errors were encountered: