Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency between constants as arguments and captured globals #3924

Open
amjames opened this issue May 15, 2024 · 3 comments
Open

Inconsistency between constants as arguments and captured globals #3924

amjames opened this issue May 15, 2024 · 3 comments

Comments

@amjames
Copy link
Contributor

amjames commented May 15, 2024

TLDR: After #3762 global variables which are captured by a kernel must be tl.constexpr or annotated as such. It is surprising to me that the kernel argument which has an annotation is actually an object of type constexpr when the CodeGenerator.visit is running, but the captured global is not. Either that should be fixed, or the suggestion in the error message should only recommend globals be defined as VAR = tl.constexpr(<value>).

Details

I had some code that looks like this (actual original is from pytorch tests:

STRING_CONSTANT_C = 'value'

@triton.jit
def kernel(in_ptr, out_ptr,  n_elements, BLOCK_SIZE: "tl.constexpr", CONSTANT_NAME: "tl.constexpr"):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    if CONSTANT_NAME.value == STRING_CONSTANT_C:
        output = 2 * x
    tl.store(out_ptr + offsets, output, mask=mask)

After getting the new error about globals needing to be tl.constexpr I tried defining STRING_CONSTANT_C` is defined globally like this:

STRING_CONSTANT_C = tl.constexpr('value')

Compilation of fails w/ if conditionals can only accept values of type {int, NoneType, bool}, not objects of type NotImplementedType. Digging into that a bit I realize that this is parsing as str.__eq__ comparing a string to a tl.constexpr object. So I modify the kernel source so the conditional uses STRING_CONSTANT_C.value, which works.

The other recommendation from the error message introduced by #3762 is to use an annotation on the captured global, trying that out

STRING_CONSTANT_C: tl.constexpr = 'value'

That fails with the modified conditional and works with the original source.

Proposal

Why not translate these captured variables to always be tl.constexpr instances the way that arguments with the annotation are handled?

Reproducer script: https://gist.github.com/amjames/973b378f7c0fa8c92b6c92d05d90547b

@amjames
Copy link
Contributor Author

amjames commented May 15, 2024

NB: Also working on fixing inductor's codegen for this situation pytorch-126195

@alexbaden
Copy link
Contributor

I think I have a fix for this, but I am not sure the semantics of tl.constexpr require it. If you explicitly "dereference" the constexpr with .value, then this problem arises like you said because you have __eq__ attribute in _apply_binary_method (https://github.com/triton-lang/triton/blob/main/python/triton/compiler/code_generator.py#L541) instead of the constexpr __eq__ attribute. I added a conditional check + constexpr unwrap for the situation where the rhs is constexpr and the lhs is not. However, when writing tests I couldn't think of a situation where one should explicitly dereference the constexpr type - if _apply_binary_method is being called, then I think both the lhs and rhs should be constexpr by definition. The previous situation only occurs because there are globals which are being treated as constexpr but are not explicitly marked constexpr - once marked constexpr, the .value "dereference" (Is that the right word?) is unnecessary. So perhaps the .value is really the bug and should be removed?

I have a branch https://github.com/alexbaden/triton/tree/alex/fix_globals_constexpr that implements this fix that I would be happy to PR, if indeed this is a triton bug and not an inductor bug / invalid dereference of tl.constexpr. I ran your repro script on my branch and it passed:

# Global defines:
STRING_CONSTANT_C: tl.constexpr = '...'
STRING_CONSTANT_OBJ_C = tl.constexpr('...')
# kernel_signature: 
def kernel(..., CONSTANT_NAME: tl.constexpr):
With expr if CONSTANT_NAME == STRING_CONSTANT_C: called w/ CONSTANT_NAME='...' -- > WORKS
With expr if CONSTANT_NAME == STRING_CONSTANT_OBJ_C: called w/ CONSTANT_NAME='...' -- > WORKS
With expr if CONSTANT_NAME == STRING_COSNTANT_OBJ_C: called w/ CONSTANT_NAME='...' -- > WORKS

@amjames
Copy link
Contributor Author

amjames commented May 16, 2024

@alexbaden that patch would fix the reproducer I wrote, but that was more of a demonstration rather than an exhaustive set of tests. The issue I have is really with the asymmetry between the handling of (locals, kernel arguments) and (capture globals) with respect to the annotation. The former group is intercepted at the assignment and constexpr(value) is stored in the local scope dictionary, the latter is not (the annotation is only checked when we dereference the name to see if we are 'allowed' to access the global). Consider this:

GLOBAL: tl.constexpr = 1

@triton.jit
def kernel(arg: tl.constexpr):
    local_var: tl.constexpr = 1
	
	if local_var == arg:
		...
	if local_var == GLOBAL:
		...
	if arg == local_var:
		...
	if arg == GLOBAL:
		...
	if GLOBAL == local_var:
		...
	if GLOBAL == arg:
		...

The last two clauses trigger a compile error and you need to use .value on the RHS to work around it. I think that is wrong.

Even worse, accessing a captured global requires either it carry the type tl.constexpr or the annotation. However in the latter case there is no guard on assigning to it.

GLOBAL: tl.constexpr = 1
@triton.jit(debug=True)
def kernel(arg: tl.constexpr):
	GLOBAL = 7
	tl.device_assert(local_var == GLOBAL) # runtime failure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants