99from  typing  import  cast 
1010
1111import  torch 
12- from  executorch .backends .arm ._passes .arm_pass_utils  import  get_first_fake_tensor 
12+ from  executorch .backends .arm ._passes .arm_pass_utils  import  (
13+     create_node ,
14+     get_first_fake_tensor ,
15+ )
1316from  executorch .backends .arm .tosa_quant_utils  import  dq_op 
1417from  executorch .backends .arm .tosa_utils  import  is_consumer_node_depthwise_conv2d 
18+ from  executorch .exir .dialects ._ops  import  ops  as  exir_ops 
1519from  executorch .exir .pass_base  import  ExportPass , PassResult 
20+ from  torch .library  import  impl , Library 
21+ 
22+ # Define lib with passthrough operators. The operators have no real meaning in edge IR 
23+ # except for argument validaiton and a passthrough output. The operators will be used 
24+ # when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect 
25+ # the edge IR graph but will be lowered to a TOSA-TRANSPOSE. 
26+ lib  =  Library ("passthrough_to_tosa" , "DEF" )
27+ # For operators that change the rank of the input, such as unsqueeze and squeeze, we may need 
28+ # to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient 
29+ # as we also need transpose the data into the correct data format. 
30+ # By utilizing an edge IR passthrough operator we can keep the edge program in 
31+ # channels-first/contiguous and get the desired behavior in the TOSA lowering. 
32+ lib .define ("_transpose(Tensor self, int[] dim_order) -> Tensor" )
33+ 
34+ 
35+ @impl (lib , "_transpose" ) 
36+ def  _transpose_impl (* args , ** kwargs ):
37+     # Validate length of dim_order array 
38+     dim  =  args [1 ]
39+     assert  len (dim ) <=  4 
40+     # Pass-through in edge-IR 
41+     return  args [0 ]
1642
1743
1844class  AnnotateChannelsLastDimOrder (ExportPass ):
1945    """ 
2046    Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order 
21-     that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. 
22-     The annotated tosa_dim_order is used to permute the node's shape such that it  
23-     gives a TOSA-compliant shape. 
47+     that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose  
48+     when a transition between 3D and 4D tensors happen.  
49+     The annotated tosa_dim_order is used to permute the node's shape such that it  gives a TOSA-compliant shape. 
2450    """ 
2551
52+     NHWC_order  =  (0 , 2 , 3 , 1 )
53+     NHWC_inverse_order  =  (0 , 3 , 1 , 2 )
54+     HWCM_order  =  (2 , 3 , 0 , 1 )
55+ 
2656    def  is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
2757        """ 
2858        returns True for dq and w in the following sequences; 
@@ -49,20 +79,56 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
4979
5080        return  False 
5181
82+     def  insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
83+         for  node  in  graph_module .graph .nodes :
84+             if  node .op  !=  "call_function" :
85+                 continue 
86+             if  node .target  ==  exir_ops .edge .aten .squeeze_copy .dims :
87+                 input_node  =  node .args [0 ]
88+                 if  input_node .meta ["val" ].dim () ==  4 :
89+                     with  graph_module .graph .inserting_before (node ):
90+                         permute_node  =  create_node (
91+                             graph_module .graph ,
92+                             torch .ops .passthrough_to_tosa ._transpose ,
93+                             args = (input_node , list (self .NHWC_inverse_order )),
94+                         )
95+                         permute_node .meta ["tosa_dim_order" ] =  tuple (
96+                             range (len (input_node .meta ["val" ].size ()))
97+                         )
98+                         node .replace_input_with (input_node , permute_node )
99+ 
100+             if  node .target  ==  exir_ops .edge .aten .unsqueeze_copy .default :
101+                 if  node .meta ["val" ].dim () ==  4 :
102+                     with  graph_module .graph .inserting_after (node ):
103+                         permute_node  =  create_node (
104+                             graph_module .graph ,
105+                             torch .ops .passthrough_to_tosa ._transpose ,
106+                             args = (node , list (self .NHWC_order )),
107+                         )
108+                         permute_node .meta ["tosa_dim_order" ] =  self .NHWC_order 
109+                         node .meta ["tosa_dim_order" ] =  (0 , 1 , 2 , 3 )
110+                         users  =  [user  for  user  in  node .users  if  user  !=  permute_node ]
111+                         for  user  in  users :
112+                             user .replace_input_with (node , permute_node )
113+ 
52114    def  call (self , graph_module : torch .fx .GraphModule ):
53-         NHWC_Order  =  (0 , 2 , 3 , 1 )
54-         HWCM_Order  =  (2 , 3 , 0 , 1 )
55115        for  node  in  graph_module .graph .nodes :
56116            node_data  =  get_first_fake_tensor (node ).data 
57117
58-             if  len ( node_data .shape ) ==  4 :
59-                 dim_order  =  NHWC_Order 
118+             if  node_data .dim ( ) ==  4 :
119+                 dim_order  =  self . NHWC_order 
60120                if  self .is_weight_node_for_depthwise_conv2d (node ):
61121                    # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to 
62122                    # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). 
63-                     dim_order  =  HWCM_Order 
123+                     dim_order  =  self . HWCM_order 
64124            else :
65125                dim_order  =  tuple (range (node_data .dim ()))
66126            node .meta ["tosa_dim_order" ] =  dim_order 
127+         # Take care of cases when: 
128+         # 4D (NHWC) -> >4D (NCH) 
129+         # 3D (NCH)  ->  4D (NHWC) 
130+         self .insert_tosa_transposes (graph_module )
67131        graph_module .recompile ()
132+         graph_module  =  super ().call (graph_module ).graph_module 
133+ 
68134        return  PassResult (graph_module , True )
0 commit comments