-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[aoti] fix corner case in unbacked replacements for atomically_apply_size_hint #153768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…size_hint [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153768
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 18dd7c3 with merge base 8ac82a1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ally_apply_size_hint" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
can you provide some context on the issue? |
…ally_apply_size_hint" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ally_apply_size_hint" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
continue | ||
|
||
lhs, rhs = assertion.expr.lhs, assertion.expr.rhs | ||
l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want lhs.compare(rhs) == 1
to make sure expressions are on the LHS and symbols are on the RHS. If both sides are expressions, then there's tie breakers listed below.
existing_replacement = self.unbacked_replacements.get(src, None) | ||
if existing_replacement and isinstance( | ||
existing_replacement, sympy.Symbol | ||
): | ||
# Prefer to keep replacements with symbols. | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we see torch._check(expr1, expr2)
and torch._check(expr2, symbol)
, then make sure to prioritize the replacement expr2: symbol
over expr2: expr1
.
torch/_inductor/sizevars.py
Outdated
def _sub_unbacked_exprs(expr: Expr) -> Expr: | ||
replacements = self._get_unbacked_replacements() | ||
while True: | ||
new_expr = expr.subs(replacements) | ||
if new_expr == expr: | ||
return new_expr | ||
expr = sympy.factor(new_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iteratively make replacements until the expression doesn't change.
I tried using _xreplace
but it can only replacesympy
objects. It's more pickier, so I went with subs
.
inductor-rocm / rocm-py3.10-inductor / test (inductor, 1, 2, linux.rocm.gpu.2) keeps failing on MI210 due to "Code: 9, Messsage: invalid configuration argument" which is probably related to the kernel launch parameters. scaling down the unbacked symint fallback might resolve this |
…ally_apply_size_hint" ## PR There are a few cases that my previous PR (#153220) didn't cover. 1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`. 2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this. 3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this. ## Device assertion msg ``` /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ... /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ``` ## Autotuning code setup This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`). It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1). ``` in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256) in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10) ... out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10) triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...) ``` If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads. - `tmp6` makes sure we're only loading with the `xindex` from 256 to 264. - `xmask` makes sure we're only loading with the `xindex` within `xnumel`. - `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`. The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load. ``` def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = (xindex % 264) x1 = xindex // 264 ... tmp6 = x0 >= tl.full([1], value=256) tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask) # device assertion is thrown here tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0") ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
torch/_inductor/sizevars.py
Outdated
|
||
def _sub_unbacked_exprs(expr: Expr) -> Expr: | ||
replacements = self._get_unbacked_replacements() | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure this will always converge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Yes, this will always converge.
The only scenario where this wouldn't converge if there was a cycle in the unbacked_replacements
. We can guarantee there's never going to be a cycle due to lhs.compare(rhs)
.
Suppose, there was a cycle.
a -> b
b -> c
c -> a
lhs.compare(rhs) == 1
is the same aslhs > rhs
- If
a -> b
thena > b
. - If
b -> c
thenb > c
. - If
c -> a
thenc > a
. - Put it all together, we have
a > b > c > a
which is impossible. - Therefore, there is no cycle if we adhere to
lhs.compare(rhs)
.
torch/_inductor/sizevars.py
Outdated
expr = sympy.factor(expr).subs(unbacked_replacements) | ||
if has_free_unbacked_symbols(expr): | ||
|
||
def _sub_unbacked_exprs(expr: Expr) -> Expr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function could be expensive. We should cache it.
…ally_apply_size_hint" ## PR There are a few cases that my previous PR (#153220) didn't cover. 1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`. 2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this. 3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this. ## Device assertion msg ``` /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ... /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ``` ## Autotuning code setup This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`). It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1). ``` in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256) in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10) ... out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10) triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...) ``` If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads. - `tmp6` makes sure we're only loading with the `xindex` from 256 to 264. - `xmask` makes sure we're only loading with the `xindex` within `xnumel`. - `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`. The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load. ``` def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = (xindex % 264) x1 = xindex // 264 ... tmp6 = x0 >= tl.full([1], value=256) tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask) # device assertion is thrown here tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0") ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ally_apply_size_hint" ## PR There are a few cases that my previous PR (#153220) didn't cover. 1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`. 2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this. 3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this. ## Device assertion msg ``` /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ... /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ``` ## Autotuning code setup This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`). It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1). ``` in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256) in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10) ... out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10) triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...) ``` If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads. - `tmp6` makes sure we're only loading with the `xindex` from 256 to 264. - `xmask` makes sure we're only loading with the `xindex` within `xnumel`. - `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`. The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load. ``` def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = (xindex % 264) x1 = xindex // 264 ... tmp6 = x0 >= tl.full([1], value=256) tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask) # device assertion is thrown here tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0") ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
PR
There are a few cases that my previous PR (#153220) didn't cover.
torch._check(lhs == rhs)
then it will show up as a deferred runtime assert withEq(lhs, rhs)
.test_size_with_unbacked_add_expr_transitive
tests for this.test_size_with_unbacked_add_and_mul_expr
tests for this.Device assertion msg
Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (
in_buf
) and writes them to the (out_buf
).It's important to note the size of
in_buf0
is the same asin_buf1
don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).If we look into the kernel code, you'll see that
tmp9
loadsin_buf1
(our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.tmp6
makes sure we're only loading with thexindex
from 256 to 264.xmask
makes sure we're only loading with thexindex
withinxnumel
.tmp6 & xmask
together is essentially checking0 ≤ x0 < u1 + s0
and256 ≤ x1 < 264
.The mask logic is correct, however,
in_buf1
has the shape[8192, 10]
this means any load where8192 ≤ x0 < u1 + s0
will be an OOB load.Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov