-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FRONTEND][BACKEND] Implement tl.device_assert
and rename tl.printf
to tl.device_print
#1143
Conversation
Do you think we should do things the other way around to be more consistent with Python:
Also, maybe we could enable debug mode by default when there kernel contains a print? |
For "assert", I think the answer is that we could do it this way. Unfortunately we don't fully support the @christopherhesse What's your suggestion? Long story short, option (1)
Yeah, we could. I haven't sorted out what features should be disabled in the debug mode though. Could we do it in another PR? |
Hmmm. This may be indeed be a little premature. Right now
Yes, we can do it in another PR. We just need to create a github issue to make sure we don't lose track of that :) |
@Jokeren I think we should go with |
I'll have to defer to @ptillet here. I do think if you allow the use of the built-in |
@ptillet @christopherhesse ready for review |
The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` [ghstack-poisoned]
…nd stores" The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` ghstack-source-id: 5694804018dc9649217985b747849e92f97bf224 Pull Request resolved: #98590
The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` ghstack-source-id: 5694804018dc9649217985b747849e92f97bf224 Pull Request resolved: #98590
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…nd stores" This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2") tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0") tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` ghstack-source-id: 3c0fb0560c19e31765260ae897e9274067bfcb10 Pull Request resolved: #98590
This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from triton-lang/triton#1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes #93538 Pull Request resolved: #98590 Approved by: https://github.com/ngimel
I
This print works as on device print op which will be executed by a warps of 32 threads. This is not necessary. Could we instead develop a LLVM printop for TritonIR: just print during tiles mapping stage and evaluate the result with CPU for verification purpose ? |
…f` to `tl.device_print` (triton-lang#1143) Note that `tl.device_print` and `print` accepts different arguments than the normal `print`. The first argument must be a string, following by variables. Device side: - `tl.device_print` - `tl.device_assert` - `print` - `assert` Compilation time: - `tl.static_assert` - `tl.static_print` Usage example: 1. ```Python tl.device_assert(x == 0, "x != 0") ``` Output: ```Python ... python/test/unit/language/assert_helper.py:18: kernel: block: [0,0,0], thread: [33,0,0] Assertion `x != 0` failed. ... ``` 2. ```Python tl.device_print("hello ", x) ``` Output: ```Python ... hello 1 ... ``` The environment variable `TRITON_DEBUG` sets the default debugging flag; if it's true, `tl.device_assert` or `assert` will be skipped.
…ang#1193) Prevents `StoreOp`s from being rewritten if it is not suitable for 2D block store. Addresses Issue: triton-lang#1143 --------- Signed-off-by: Maxime France-Pillois <maxime.francepillois@codeplay.com>
Note that
tl.device_print
andprint
accepts different arguments than the normalprint
. The first argument must be a string, following by variables.Device side:
tl.device_print
tl.device_assert
print
assert
Compilation time:
tl.static_assert
tl.static_print
Usage example:
Output:
Output: