Skip to content

Commit

Permalink
[reland][quant][pt2e][xnnpack_quantizer] Add support for mul and mul_…
Browse files Browse the repository at this point in the history
…relu (#107930)

Summary: att

Test Plan: buck2 run executorch/examples/quantization:example -- -m=mv3 --verify

Differential Revision: D48588121

ghstack-source-id: 7a044b21dcb0b891bd8c39e7fa3ff11c49b86535
Pull Request resolved: #107992
  • Loading branch information
jerryzh168 committed Aug 26, 2023
1 parent f92f69d commit bbee0d1
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 32 deletions.
2 changes: 2 additions & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CPU
import unittest

# The dict value is match_nodes(computation_op+unary_op)

Expand Down Expand Up @@ -399,6 +400,7 @@ def forward(self, x):

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@unittest.skip("TODO(leslie): some numbers changed due to quant flow update, re-enable the test")
def test_qconv2d_binary(self):
class M(torch.nn.Module):
def __init__(
Expand Down
65 changes: 65 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ def forward(self, indices):
conv_out = torch.squeeze(conv_out, dim=0)
return self.linear(conv_out)

class AddInplaceAdd(torch.nn.Module):
def forward(self, x, y):
x = x + y
x += y
return x

class MulInplaceMul(torch.nn.Module):
def forward(self, x, y):
x = x * y
x *= y
return x

class PT2EQuantizationTestCase(QuantizationTestCase):
"""
Expand Down Expand Up @@ -1221,6 +1232,60 @@ def validate(self, model: torch.fx.GraphModule) -> None:
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)

def test_add_and_inplace_add(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
node_occurrence = {
# two input and one output for first add, and output for second add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.add.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.add_.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.AddInplaceAdd(),
example_inputs,
quantizer,
node_occurrence,
node_list,
)

def test_mul_and_inplace_mul(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
node_occurrence = {
# two input and one output for first add, and output for second add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.mul.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.mul_.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.MulInplaceMul(),
example_inputs,
quantizer,
node_occurrence,
node_list,
)

def test_xnnpack_quantizer_conv(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def test_conv2d_binary_with_quantizer_api(self):
# one for output for the add
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
Expand Down Expand Up @@ -344,8 +344,8 @@ def test_conv2d_binary_unary_with_quantizer_api(self):
# one for output for the relu
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
Expand Down
1 change: 1 addition & 0 deletions torch/ao/quantization/pt2e/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
{torch.add, operator.add, operator.iadd, "add", "add_"},
{torch.mul, operator.mul, operator.imul},
]


Expand Down
33 changes: 30 additions & 3 deletions torch/ao/quantization/pt2e/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,38 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
new_arg = arg
obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq
else:
# skip inserting new observers if there is an observer inserted for the arg before
# that has the same dtype that we want to insert here
# alternatively we could have a dedup pass after we insert all observers to deduplicate
# observers
# Example:
# arg -> existing_obs -> conv1
# \ -> conv2
#
# instead of inserting new observers we will have:
# arg -> existing_obs -> conv1
# \ -> conv2
existing_obs_node = None
for maybe_obs_node in arg.users.keys():
if maybe_obs_node.op == 'call_module':
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
maybe_obs_mod.dtype == arg_as_input_target_dtype
):
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
existing_obs_node = maybe_obs_node
break

assert arg_as_input_act_obs_or_fq is not None
new_obs_node = _insert_obs_or_fq(
arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph) # type: ignore[arg-type]
new_arg = new_obs_node
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
if existing_obs_node is None:
new_obs_node = _insert_obs_or_fq(
arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
# override this arg to be the observed arg
new_arg = new_obs_node
else:
new_arg = existing_obs_node

return new_arg

Expand Down
2 changes: 2 additions & 0 deletions torch/ao/quantization/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def _annotate_all_patterns(
self._annotate_conv2d_patterns(model, config, filter_fn)
self._annotate_max_pool2d(model, config, filter_fn)
self._annotate_add_patterns(model, config, filter_fn)
OP_TO_ANNOTATOR["mul_relu"](model, config, filter_fn)
OP_TO_ANNOTATOR["mul"](model, config, filter_fn)
self._annotate_adaptive_avg_pool2d(model, config, filter_fn)
self._annotate_gru_io_only(model, config, filter_fn)
return model
Expand Down
125 changes: 100 additions & 25 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,20 +521,28 @@ def _annotate_max_pool2d(
)


def _annotate_input_out_obs_sharing_op(
op: Callable,
def _annotate_adaptive_avg_pool2d(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
module_partitions = get_source_partitions(gm.graph, [op], filter_fn)
"""Always annotate adaptive_avg_pool2d op"""
module_partitions = get_source_partitions(
gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
)
partitions = list(itertools.chain(*module_partitions.values()))
for partition in partitions:
io_obs_sharing_node = partition.output_nodes[0]
if _is_annotated([io_obs_sharing_node]):
pool_node = partition.output_nodes[0]
if (
pool_node.op != "call_function"
or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
):
raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")

if _is_annotated([pool_node]):
continue

input_act = io_obs_sharing_node.args[0]
input_act = pool_node.args[0]
assert isinstance(input_act, Node)

# only annotate input output sharing operator
Expand All @@ -544,31 +552,21 @@ def _annotate_input_out_obs_sharing_op(
or not input_act.meta["quantization_annotation"]._annotated
or input_act.meta["quantization_annotation"].output_qspec is None
):
continue
input_act_qspec = get_input_act_qspec(quantization_config)
else:
input_act_qspec = SharedQuantizationSpec(input_act)

act_qspec = SharedQuantizationSpec(input_act)
io_obs_sharing_node.meta["quantization_annotation"] = QuantizationAnnotation(
# output sharing with input
output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
input_act: input_act_qspec,
},
output_qspec=act_qspec,
output_qspec=output_act_qspec,
_annotated=True,
)


def _annotate_adaptive_avg_pool2d(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
_annotate_input_out_obs_sharing_op(
torch.nn.AdaptiveAvgPool2d, gm, quantization_config, filter_fn
)
_annotate_input_out_obs_sharing_op(
F.adaptive_avg_pool2d, gm, quantization_config, filter_fn
)


def _annotate_add_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
Expand Down Expand Up @@ -617,7 +615,7 @@ def _annotate_add(
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
add_partitions = get_source_partitions(
gm.graph, [operator.add, torch.add], filter_fn
gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
)
add_partitions = list(itertools.chain(*add_partitions.values()))
for add_partition in add_partitions:
Expand All @@ -644,6 +642,81 @@ def _annotate_add(
)


def _annotate_mul_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.mul, torch.nn.ReLU], filter_fn
)
for fused_partition in fused_partitions:
mul_partition, relu_partition = fused_partition
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(mul_partition.output_nodes) > 1:
raise ValueError("mul partition has more than one output node")
mul_node = mul_partition.output_nodes[0]

if _is_annotated([relu_node, mul_node]):
continue

input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)

input_qspec_map = {}
input_act0 = mul_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec

input_act1 = mul_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = input_act_qspec

mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=output_act_qspec,
_annotated=True,
)


def _annotate_mul(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
mul_partitions = get_source_partitions(
gm.graph, [operator.mul, torch.mul, operator.imul], filter_fn
)
mul_partitions = list(itertools.chain(*mul_partitions.values()))
for mul_partition in mul_partitions:
mul_node = mul_partition.output_nodes[0]
if _is_annotated([mul_node]):
continue

input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)

input_qspec_map = {}
input_act0 = mul_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec

input_act1 = mul_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = input_act_qspec

mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)


OP_TO_ANNOTATOR = {
"linear": _annotate_linear,
"conv2d": _annotate_conv2d,
Expand All @@ -653,6 +726,8 @@ def _annotate_add(
"max_pool2d": _annotate_max_pool2d,
"add": _annotate_add,
"add_relu": _annotate_add_relu,
"mul": _annotate_mul,
"mul_relu": _annotate_mul_relu,
"adaptive_avg_pool2d": _annotate_adaptive_avg_pool2d,
# input output only gru
"gru_io_only": _annotate_gru_io_only,
Expand Down

0 comments on commit bbee0d1

Please sign in to comment.