diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py index fcfcdfc4c18..2c5331eb213 100644 --- a/backends/vulkan/_passes/replace_qdq.py +++ b/backends/vulkan/_passes/replace_qdq.py @@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule): exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, ]: - # Replace quantize op feeding into conv2d (first argument is the quantized input) - quantized_input_node = node.args[0] - if isinstance( - quantized_input_node, torch.fx.Node - ) and utils.is_quant_node(quantized_input_node): - # Get the arguments from the original quantize node - input_tensor = quantized_input_node.args[0] - scale = quantized_input_node.args[1] - zero_point = quantized_input_node.args[2] + for quantized_input_node in node.args: + if isinstance( + quantized_input_node, torch.fx.Node + ) and utils.is_quant_node(quantized_input_node): + # Get the arguments from the original quantize node + input_tensor = quantized_input_node.args[0] + scale = quantized_input_node.args[1] + zero_point = quantized_input_node.args[2] - nodes_to_replace.append( - { - "old_node": quantized_input_node, - "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, - "args": (input_tensor, scale, zero_point), - "node_type": "quantize_input", - } - ) + nodes_to_replace.append( + { + "old_node": quantized_input_node, + "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + "args": (input_tensor, scale, zero_point), + "node_type": "quantize_input", + } + ) # Find dequantize ops that consume the output of this conv2d for user in node.users: