-
Notifications
You must be signed in to change notification settings - Fork 685
support qnn mean (dim=None) #14675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support qnn mean (dim=None) #14675
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/14675
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
self.lower_module_and_test_output(module, sample_input) | ||
|
||
def test_qnn_backend_mean(self): | ||
modules = [Mean(), Mean()] # noqa: F405 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need more configurations here? like Mean(dim=0)
, Mean(dim=0, keepdim=True)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
4e25030
to
6948323
Compare
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
6948323
to
c50ec69
Compare
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
c50ec69
to
2b6dc7f
Compare
Does it look good? I think it can fix 5 failing op tests |
2b6dc7f
to
b39d9fb
Compare
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
# Scalar case | ||
{ | ||
QCOM_MODULE: Mean(), | ||
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be torch.tensor(5.0) if you want to test scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is indeed a good test. ReduceMean doesn't support 0d tensor, I think we need to have a pass to convert 0d to 1d tensor if the input of mean is 0d. I have some sketch code here. What do you think? For now, I comment out the test case and think we can follow up on this
import torch
from executorch.exir.pass_base import ExportPass, PassResult
class Rank0ToRank1(ExportPass):
"""
For selected ops and selected input positions, if the input is rank-0 (scalar),
insert a reshape to [1] before the op.
"""
def __init__(self, op_input_map=None) -> None:
super().__init__()
# key is the op, value is the input indices to be reshaped
self.op_input_map = {
torch.ops.aten.mean.dim: [0],
}
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
changed = False
for node in list(graph.nodes):
if node.op == "call_function" and node.target in self.op_input_map:
input_indices = self.op_input_map[node.target]
new_args = list(node.args)
for idx in input_indices:
if idx < len(new_args):
inp = new_args[idx]
if hasattr(inp, "meta"):
val = inp.meta.get("val", None)
if val is not None and hasattr(val, "shape") and val.shape == ():
# Insert reshape right before the op
with graph.inserting_before(node):
reshape_node = graph.call_function(
torch.ops.aten.reshape.default,
args=(inp, (1,))
)
reshape_node.meta["val"] = val.reshape(1,)
# Replace arg idx with reshape_node
new_args[idx] = reshape_node
changed = True
# update node args if modified
node.args = tuple(new_args)
if changed:
graph_module.recompile()
return PassResult(graph_module, changed)
b39d9fb
to
cd343ee
Compare
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
cd343ee
to
70f1009
Compare
70f1009
to
255c5c9
Compare
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
255c5c9
to
1ffcbd9
Compare
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
1ffcbd9
to
0a200df
Compare
@pytorchbot cherry-pick --onto release/1.0 -c regression |
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776 (cherry picked from commit 9ab5592)
Cherry picking #14675The cherry pick PR is at #14755 and it is recommended to link a regression cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape
Differential Revision: D83520776