3333from executorch .backends .cadence .aot .utils import get_edge_overload_packet
3434from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
3535from executorch .exir .dialects ._ops import ops as exir_ops
36- from executorch .exir .dialects .edge ._ops import EdgeOpOverload
36+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
3737from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
3838from executorch .exir .pass_manager import PassManager , PassType
3939from executorch .exir .passes import dead_code_elimination_pass
@@ -745,6 +745,68 @@ def permute_shape(
745745 return [shape [p ] for p in permute_dims ]
746746
747747
748+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
749+ class RemoveBranchedQuantDequant (ExportPass ):
750+ """
751+ This pass looks for adjacent quant and dequant nodes with identical
752+ parameters, where the quant node has other users in addition to the
753+ dequant. The quant and dequant pair would be removed by the
754+ FuseQuantDequantToRequantizePass if not for the multiple users. This pass
755+ removes just the dequant node by connecting it to the quant's parent node
756+ """
757+
758+ quantize_op_packets : set [EdgeOpOverloadPacket ] = {
759+ exir_ops .edge .cadence .quantize_per_tensor ,
760+ exir_ops .edge .quantized_decomposed .quantize_per_tensor ,
761+ }
762+ dequantize_op_packets : set [EdgeOpOverloadPacket ] = {
763+ exir_ops .edge .cadence .dequantize_per_tensor ,
764+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor ,
765+ }
766+
767+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
768+ self .remove_branched (
769+ graph_module , self .quantize_op_packets , self .dequantize_op_packets
770+ )
771+ self .remove_branched (
772+ graph_module , self .dequantize_op_packets , self .quantize_op_packets
773+ )
774+
775+ graph_module .graph .eliminate_dead_code ()
776+ result = super ().call (graph_module )
777+ return result
778+
779+ def remove_branched (
780+ self ,
781+ graph_module : torch .fx .GraphModule ,
782+ producer_pkts : set [EdgeOpOverloadPacket ],
783+ consumer_pkts : set [EdgeOpOverloadPacket ],
784+ ) -> None :
785+ for node in graph_module .graph .nodes :
786+ if (
787+ node .op != "call_function"
788+ or not isinstance (node .target , EdgeOpOverload )
789+ or get_edge_overload_packet (node .target ) not in producer_pkts
790+ ):
791+ continue
792+
793+ if len (node .users ) < 2 :
794+ continue
795+
796+ for user in node .users :
797+ if (
798+ not isinstance (user .target , EdgeOpOverload )
799+ or get_edge_overload_packet (user .target ) not in consumer_pkts
800+ ):
801+ continue
802+
803+ # check qparams match
804+ if node .args [1 :] != user .args [1 :]:
805+ continue
806+
807+ user .replace_all_uses_with (node .args [0 ])
808+
809+
748810# The following class consolidates functions to remove ops that are redundant
749811# in Jarvis. Currently, each function in this class iterates over each node of
750812# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +827,5 @@ class CadenceRemoveNops:
765827 RemoveNopMulOpPass ,
766828 RemoveNopAddOpPass ,
767829 RemoveNopLinalgVectorNormOpPass ,
830+ RemoveBranchedQuantDequant ,
768831 ]
0 commit comments