-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[Inductor] add contiguous layout optm for bmm input #122599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122599
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 13c5c18 with merge base 03a05e7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
f222739
to
466338b
Compare
torch/_inductor/kernel/bmm.py
Outdated
if ( | ||
isinstance(t.data, ir.View) | ||
and isinstance(t.data.data, ir.PermuteView) | ||
and t.data.data.dims == [0, 3, 1, 2] | ||
): | ||
t = ir.Pointwise.create( | ||
device=t.get_device(), | ||
dtype=t.get_dtype(), | ||
inner_fn=t.make_loader(), | ||
ranges=t.get_size(), | ||
origin_node=t.get_origin_node(), | ||
traceback=t.get_traceback(), | ||
) | ||
t.realize() | ||
t.freeze_layout() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several issues with this change:
- It only handles a specific case where the tensor is in 4D and permuted with a particular order. Can we make it general? Basically, we want a particular order of the last two dims?
- Related to 1, bmm can actually handle non-contiguous cases and also transposed cases for the last two dims. It only requires one of the dims to be contiguous while the other can have a stride larger than the size of the former dim. Is it too strict to always force contiguous here?
- Perhaps we can call
require_stride_order
instead of implementing thecopy_input
again here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestions!
A new function called is_pointwise_contiguous_or_transposed_after_perm
is added to deduce the layout of a pointwise from its readers, because it has not been realized yet.
6300552
to
9e4b415
Compare
8d8509e
to
c3fd623
Compare
@leslie-fang-intel @jgong5 Hi, I did some code changes. Please review again :) |
e853e82
to
758f0b3
Compare
test/inductor/test_torchinductor.py
Outdated
self.assertEqual(out_expected, out_actual) | ||
self.assertEqual(out_expected.stride(), out_actual.stride()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this particular case, does the test pass with or without the PR change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is just to check the accuracy with the PR. I would remove it.
torch/_inductor/kernel/bmm.py
Outdated
return t | ||
|
||
if all(x.get_device().type == "cpu" for x in [mat1, mat2]): | ||
if not ir.is_storage_and_layout(mat1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about mat1
or mat2
has flexible layout? Shall we also apply the similar logic here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified!
torch/_inductor/kernel/bmm.py
Outdated
# Make the inputs of bmm contiguous | ||
# because bmm cpu implementation does contiguous() if not | ||
# this is to avoid additional copies in bmm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment accurate? If the input is transposed in the last two dims, would bmm still make the inputs contiguous?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified!
torch/_inductor/ir.py
Outdated
) | ||
|
||
|
||
def is_contiguous_or_transposed(sizes, strides): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transposition can happen in any pair of dims, not necessarily the last two dims. Also you are checking stride >= size
not stride == size
. The function name doesn't seem to match the implementation here. Perhaps, it is clearer we just inline the implementation in the tuned_bmm
code without factoring out a util function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified!
torch/_inductor/kernel/bmm.py
Outdated
# Make the inputs of bmm contiguous | ||
# because bmm cpu implementation does contiguous() if not | ||
# this is to avoid additional copies in bmm | ||
def do_bmm_input_contiguous(t, meta_t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about may_require_contiguous
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and applied!
torch/_inductor/kernel/bmm.py
Outdated
if not ir.is_storage_and_layout(t): | ||
return True | ||
_, layout = ir.as_storage_and_layout(t, freeze=False) | ||
return not isinstance(layout, ir.FixedLayout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not checking ir.FlexiableLyout
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified to ir.FlexiableLyout
.
bbd71a7
to
13c5c18
Compare
@eellison Hi, please help review the PR. Thanks! |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#117743. Add contiguous layout optimization for `bmm` input, to avoid additional copies. Pull Request resolved: pytorch#122599 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/eellison
Fixes #117743. Add contiguous layout optimization for `bmm` input, to avoid additional copies. Pull Request resolved: #122599 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/eellison
Fixes #117743.
Add contiguous layout optimization for
bmm
input, to avoid additional copies.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang