Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def define_tensor(
tensor_source_node, target_build_node
)
dtype = self.get_data_type(tensor, quant_configs)
print(f"tensor_name: {tensor_name}, tensor_type: {tensor_type}, dtype: {dtype}")
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
Expand Down
53 changes: 46 additions & 7 deletions backends/qualcomm/builders/op_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpElementWiseEqual, QNN_OP_PACKAGE_NAME_QTI_AISW

from executorch.backends.qualcomm.utils.constants import (
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_SCALE,
QCOM_ZERO_POINT,
)

@register_node_visitor
class Equal(NodeVisitor):
target = ["aten.eq.Tensor"]
target = ["aten.eq.Tensor", "aten.eq.Scalar"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand All @@ -37,11 +44,43 @@ def define_node(
output_tensors = [output_tensor_wrapper]

input_tensors = []
for index in range(2):
input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE

for index, arg in enumerate(node.args):
if isinstance(arg, torch.fx.Node):
# Normal tensor input
input_node = self.get_node(arg)
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
else:
assert index == 1, f"eq op arg at index 1 has to be int, but the type is {type(arg)}"
assert isinstance(arg, int), f"eq op arg {arg} has to be int , but the type is {type(arg)}"
print(f"arg is {arg}, type is {type(arg)}")
# Handle scalar input (e.g., int or float)
scalar = arg
scalar_value = float(scalar)
input_tensor = torch.tensor(
scalar_value, dtype=torch.int32
)
input_node = torch.fx.Node(
node.graph,
node.name + "_runtime_scalar",
"call_function",
exir_ops.edge.aten.scalar_tensor.default,
(), # args
{}, # kwargs
)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_range = (
quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
)
quant_attrs[QCOM_ZERO_POINT] = (
0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX]
)
quant_attrs[QCOM_SCALE] = (
scalar / quant_range if scalar >= 0 else -scalar / quant_range
)
input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
input_tensor_wrapper = self.define_tensor(
input_node,
node,
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,17 @@ def forward(self, x):
return x == self.constant


class EqualFromInplaceCopyDecomp(torch.nn.Module):
def __init__(self, hidden_size=4):
super().__init__()
# a small state tensor
self.register_buffer("h", torch.zeros((1, hidden_size)))

def forward(self, x):
self.h[0] = x
return self.h[0]


class ExpandCopy(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
33 changes: 33 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,19 @@ def test_qnn_backend_equal(self):
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)

def test_qnn_backend_equal_debug(self):
test_comb = [
{
QCOM_MODULE: EqualFromInplaceCopyDecomp(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.tensor([1.0, 2.0, 3.0, 4.0]), ),
},
]
for i, test in enumerate(test_comb):
with self.subTest(i=i):
self.lower_module_and_test_output(
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)

def test_qnn_backend_expand(self):
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
sample_inputs = [
Expand Down Expand Up @@ -2842,6 +2855,26 @@ def test_qnn_backend_equal(self):
)
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])

def test_qnn_backend_equal_debug(self):
test_comb = [
{
QCOM_MODULE: EqualFromInplaceCopyDecomp(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.tensor([1.0, 2.0, 3.0, 4.0]), ),
},
]
for i, test in enumerate(test_comb):
with self.subTest(i=i):
module = self.get_qdq_module(
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)

print("quantized module")
module.print_readable()

self.lower_module_and_test_output(
module, test[QCOM_SAMPLE_INPUTS]
)

def test_qnn_backend_expand(self):
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
sample_inputs = [
Expand Down
Loading