-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quant] Support lowering of channel shuffle in FX #83731
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 464c7fd (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
051843b
to
845a466
Compare
import torch | ||
|
||
class ChannelShuffle(torch.nn.ChannelShuffle): | ||
def __init__(self, groups: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this constructor needed? since it just passes through the arguments
the base-derived constructor would be sufficient?
and maybe same question about forward impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. They are removed.
7cbcfc8
to
22533a5
Compare
@vadimkantorov Please review again. Thanks. |
@Xia-Weiwen I'm not a member of PyTorch team, so a review from someone else is required :) |
I see. Thank you anyway. |
Hi @jerryzh168 Could you please review this PR? Or could you ask someone else to review? Thanks! |
torch.nn.BatchNorm2d, | ||
torch.nn.BatchNorm3d, | ||
torch.nn.ChannelShuffle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have two channel shuffle here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the typo. Now they are removed.
@@ -0,0 +1,16 @@ | |||
import torch | |||
|
|||
class ChannelShuffle(torch.nn.ChannelShuffle): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the float version works for quantized inputs as well, I don't think we will need to define an extra quantized module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be a part of this list in that case I think: https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/_lower_to_native_backend.py#L122
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I have removed them.
@@ -353,6 +353,7 @@ def _get_default_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPat | |||
torch.nn.LayerNorm, | |||
torch.nn.Dropout, | |||
torch.nn.PReLU, | |||
torch.nn.ChannelShuffle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should add this to _get_share_qprams_op_backend_config I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added channel shuffle to copy node list. Is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
channel shuffle should be categorized as copy node I think, so we don't need to add new quantized modules and configs and lowering need to change as well, please see inline comments for changes
21caf8c
to
8371853
Compare
Thanks for reviewing. I have removed most of previous changes. Now I only add it to the |
ec6c344
to
874872f
Compare
models = [M1().eval(), M2().eval(), M3().eval()] | ||
# torch.channel_shuffle is torch.nn.functional.channel_shuffle | ||
expected_nodes = [ | ||
ns.call_module(torch.nn.ChannelShuffle), | ||
ns.call_function(torch.channel_shuffle), | ||
ns.call_function(torch.channel_shuffle) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: might be better to just write a list of tuple of (model, expected_node) so that it's clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. It's done. Please take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, please add expected channel shuffle op to node_occurrence
as well
2fc90cc
to
0bae9de
Compare
Hi @jerryzh168 I changed per your comments. All checks have passed. Do you have more comments? Thanks. |
0bae9de
to
a6a6e30
Compare
a6a6e30
to
2a92764
Compare
Hi @jerryzh168 Do you think it's ok to land this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, please add a comment in torch/ao/quantization/fx/_lower_to_native_backend.py
about why we only need torch.channel_shuffle
and not F.channel_shuffle
Thanks. I will add that comment and land it after all CI checks pass. |
…tQuantizeFx.test_channel_shuffle_lowering
2a92764
to
9f66a0f
Compare
…wer_to_native_backend.py
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Description Support lowering of channel shuffle in FX by adding its module and functional op to `is_copy_node` list in `torch/ao/quantization/fx/_lower_to_native_backend.py` ## Validation UTs added to test - correctness of quantized `ChannelShuffle` module. - FX lowering of `ChannelShuffle` module and functional `channel_shuffle`. Pull Request resolved: pytorch#83731 Approved by: https://github.com/jerryzh168
Description
Support lowering of channel shuffle in FX by adding its module and functional op to
is_copy_node
list intorch/ao/quantization/fx/_lower_to_native_backend.py
Validation
UTs added to test
ChannelShuffle
module.ChannelShuffle
module and functionalchannel_shuffle
.cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @leslie-fang-intel @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @ezyang @SherlockNoMad @soumith @EikanWang @wenzhe-nrv