From 83fd30b2bf652921475c37505d8c87fde6df49de Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 10 Nov 2025 13:32:26 -0800 Subject: [PATCH] [ET-VK][ez] Apply quantize op replacement to all argument nodes Pull Request resolved: https://github.com/pytorch/executorch/pull/15702 Title says it all! With the way the pass is currently written only the first arg will be inspected for q/dq node replacement. As a consequence, the second arg for i.e. binary ops may not have the quantized op be replaced. ghstack-source-id: 322214453 @exported-using-ghexport Differential Revision: [D86674169](https://our.internmc.facebook.com/intern/diff/D86674169/) --- backends/vulkan/_passes/replace_qdq.py | 33 +++++++++++++------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py index fcfcdfc4c18..2c5331eb213 100644 --- a/backends/vulkan/_passes/replace_qdq.py +++ b/backends/vulkan/_passes/replace_qdq.py @@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule): exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, ]: - # Replace quantize op feeding into conv2d (first argument is the quantized input) - quantized_input_node = node.args[0] - if isinstance( - quantized_input_node, torch.fx.Node - ) and utils.is_quant_node(quantized_input_node): - # Get the arguments from the original quantize node - input_tensor = quantized_input_node.args[0] - scale = quantized_input_node.args[1] - zero_point = quantized_input_node.args[2] + for quantized_input_node in node.args: + if isinstance( + quantized_input_node, torch.fx.Node + ) and utils.is_quant_node(quantized_input_node): + # Get the arguments from the original quantize node + input_tensor = quantized_input_node.args[0] + scale = quantized_input_node.args[1] + zero_point = quantized_input_node.args[2] - nodes_to_replace.append( - { - "old_node": quantized_input_node, - "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, - "args": (input_tensor, scale, zero_point), - "node_type": "quantize_input", - } - ) + nodes_to_replace.append( + { + "old_node": quantized_input_node, + "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + "args": (input_tensor, scale, zero_point), + "node_type": "quantize_input", + } + ) # Find dequantize ops that consume the output of this conv2d for user in node.users: