From 3a97268470fb598050c077a8bfe9a4f30b34cf5a Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 14 Feb 2025 11:51:45 -0800 Subject: [PATCH] patch qnn silu lowering (#8494) Summary: There are some edge lowering case missing. Change the logic to look for the op first because it's not decomposed Reviewed By: billmguo Differential Revision: D69636086 --- backends/qualcomm/_passes/decompose_silu.py | 22 ++++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/backends/qualcomm/_passes/decompose_silu.py b/backends/qualcomm/_passes/decompose_silu.py index 890d965dec2..ca1a566be1e 100644 --- a/backends/qualcomm/_passes/decompose_silu.py +++ b/backends/qualcomm/_passes/decompose_silu.py @@ -7,7 +7,6 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions class DecomposeSilu(ExportPass): @@ -22,24 +21,23 @@ def _copy_meta(self, meta: Dict): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - partitions = get_source_partitions( - graph, [torch.nn.functional.silu, torch.ops.aten.silu.default] - ) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - - inputs = src_partition.input_nodes - silu_node = src_partition.nodes[0] - with graph_module.graph.inserting_after(inputs[0]): + for node in graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.silu.default + ): + silu_node = node + silu_node_input = node.args[0] + with graph_module.graph.inserting_after(silu_node_input): sigmoid_node = graph.create_node( - "call_function", torch.ops.aten.sigmoid, (inputs[0],) + "call_function", torch.ops.aten.sigmoid, (silu_node_input,) ) sigmoid_node.meta = self._copy_meta(silu_node.meta) with graph_module.graph.inserting_after(sigmoid_node): mul_node = graph.create_node( "call_function", torch.ops.aten.mul, - (inputs[0], sigmoid_node), + (silu_node_input, sigmoid_node), ) mul_node.meta = self._copy_meta(silu_node.meta) for user in silu_node.users.copy():