55
66# pyre-unsafe
77
8- from typing import cast
98
109import torch
1110from executorch .backends .arm ._passes .arm_pass_utils import (
1211 create_node ,
1312 get_first_fake_tensor ,
14- insert_q_dq_pair ,
1513)
16- from executorch .backends .arm .tosa_quant_utils import dq_op , q_op
1714from executorch .backends .arm .tosa_utils import is_consumer_node_depthwise_conv2d
1815from executorch .exir .dialects ._ops import ops as exir_ops
1916from executorch .exir .pass_base import ExportPass , PassResult
@@ -59,20 +56,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
5956
6057 def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
6158 """
62- returns True for dq and w in the following sequences ;
59+ returns True for w in the following sequence ;
6360 w -> depthwise_conv2d -> ...
64- w -> dq -> depthwise_conv2d -> ...
6561 """
66- if node .op == "call_function" :
67- if node .target != dq_op :
68- return False
69- prev_node = node .args [0 ]
70- if cast (torch .fx .Node , prev_node ).op != "placeholder" :
71- return False
72- if is_consumer_node_depthwise_conv2d (node ):
73- consumer_node = list (node .users )[0 ]
74- return consumer_node .args [1 ] == node
75- elif node .op == "placeholder" :
62+ if node .op == "placeholder" :
7663 # node is an input, weight or bias node
7764 consumer_node = list (node .users )[0 ]
7865 if self .is_weight_node_for_depthwise_conv2d (consumer_node ):
@@ -129,8 +116,6 @@ def is_channel_reshape(input_shape, output_shape):
129116
130117 @staticmethod
131118 def insert_input_transpose (node , input_node , graph_module ):
132- quantize = input_node .target == dq_op
133- q_params = input_node .args [1 :] if quantize else None
134119 with graph_module .graph .inserting_before (node ):
135120 permute_node = create_node (
136121 graph_module .graph ,
@@ -143,8 +128,6 @@ def insert_input_transpose(node, input_node, graph_module):
143128 else AnnotateChannelsLastDimOrder .NHWC_inverse_order
144129 ),
145130 ),
146- quantize = quantize ,
147- q_params = q_params ,
148131 )
149132 node .replace_input_with (input_node , permute_node )
150133
@@ -185,11 +168,6 @@ def insert_output_transpose(node, graph_module):
185168 for user in users :
186169 user .replace_input_with (node , permute_node )
187170
188- quantize = node .args [0 ] == q_op
189- if quantize :
190- q_params = node .args [0 ].args [1 :]
191- insert_q_dq_pair (graph_module .graph , node , q_params )
192-
193171 @staticmethod
194172 def _insert_view_transpose (
195173 input_shape , output_shape , node , input_node , graph_module
0 commit comments