Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Disallow access to globals that are not tl.constexpr. #3762

Merged
merged 1 commit into from
Apr 25, 2024

Conversation

jlebar
Copy link
Collaborator

@jlebar jlebar commented Apr 25, 2024

[Frontend] Disallow access to globals that are not tl.constexpr.

We've found that users often accidentally access global vars from inside a
Triton kernel that they don't mean to. Moreover, they sometimes change these
global vars after they first run the kernel, but Triton does not notice this
(the initial value of the global is baked into the compiled kernel). As a
result, the following runs have a stale value for the global.

We hope that making users annotate their globals with tl.constexpr will remind
them that they are kernel arguments that they should not change.

In a separate PR, we will additionally explore the possibility of raising an
error when the value of one of these globals changes.

This is a breaking change, so we allow you to opt out of it by setting
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1. But we do not promise to support this
envvar forever.

PR chain

  1. 👉 [Frontend] Disallow access to globals that are not tl.constexpr. #3762 👈 YOU ARE HERE

@jlebar jlebar requested a review from ptillet as a code owner April 25, 2024 16:53
@jlebar jlebar marked this pull request as draft April 25, 2024 16:55
@jlebar jlebar force-pushed the dev-jlebar/disallow-non-constexpr-global branch from 3192153 to 21336e5 Compare April 25, 2024 18:41
We've found that users often accidentally access global vars from inside a
Triton kernel that they don't mean to.  Moreover, they sometimes *change* these
global vars after they first run the kernel, but Triton does not notice this
(the initial value of the global is baked into the compiled kernel). As a
result, the following runs have a stale value for the global.

We hope that making users annotate their globals with tl.constexpr will remind
them that they are kernel arguments that they should not change.

In a separate PR, we will additionally explore the possibility of raising an
error when the value of one of these globals changes.

This is a breaking change, so we allow you to opt out of it by setting
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1.  But we do not promise to support this
envvar forever.

GPC: disallow-non-constexpr-global
@jlebar jlebar force-pushed the dev-jlebar/disallow-non-constexpr-global branch from 21336e5 to 22e975b Compare April 25, 2024 19:05
@jlebar jlebar marked this pull request as ready for review April 25, 2024 20:09
@jlebar jlebar enabled auto-merge (squash) April 25, 2024 20:10
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I'll give a bit of time to Phil to review as I'm not an expert of this area

@jlebar jlebar merged commit 1e994b8 into main Apr 25, 2024
5 checks passed
@jlebar jlebar deleted the dev-jlebar/disallow-non-constexpr-global branch April 25, 2024 20:32
amjames added a commit to amjames/pytorch that referenced this pull request May 16, 2024
…ser defined kernel source

[Triton pytorch#3762](triton-lang/triton#3762)
disallows access to globals which are not `tl.constexpr`

Triton has always treated captured globals this way, but they now
require it be explicit in user code.

Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.

ghstack-source-id: 5fc2a44f6d39a699d704ca1f311eb84770af8647
Pull Request resolved: pytorch#126195
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 6, 2024
… globals correctly. (#126195)

[Triton #3762](triton-lang/triton#3762)
disallows access to globals which are not `tl.constexpr`

Triton has always treated captured globals this way, but they now
require it be explicit in user code.

Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.

Pull Request resolved: #126195
Approved by: https://github.com/alexbaden, https://github.com/bertmaher
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
… globals correctly. (pytorch#126195)

[Triton pytorch#3762](triton-lang/triton#3762)
disallows access to globals which are not `tl.constexpr`

Triton has always treated captured globals this way, but they now
require it be explicit in user code.

Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.

Pull Request resolved: pytorch#126195
Approved by: https://github.com/alexbaden, https://github.com/bertmaher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants