Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_QUANT_ATTRS,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
Expand All @@ -31,6 +36,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
# args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps']
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
Expand All @@ -54,8 +60,25 @@ def define_node(
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]

weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
has_weight = len(node.args) > 2 and node.args[2] is not None
if has_weight:
weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
else:
# elementwise_affine=False: use all-ones weight as identity
weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
weight_node = torch.fx.Node(
node.graph,
node.name + "_runtime_weight",
"call_function",
exir_ops.edge.aten.tensor.default,
(),
{},
)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
Expand All @@ -64,21 +87,34 @@ def define_node(
nodes_to_wrappers,
)

layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]

bias_node = self.get_node(node.args[3])
if bias_node is not None:
# Fake node: even when original bias is absent, QNN still needs it
has_bias = len(node.args) > 3 and node.args[3] is not None
if has_bias:
bias_node = self.get_node(node.args[3])
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
else:
bias_tensor = torch.zeros(normalized_shapes, dtype=torch.float32)
bias_node = torch.fx.Node(
node.graph,
node.name + "_runtime_bias",
"call_function",
exir_ops.edge.aten.tensor.default,
(),
{},
)
layer_norm_input_tensors.append(bias_tensor_wrapper)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
)

epsilon = node.args[4]
epsilon = node.args[4] if len(node.args) > 4 else 1e-05

output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
Expand All @@ -94,7 +130,9 @@ def define_node(
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpLayerNorm.op_name,
)
layer_norm_op.AddInputTensors(layer_norm_input_tensors)
layer_norm_op.AddInputTensors(
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
)
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
layer_norm_op.AddScalarParam(
OpLayerNorm.param_epsilon,
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def is_parameter(

def get_parameter(
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
if node is None:
return None
param = None
if is_param(edge_program, node):
param = get_param(edge_program, node)
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def build_executorch_binary(
with open(pte_name, "wb") as file:
exec_prog_mgr.write_to_file(file)

print(f"Successfully generated {pte_name}.")
if qnn_config.compile_only:
sys.exit(0)

Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,8 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:


@register_annotator(
[torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name
[torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default],
QnnConstants.OpLayerNorm.op_name,
)
class LayerNorm(GeneralOpDef):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/quantizer/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if node is None:
continue
Comment on lines +32 to +33
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
node.meta[Q_ANNOTATION_KEY]._annotated = True
Expand Down
20 changes: 20 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,26 @@ def forward(self, x):
return self.linear(self.layer_norm(x))


class NativeLayerNorm(torch.nn.Module):
def __init__(self, affine=True):
super().__init__()
self.affine = affine
self.weight = torch.nn.Parameter(torch.ones(768))
self.bias = torch.nn.Parameter(torch.zeros(768))
self.normalized_shape = [768]
self.eps = 1e-6

def forward(self, x):
if self.affine:
return torch.native_layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)[0]
else:
return torch.native_layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)[0]


class LayerNormAdd(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
15 changes: 15 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,13 @@ def test_qnn_backend_layer_norm(self):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_native_layer_norm(self):
modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_leaky_relu(self):
torch.manual_seed(8)
test_comb = [
Expand Down Expand Up @@ -3811,6 +3818,14 @@ def test_qnn_backend_layer_norm(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_native_layer_norm(self):
modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_leaky_relu(self):
test_comb = [
{
Expand Down
Loading