From 8d9c9af076faea8bdf1ada319053febb3e462d11 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Thu, 13 Nov 2025 17:43:31 -0800 Subject: [PATCH] Support eq.Scalar (#15792) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/15792 Differential Revision: D86891707 --- backends/qualcomm/builders/node_visitor.py | 1 + backends/qualcomm/builders/op_eq.py | 53 +++++++++++++++++--- backends/qualcomm/tests/models.py | 11 ++++ backends/qualcomm/tests/test_qnn_delegate.py | 33 ++++++++++++ 4 files changed, 91 insertions(+), 7 deletions(-) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index f3dadb99129..c918ae0fcf7 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -466,6 +466,7 @@ def define_tensor( tensor_source_node, target_build_node ) dtype = self.get_data_type(tensor, quant_configs) + print(f"tensor_name: {tensor_name}, tensor_type: {tensor_type}, dtype: {dtype}") if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, diff --git a/backends/qualcomm/builders/op_eq.py b/backends/qualcomm/builders/op_eq.py index fcf3213d3a9..3dfe2c77159 100644 --- a/backends/qualcomm/builders/op_eq.py +++ b/backends/qualcomm/builders/op_eq.py @@ -8,15 +8,22 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor from .qnn_constants import OpElementWiseEqual, QNN_OP_PACKAGE_NAME_QTI_AISW - +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_SCALE, + QCOM_ZERO_POINT, +) @register_node_visitor class Equal(NodeVisitor): - target = ["aten.eq.Tensor"] + target = ["aten.eq.Tensor", "aten.eq.Scalar"] def __init__(self, *args) -> None: super().__init__(*args) @@ -37,11 +44,43 @@ def define_node( output_tensors = [output_tensor_wrapper] input_tensors = [] - for index in range(2): - input_node = self.get_node(node.args[index]) - input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE - + for index, arg in enumerate(node.args): + if isinstance(arg, torch.fx.Node): + # Normal tensor input + input_node = self.get_node(arg) + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + else: + assert index == 1, f"eq op arg at index 1 has to be int, but the type is {type(arg)}" + assert isinstance(arg, int), f"eq op arg {arg} has to be int , but the type is {type(arg)}" + print(f"arg is {arg}, type is {type(arg)}") + # Handle scalar input (e.g., int or float) + scalar = arg + scalar_value = float(scalar) + input_tensor = torch.tensor( + scalar_value, dtype=torch.int32 + ) + input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_range = ( + quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] + ) + quant_attrs[QCOM_ZERO_POINT] = ( + 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] + ) + quant_attrs[QCOM_SCALE] = ( + scalar / quant_range if scalar >= 0 else -scalar / quant_range + ) + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs input_tensor_wrapper = self.define_tensor( input_node, node, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index cdd0c194fe3..b72246bc123 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -944,6 +944,17 @@ def forward(self, x): return x == self.constant +class EqualFromInplaceCopyDecomp(torch.nn.Module): + def __init__(self, hidden_size=4): + super().__init__() + # a small state tensor + self.register_buffer("h", torch.zeros((1, hidden_size))) + + def forward(self, x): + self.h[0] = x + return self.h[0] + + class ExpandCopy(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d168e25c81b..acb25dd89fb 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -765,6 +765,19 @@ def test_qnn_backend_equal(self): test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] ) + def test_qnn_backend_equal_debug(self): + test_comb = [ + { + QCOM_MODULE: EqualFromInplaceCopyDecomp(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([1.0, 2.0, 3.0, 4.0]), ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_expand(self): modules = [ExpandAs(), ExpandCopy()] # noqa: F405 sample_inputs = [ @@ -2842,6 +2855,26 @@ def test_qnn_backend_equal(self): ) self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + def test_qnn_backend_equal_debug(self): + test_comb = [ + { + QCOM_MODULE: EqualFromInplaceCopyDecomp(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([1.0, 2.0, 3.0, 4.0]), ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + + print("quantized module") + module.print_readable() + + self.lower_module_and_test_output( + module, test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_expand(self): modules = [ExpandAs(), ExpandCopy()] # noqa: F405 sample_inputs = [