|
14 | 14 | register_node_visitor, |
15 | 15 | ) |
16 | 16 | from executorch.backends.arm.tosa_mapping import TosaArg |
17 | | -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args |
| 17 | +from executorch.backends.arm.tosa_quant_utils import ( |
| 18 | + build_rescale, |
| 19 | + search_quant_arg_downstream, |
| 20 | + search_quant_arg_upstream, |
| 21 | +) |
18 | 22 |
|
19 | 23 | from executorch.backends.arm.tosa_utils import build_reshape |
20 | | -from executorch.exir.dialects._ops import ops as exir_ops |
21 | 24 | from serializer.tosa_serializer import TosaOp |
22 | 25 |
|
23 | 26 |
|
@@ -67,12 +70,7 @@ def define_node( |
67 | 70 | input_zp = 0 |
68 | 71 | if is_quant_node: |
69 | 72 | input_node = node.all_input_nodes[1] |
70 | | - # rank > 2 linear layer |
71 | | - if input_node.target == exir_ops.edge.aten.view_copy.default: |
72 | | - quant_node = input_node.all_input_nodes[0] |
73 | | - else: |
74 | | - quant_node = input_node |
75 | | - input_zp = get_quant_node_args(quant_node).zp |
| 73 | + input_zp = search_quant_arg_upstream(input_node).zp |
76 | 74 | attr.ConvAttribute( |
77 | 75 | pad=pad_attr, |
78 | 76 | stride=stride_attr, |
@@ -107,24 +105,16 @@ def define_node( |
107 | 105 | # Read inputs' parent nodes |
108 | 106 | _, input_node, weight_node = node.all_input_nodes |
109 | 107 |
|
110 | | - # rank > 2 linear layer |
111 | | - if input_node.target == exir_ops.edge.aten.view_copy.default: |
112 | | - quant_node = input_node.all_input_nodes[0] |
113 | | - input_scale = get_quant_node_args(quant_node).scale |
114 | | - consumer_node = list(node.users)[0] |
115 | | - consumer_consumer_node = list(consumer_node.users)[0] |
116 | | - quant_args = get_quant_node_args(consumer_consumer_node) |
117 | | - consumer_node_scale = quant_args.scale |
118 | | - consumer_node_node_zp = quant_args.zp |
119 | | - else: |
120 | | - input_scale = get_quant_node_args(input_node).scale |
121 | | - consumer_node = list(node.users)[0] |
122 | | - quant_args = get_quant_node_args(consumer_node) |
123 | | - consumer_node_scale = quant_args.scale |
124 | | - consumer_node_node_zp = quant_args.zp |
| 108 | + qargs = search_quant_arg_upstream(input_node) |
| 109 | + input_scale = qargs.scale |
| 110 | + consumer_node = list(node.users)[0] |
| 111 | + quant_args = search_quant_arg_downstream(consumer_node) |
| 112 | + |
| 113 | + consumer_node_scale = quant_args.scale |
| 114 | + consumer_node_node_zp = quant_args.zp |
125 | 115 |
|
126 | 116 | weight_node_q_node = weight_node.all_input_nodes[0] |
127 | | - weight_scale = get_quant_node_args(weight_node_q_node).scale |
| 117 | + weight_scale = search_quant_arg_upstream(weight_node_q_node).scale |
128 | 118 |
|
129 | 119 | output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale |
130 | 120 |
|
|
0 commit comments