Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .i64_to_i32 import I64toI32
from .insert_io_qdq import InsertIOQDQ
from .insert_requantize import InsertRequantize
from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps
from .layout_transform import LayoutTransform
from .lift_constant_scalar_operands import LiftConstantScalarOperands
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
Expand All @@ -45,7 +46,6 @@
from .seq_mse import SeqMSE
from .tag_quant_io import TagQuantIO


__all__ = [
AnnotateAdaptiveAvgPool1D,
AnnotateQuantAttrs,
Expand Down Expand Up @@ -75,6 +75,7 @@
FuseConsecutiveTranspose,
I64toI32,
InsertIOQDQ,
InsertReshapeForReduceOps,
InsertRequantize,
LayoutTransform,
LiftConstantScalarOperands,
Expand Down
59 changes: 59 additions & 0 deletions backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass


class InsertReshapeForReduceOps(ExportPass):
"""
Rewrite `aten.argmax.default` with `dim=None` into
a reshape-to-1D followed by argmax(dim=0).
PyTorch semantics:
torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0
QNN requires an explicit axis, so we insert the reshape.
"""

def __init__(self):
super().__init__()
self.op_map = {torch.ops.aten.argmax.default, torch.ops.aten.argmin.default}

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
modified = False

for n in graph.nodes:
if n.target in self.op_map:
dim_arg = None if len(n.args) == 1 else n.args[1]

if dim_arg is None:
inp = n.args[0]

# Insert reshape before argmax
with graph.inserting_before(n):
reshape_node = graph.create_node(
"call_function",
torch.ops.aten.reshape.default,
(inp, [-1]),
{},
)
reshape_node.meta = dict(inp.meta)
if "val" in inp.meta:
reshape_node.meta["val"] = inp.meta["val"].reshape(-1)

# Rewrite argmax: take reshape_node as input, set dim=0
n.args = (reshape_node, 0, *n.args[2:])

modified = True

if modified:
graph_module.recompile()
dead_code_elimination_pass(graph_module)

return PassResult(graph_module, modified)
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
I64toI32,
InsertIOQDQ,
InsertRequantize,
InsertReshapeForReduceOps,
LayoutTransform,
LiftConstantScalarOperands,
RecomposePixelUnshuffle,
Expand Down Expand Up @@ -209,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(ReplaceInfValues())
self.add_pass(LiftConstantScalarOperands())
self.add_pass(InsertReshapeForReduceOps())
return self._transform(graph_module)

def transform_for_export_pipeline(
Expand All @@ -229,6 +231,7 @@ def transform_for_export_pipeline(
self.add_pass(ConvertLinearToConv2d(exported_program))
self.add_pass(ConvertSquareToPow())
self.add_pass(LiftConstantScalarOperands())
self.add_pass(InsertReshapeForReduceOps())
self._transform(exported_program.graph_module)
ep = lift_constant_tensor_pass(exported_program)
return ep
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
to_be_implemented_operator = [
exir_ops.edge.aten._adaptive_avg_pool3d.default,
exir_ops.edge.aten.adaptive_max_pool2d.default,
exir_ops.edge.aten.adaptive_max_pool3d.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this pr just about argmax? i see maxpool added here. if it is about reduce ops then please do update the title

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah updated

exir_ops.edge.aten.avg_pool3d.default,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.log10.default,
Expand Down
14 changes: 14 additions & 0 deletions backends/qualcomm/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,17 @@ runtime.python_library(
":test_qnn_delegate"
]
)

runtime.python_test(
name = "test_passes",
srcs = [
"test_passes.py",
],
deps = [
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/backends/qualcomm/_passes:passes",
"//executorch/backends/qualcomm/builders:builders",
],
)
14 changes: 8 additions & 6 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,23 @@ def forward(self, y):


class Argmax(torch.nn.Module):
def __init__(self):
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
x = torch.argmax(x, dim=0, keepdim=True)
return x
return torch.argmax(x, dim=self.dim, keepdim=self.keepdim)


class Argmin(torch.nn.Module):
def __init__(self):
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
x = torch.argmin(x, dim=0, keepdim=True)
return x
return torch.argmin(x, dim=self.dim, keepdim=self.keepdim)


class ArgminViewSqueezeConv2D(torch.nn.Module):
Expand Down
54 changes: 54 additions & 0 deletions backends/qualcomm/tests/test_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest

import torch
from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps


class TestPasses(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this file, it's helpful for us to test all the passes thoroughly.

def test_insert_reshape_for_argmax(self):
class ArgmaxModule(torch.nn.Module):
def forward(self, x):
return torch.argmax(x, dim=None)

mod = ArgmaxModule()

x = torch.tensor([[1.0, 5.0], [3.0, 2.0]])
ep = torch.export.export(mod, (x,))
# Run original module for reference
ref = mod(x)

reshape_nodes = [
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default
]
argmax_nodes = [
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default
]
self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted")
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")

InsertReshapeForReduceOps()(ep.graph_module)

out = ep.graph_module(x)

# Check graph structure: argmax should take a reshape as input
reshape_nodes = [
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default
]
argmax_nodes = [
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default
]
self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted")
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")

argmax_node = argmax_nodes[0]
self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0")

# Execute new graph and compare with reference
out = ep.graph_module(x)
self.assertTrue(
torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}"
)


if __name__ == "__main__":
unittest.main()
128 changes: 114 additions & 14 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,64 @@ def test_qnn_backend_arange(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_argmax(self):
module = Argmax() # noqa: F405
sample_input = (torch.randn(16, 3, 4, 4),)
self.lower_module_and_test_output(module, sample_input)
test_cases = [
{
QCOM_MODULE: Argmax(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
},
{
QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
},
{
QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
},
]

for i, case in enumerate(test_cases):
with self.subTest(i=i):
self.lower_module_and_test_output(
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
)

def test_qnn_backend_argmin(self):
module = Argmin() # noqa: F405
sample_input = (torch.rand(3, 4),)
self.lower_module_and_test_output(module, sample_input)
test_cases = [
{
QCOM_MODULE: Argmin(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
},
{
QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
},
{
QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
},
]

for i, case in enumerate(test_cases):
with self.subTest(i=i):
self.lower_module_and_test_output(
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
)

@unittest.expectedFailure
def test_qnn_backend_asin(self):
Expand Down Expand Up @@ -1797,16 +1847,66 @@ def test_qnn_backend_arange(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_argmax(self):
module = Argmax() # noqa: F405
sample_input = (torch.randn(16, 3, 4, 4),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
test_cases = [
{
QCOM_MODULE: Argmax(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
},
{
QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
},
{
QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
},
]

for i, case in enumerate(test_cases):
with self.subTest(i=i):
module = self.get_qdq_module(
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
)
self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS])

def test_qnn_backend_argmin(self):
module = Argmin() # noqa: F405
sample_input = (torch.randn(16, 3, 4, 4),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
test_cases = [
{
QCOM_MODULE: Argmin(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
},
{
QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
},
{
QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
},
{
QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
},
]

for i, case in enumerate(test_cases):
with self.subTest(i=i):
module = self.get_qdq_module(
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
)
self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS])

def test_qnn_backend_asin(self):
module = Asin() # noqa: F405
Expand Down
Loading