Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][BACKEND] Implement tl.device_assert and rename tl.printf to tl.device_print #1143

Merged
merged 28 commits into from
Mar 4, 2023

Conversation

Jokeren
Copy link
Contributor

@Jokeren Jokeren commented Feb 3, 2023

Note that tl.device_print and print accepts different arguments than the normal print. The first argument must be a string, following by variables.

Device side:

  • tl.device_print
  • tl.device_assert
  • print
  • assert

Compilation time:

  • tl.static_assert
  • tl.static_print

Usage example:

tl.device_assert(x == 0, "x != 0")

Output:

...
python/test/unit/language/assert_helper.py:18: kernel: block: [0,0,0], thread: [33,0,0] Assertion `x != 0` failed.
...
tl.device_print("hello ", x)

Output:

...
hello 1
...

@ptillet
Copy link
Collaborator

ptillet commented Feb 3, 2023

Do you think we should do things the other way around to be more consistent with Python:

  • assert is evaluated at runtime; tl.static_assert is evaluated at compile-time.
  • print is evaluated at runtime; tl.static_print is evaluated at compile-time.

Also, maybe we could enable debug mode by default when there kernel contains a print?

@Jokeren
Copy link
Contributor Author

Jokeren commented Feb 3, 2023

assert is evaluated at runtime; tl.static_assert is evaluated at compile-time.
print is evaluated at runtime; tl.static_print is evaluated at compile-time.

For "assert", I think the answer is that we could do it this way.

Unfortunately we don't fully support the print semantic yet. Right now tl.device_print only accepts a str prefix with all args appended to the end. If you want to print something like print(f"msg {k}: {v}"), $k and $v will be evaluated at compile time I think.

@christopherhesse What's your suggestion? Long story short, option (1) device_assert evaluates dynamic values, assert evaluates static (i.e., compile-time) values; option (2) assert evaluates dynamic values, static_assert evaluates static values.

Also, maybe we could enable debug mode by default when there kernel contains a print?

Yeah, we could. I haven't sorted out what features should be disabled in the debug mode though. Could we do it in another PR?

@ptillet
Copy link
Collaborator

ptillet commented Feb 3, 2023

Unfortunately we don't fully support the print semantic yet.

Hmmm. This may be indeed be a little premature. Right now tl.printf is not very useful and would confuse users anyway as it prints once per CUDA thread instead of once per program.

Yeah, we could. I haven't sorted out what features should be disabled in the debug mode though. Could we do it in another PR?

Yes, we can do it in another PR. We just need to create a github issue to make sure we don't lose track of that :)

@ptillet
Copy link
Collaborator

ptillet commented Feb 3, 2023

@Jokeren I think we should go with assert being evaluated at runtime and static_assert being evaluated statically. This also matches the current semantic of static_range vs range.

@christopherhesse
Copy link
Contributor

I'll have to defer to @ptillet here. I do think if you allow the use of the built-in assert and print it could be weird if they behave differently than the normal versions of those. Here's an example from jax where they have to explain how assert only runs at trace time: jax-ml/jax#2273 (comment) (though at least that is consistent with print)

@Jokeren Jokeren marked this pull request as ready for review February 23, 2023 21:50
@Jokeren
Copy link
Contributor Author

Jokeren commented Feb 23, 2023

@ptillet @christopherhesse ready for review

@Jokeren Jokeren deleted the keren/assert branch March 4, 2023 16:08
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 7, 2023
The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 7, 2023
…nd stores"

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 7, 2023
The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 8, 2023
…nd stores"

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 8, 2023
The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

ghstack-source-id: 5694804018dc9649217985b747849e92f97bf224
Pull Request resolved: #98590
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

ghstack-source-id: 5694804018dc9649217985b747849e92f97bf224
Pull Request resolved: #98590
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 13, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 19, 2023
…nd stores"


This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 19, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 19, 2023
The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < 3 | ~xmask, f"index out of bounds: tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(0 <= tmp2 | ~xmask, f"index out of bounds: 0 <= tmp2")
tl.device_assert(tmp2 < 2 | ~xmask, f"index out of bounds: tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(tmp2 < min(ks2, ks1) | ~xmask, f"index out of bounds: tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(0 <= tmp0 | ~xmask, f"index out of bounds: 0 <= tmp0")
tl.device_assert(tmp0 < ks3 | ~xmask, f"index out of bounds: tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

ghstack-source-id: 3c0fb0560c19e31765260ae897e9274067bfcb10
Pull Request resolved: #98590
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Apr 19, 2023
This PR also adds a way to CSE statements (not only assignments).

The tests follow the pattern from triton-lang/triton#1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.

Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...

Generated code for `x[y]` with `x.shape == (3, 2, 4),  y.ndim == 1`:

With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```

Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4),  y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```

With `dynamic=True`:
```python
tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)")
```

The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`

Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```

Fixes #93538

Pull Request resolved: #98590
Approved by: https://github.com/ngimel
@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Mar 21, 2024

I

I'll have to defer to @ptillet here. I do think if you allow the use of the built-in assert and print it could be weird if they behave differently than the normal versions of those. Here's an example from jax where they have to explain how assert only runs at trace time: google/jax#2273 (comment) (though at least that is consistent with print)

This print works as on device print op which will be executed by a warps of 32 threads. This is not necessary.

Could we instead develop a LLVM printop for TritonIR: just print during tiles mapping stage and evaluate the result with CPU for verification purpose ?

pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…f` to `tl.device_print` (triton-lang#1143)

Note that `tl.device_print` and `print` accepts different arguments than
the normal `print`. The first argument must be a string, following by
variables.

Device side:

- `tl.device_print`
- `tl.device_assert`
- `print`
- `assert`

Compilation time:

- `tl.static_assert`
- `tl.static_print`

Usage example:

1.
```Python
tl.device_assert(x == 0, "x != 0")
```

Output:

```Python
...
python/test/unit/language/assert_helper.py:18: kernel: block: [0,0,0], thread: [33,0,0] Assertion `x != 0` failed.
...
```

2.
```Python
tl.device_print("hello ", x)
```

Output:

```Python
...
hello 1
...
```

The environment variable `TRITON_DEBUG` sets the default debugging flag; if it's true, `tl.device_assert` or `assert` will be skipped.
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 5, 2024
…ang#1193)

Prevents `StoreOp`s from being rewritten if it is not suitable for 2D
block store.

Addresses Issue: triton-lang#1143

---------

Signed-off-by: Maxime France-Pillois <maxime.francepillois@codeplay.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants