-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[inductor] modify the output_stride of ConcatKernel #122761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122761
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 0877fe9 with merge base 19f5033 ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
torch/_inductor/ir.py
Outdated
|
||
output_stride = FlexibleLayout.contiguous_strides(new_size) | ||
# If any of the inputs is in CL format, use CL format for the output | ||
any_input_is_storage_and_layout = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs)
looks clearer?
torch/_inductor/ir.py
Outdated
if isinstance(x, TensorBox): | ||
return is_pointwise_with_channels_last_inputs(x.data) | ||
if isinstance(x, StorageBox) and isinstance(x.data, Pointwise): | ||
all_reads = x.data.get_reads() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar as in #122760, we can get the buffer from read.name
. Will it be easy to check the buffer's layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that we can share some logic from #122760 too? Both need to get layouts of the input buffers.
torch/_inductor/ir.py
Outdated
if any_input_is_storage_and_layout is False: | ||
if all( | ||
is_pointwise_with_channels_last_inputs(input) for input in inputs | ||
) and len(new_size) in [4, 5]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_pointwise_with_channels_last_inputs
should already cover this check? Why check it again here?
torch/_inductor/ir.py
Outdated
return False | ||
|
||
return len(all_reads) >= 1 and all( | ||
type(read) is dependencies.MemoryDep and is_channel_last_index(read) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type(read) is dependencies.MemoryDep and is_channel_last_index(read) | |
isinstance(read, dependencies.MemoryDep) and is_channels_last_index(read) |
torch/_inductor/ir.py
Outdated
output_stride = make_channels_last_strides_for(new_size) | ||
break | ||
if any_input_is_storage_and_layout is False: | ||
if all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explain why we requires "all" if none of the inputs is storage_and_layout instead of "any" on any input is storage_and_layout.
torch/_inductor/ir.py
Outdated
if isinstance(x, TensorBox): | ||
return is_pointwise_with_channels_last_inputs(x.data) | ||
if isinstance(x, StorageBox) and isinstance(x.data, Pointwise): | ||
all_reads = x.data.get_reads() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that we can share some logic from #122760 too? Both need to get layouts of the input buffers.
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
LGTM, just a few more comments. |
fx_node_args = [ | ||
fx_node_args, | ||
] | ||
if any_input_is_storage_and_layout is False and any( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to add a note here explaining the heuristics we use here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks!
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix #121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
@jansel , could you please review this PR? Thanks. |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fix pytorch#121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. Pull Request resolved: pytorch#122761 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
Fix pytorch#121613. Modify the `output_stride` of `ConcatKernel`: If any input to `Concat` is `Pointwise`, check the layout of all inputs to `Pointwise`, if any of the inputs is in channels_last format, set channels_last strides for the `output_stride`. Pull Request resolved: pytorch#122761 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
Stack from ghstack (oldest at bottom):
Fix #121613.
Modify the
output_stride
ofConcatKernel
: If any input toConcat
isPointwise
, check the layout of all inputs toPointwise
, if any of the inputs is in channels_last format, set channels_last strides for theoutput_stride
.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang