Skip to content

Conversation

ColinPeppler
Copy link
Contributor

@ColinPeppler ColinPeppler commented May 16, 2025

PR

There are a few cases that my previous PR (#153220) didn't cover.

  1. The LHS/RHS matters. Today, if you do torch._check(lhs == rhs) then it will show up as a deferred runtime assert with Eq(lhs, rhs).
  2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. test_size_with_unbacked_add_expr_transitive tests for this.
  3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. test_size_with_unbacked_add_and_mul_expr tests for this.

Device assertion msg

/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.

Autotuning code setup

This is the autotuning code for a concat kernel which takes input tensors (in_buf) and writes them to the (out_buf).

It's important to note the size of in_buf0 is the same as in_buf1 don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).

in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)

If we look into the kernel code, you'll see that tmp9 loads in_buf1 (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.

  • tmp6 makes sure we're only loading with the xindex from 256 to 264.
  • xmask makes sure we're only loading with the xindex within xnumel.
  • tmp6 & xmask together is essentially checking 0 ≤ x0 < u1 + s0 and 256 ≤ x1 < 264.

The mask logic is correct, however, in_buf1 has the shape [8192, 10] this means any load where 8192 ≤ x0 < u1 + s0 will be an OOB load.

def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link

pytorch-bot bot commented May 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153768

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 18dd7c3 with merge base 8ac82a1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ally_apply_size_hint"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 16, 2025
…size_hint

ghstack-source-id: f4396c4
Pull Request resolved: #153768
@jingsh
Copy link
Contributor

jingsh commented May 17, 2025

can you provide some context on the issue?

…ally_apply_size_hint"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
…ally_apply_size_hint"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
continue

lhs, rhs = assertion.expr.lhs, assertion.expr.rhs
l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want lhs.compare(rhs) == 1 to make sure expressions are on the LHS and symbols are on the RHS. If both sides are expressions, then there's tie breakers listed below.

https://github.com/sympy/sympy/blob/2e7baea39cd0d891433bdc2ef2222e445c381ca3/sympy/core/basic.py#L364-L370

Comment on lines +670 to +675
existing_replacement = self.unbacked_replacements.get(src, None)
if existing_replacement and isinstance(
existing_replacement, sympy.Symbol
):
# Prefer to keep replacements with symbols.
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we see torch._check(expr1, expr2) and torch._check(expr2, symbol), then make sure to prioritize the replacement expr2: symbol over expr2: expr1.

Comment on lines 687 to 693
def _sub_unbacked_exprs(expr: Expr) -> Expr:
replacements = self._get_unbacked_replacements()
while True:
new_expr = expr.subs(replacements)
if new_expr == expr:
return new_expr
expr = sympy.factor(new_expr)
Copy link
Contributor Author

@ColinPeppler ColinPeppler May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iteratively make replacements until the expression doesn't change.

I tried using _xreplace but it can only replacesympy objects. It's more pickier, so I went with subs.

https://github.com/sympy/sympy/blob/2e7baea39cd0d891433bdc2ef2222e445c381ca3/sympy/core/basic.py#L1322-L1323

@ColinPeppler
Copy link
Contributor Author

ColinPeppler commented May 20, 2025

inductor-rocm / rocm-py3.10-inductor / test (inductor, 1, 2, linux.rocm.gpu.2)

keeps failing on MI210 due to "Code: 9, Messsage: invalid configuration argument" which is probably related to the kernel launch parameters.

scaling down the unbacked symint fallback might resolve this

…ally_apply_size_hint"

## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.

## Device assertion msg

```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```

## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).

It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```

If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6`  makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.

The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
@laithsakka laithsakka requested review from eellison and embg May 20, 2025 04:26
@ColinPeppler ColinPeppler requested a review from chenyang78 May 20, 2025 18:02

def _sub_unbacked_exprs(expr: Expr) -> Expr:
replacements = self._get_unbacked_replacements()
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this will always converge?

Copy link
Contributor Author

@ColinPeppler ColinPeppler May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Yes, this will always converge.

The only scenario where this wouldn't converge if there was a cycle in the unbacked_replacements. We can guarantee there's never going to be a cycle due to lhs.compare(rhs).

Suppose, there was a cycle.

a -> b
b -> c
c -> a
  • lhs.compare(rhs) == 1 is the same as lhs > rhs
  • If a -> b then a > b.
  • If b -> c then b > c.
  • If c -> a then c > a.
  • Put it all together, we have a > b > c > a which is impossible.
  • Therefore, there is no cycle if we adhere to lhs.compare(rhs).

expr = sympy.factor(expr).subs(unbacked_replacements)
if has_free_unbacked_symbols(expr):

def _sub_unbacked_exprs(expr: Expr) -> Expr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function could be expensive. We should cache it.

…ally_apply_size_hint"

## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.

## Device assertion msg

```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```

## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).

It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```

If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6`  makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.

The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
…ally_apply_size_hint"

## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.

## Device assertion msg

```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```

## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).

It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```

If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6`  makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.

The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 20, 2025
…size_hint

ghstack-source-id: 7d1f669
Pull Request resolved: #153768
@ColinPeppler ColinPeppler requested a review from desertfire May 21, 2025 14:57
@ColinPeppler
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 21, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/ColinPeppler/70/head branch June 21, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants