Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:

return ConfigPrecisionType.STATIC_QUANT

def _overwrite_precision(self, node: torch.fx.Node):
precision = self._detect_precision(node)
if precision not in self.enabled_precision_types:
# detected precision is not enabled, lets try to partition it as fp32
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
# if only fp32 is enabled, then we can still partition fp32 gemms
# even with in a quantized graph
if precision in [
ConfigPrecisionType.STATIC_QUANT,
ConfigPrecisionType.DYNAMIC_QUANT,
]:
precision = ConfigPrecisionType.FP32
logging.info(f"Overwriting precision, partitioning {node} as FP32")
return True, precision
return False, precision

def get_deps(
self,
node: torch.fx.Node,
Expand All @@ -107,7 +123,7 @@ def get_deps(
if precision not in self.supported_precision_types():
# detected precision but it is either disabled or not supported
return (False, [])

_, precision = self._overwrite_precision(node)
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
valid_act, act_deps = self._get_act_deps(node, ep, precision)
Expand Down Expand Up @@ -193,7 +209,7 @@ def _get_bias_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
gemm_deps = []
if len(node.all_input_nodes) > 2 and self.bias_idx:
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
bias_node = get_input_node(node, self.bias_idx)
if bias_node:
if not is_param_node(ep, bias_node):
Expand Down Expand Up @@ -266,7 +282,14 @@ def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
# if force fp32_dynamic_linear is enabled, then we
# do not partition the weight node
return (True, [])

# Since we are in Linear, we may assume that the weights are indeed static.
overwritten_linear_precision, new_precision = self._overwrite_precision(node)
if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision:
# if overwriting quantized precision to fp32, then we
# do not partition the weight node
return (True, [])

Expand Down
155 changes: 152 additions & 3 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import unittest

from itertools import product
from typing import Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import torch
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackFloatingPointPartitioner,
XnnpackPartitioner,
)
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from executorch.backends.xnnpack.test.tester.tester import (
Partition,
Expand Down Expand Up @@ -672,3 +674,150 @@ def _test_groupwise_dq_linear(
.serialize()
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)

def _test_linear_overwrite_precision(
self,
make_module: Callable[[int, int], torch.nn.Module],
uses_bias: bool,
quant_type: str,
quant_node_checks: List[Dict[str, int]],
atol: float = 1e-03,
):
"""
This test is to test the overwrite precision of linear op.
We will test partitioning, lowering, and running the quantized linear model as fp32 linear op.
When using legacy_mode, we will test we don't partition [add]mm given,
(1) We can't assume that weights are always static (non param).
(2) Alternatively, when lowering [add]mm to xnn::bmm we can't support bias.
(2)(a) Only lowering non-bias [add]mm, which is only exposed on legacy_path deemed low ROI.
"""

in_sizes = [3, 4, 4]
input_sizes = [4, 37, 17]
output_sizes = [4, 17, 37]

assert quant_type in ["per_tensor", "per_channel", "per_channel_dynamic"]
per_channel = "per_channel" in quant_type
dynamic = "dynamic" in quant_type
quant_config = get_symmetric_quantization_config(
is_per_channel=per_channel,
is_dynamic=dynamic,
)
# Using FP32 partitioner for this quantized graph
partitioner = XnnpackFloatingPointPartitioner()

def get_qnode_checks(quant_node_checks, dialect):
d = {}
assert dialect in ["aten", "edge"]
if dialect == "aten":
d = {
f"torch.ops.quantized_decomposed.{op}": count
for op, count in quant_node_checks.items()
}
elif dialect == "edge":
d = {
f"executorch.exir.dialects.edge._ops.quantized_decomposed.{op}".replace(
".", "_"
): count
for op, count in quant_node_checks.items()
}
assert len(d) == len(quant_node_checks)
return d

for i, _ in enumerate(in_sizes):
torch._dynamo.reset()
in_size = int(in_sizes[i])
input_size = int(input_sizes[i])
output_size = int(output_sizes[i])
input_shape = [in_size] + [input_size]
module = make_module(input_size, output_size).eval()
inputs = (torch.randn(input_shape),)

addmm_op_str = (
"executorch_exir_dialects_edge__ops_aten_addmm_default"
if uses_bias
else "executorch_exir_dialects_edge__ops_aten_mm_default"
)
linear_op_str = "executorch_exir_dialects_edge__ops_aten_linear_default"

for legacy_mode in (True, False):
tester = (
Tester(module, inputs)
.quantize(Quantize(quantization_config=quant_config))
.export()
.dump_artifact()
.check_count(get_qnode_checks(quant_node_checks, "aten"))
)

if legacy_mode:
tester.to_edge()
tester.partition(Partition(partitioner=partitioner))
# We don't expect [add]mm to be partitioned
tester.check([addmm_op_str])
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower(partitioners=[partitioner])
)
# We do expect linear to be partitioned
tester.check_not([linear_op_str])

# For legacy mode, fp32 permute_copy gets partitioned. (just a side effect)
# For new mode, fp32 linear gets partitioned.
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": 1}
)

# Typically, we would not see any quantized ops in the graph.
# But here we shouldn't partition these.
tester.check_count(get_qnode_checks(quant_node_checks, "edge"))

# TODO: Need to figure out how to load quantized ops in pybindings.
# tester.to_executorch()
# tester.serialize()
# tester.run_method_and_compare_outputs(
# qtol=bool(quant_config), atol=atol
# )

def test_qs8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_tensor",
quant_node_checks={
"quantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_tensor.default": 3, # 1: act, 1: weight, 1: output
},
)

def test_qc8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_channel",
quant_node_checks={
"quantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_channel.default": 1, # 1: weight
},
)

def test_qd8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_channel_dynamic",
quant_node_checks={
"quantize_per_tensor.tensor": 1, # 1: act
"dequantize_per_tensor.tensor": 1, # 1: act
"dequantize_per_channel.default": 1, # 1: weight
},
)
Loading