Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]:
class XnnpackDynamicallyQuantizedPartitioner(XnnpackPartitioner):
def __init__(self):
super().__init__(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
)


Expand Down
32 changes: 32 additions & 0 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,21 @@ def forward(self, x, y):
return a + b


class SharedDQChain(torch.nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
self.linear1_bias = torch.nn.Parameter(torch.rand(output_size))

self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
self.linear2_bias = torch.nn.Parameter(torch.rand(output_size))

def forward(self, x):
a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias)
b = torch.nn.functional.linear(x, self.linear2_weight, self.linear2_bias)
return a + b


class TestLinear(unittest.TestCase):
"""
Test Class for XNNPACK Linear Operators.
Expand Down Expand Up @@ -520,6 +535,23 @@ def get_qnode_checks(quant_node_checks, dialect):
# qtol=bool(quant_config), atol=atol
# )

def test_qd8_f32_per_channel_shared_dq_chain(self):
for use_bias in (False, True):
module = SharedDQChain(
input_size=13,
output_size=17,
)
inputs = (torch.randn(1, 2, 13),)

self._test_dqlinear(
module,
inputs,
dynamic_shapes=None,
is_per_channel=True,
linear_count=2,
uses_bias=use_bias,
)

def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
for uses_bias in (False, True):
module = BaseLinear(
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/test/tester/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ runtime.python_library(
"//executorch/exir/backend:partitioner",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/backends/transforms:duplicate_dynamic_quant_chain"
],
)
5 changes: 5 additions & 0 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import torch
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes import XNNPACKPassManager
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
Expand Down Expand Up @@ -177,6 +180,8 @@ def run(
prepared(*inputs)

converted = convert_pt2e(prepared)
DuplicateDynamicQuantChainPass()(converted)

self.converted_graph = converted

@property
Expand Down
Loading