Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/apple/coreml/quantizer/coreml_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
#
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer
from coremltools.optimize.torch.quantization._coreml_quantizer import ( # noqa: FLAKE8 F401
CoreMLQuantizer,
)
15 changes: 10 additions & 5 deletions backends/apple/coreml/test/test_coreml_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@
#
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

from typing import Tuple

import numpy as np
import pytest
from typing import Tuple

import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e

from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer

from coremltools.optimize.torch.quantization.quantization_config import (
LinearQuantizerConfig,
QuantizationScheme,
)

from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)


class TestCoreMLQuantizer:
@staticmethod
Expand Down
41 changes: 20 additions & 21 deletions backends/arm/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
QuantizationSpecBase,
SharedQuantizationSpec,
)

Expand Down Expand Up @@ -416,6 +415,25 @@ def _annotate_conv_bn_relu(
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)


def _get_pattern(conv_fn: Callable, relu_is_inplace: bool, has_relu: bool):
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
conv = conv_fn(x, conv_weight, conv_bias)
bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
if has_relu:
output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
else:
output = bn
return output, {
"input": x,
"conv": conv,
"weight": conv_weight,
"bias": conv_bias,
"output": output,
}

return _WrapperModule(_conv_bn)


def _do_annotate_conv_bn(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
Expand All @@ -430,24 +448,6 @@ def _do_annotate_conv_bn(
for the following names: "input", "conv", "weight", "bias", and "output".
"""

def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
conv = conv_fn(x, conv_weight, conv_bias)
bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
if has_relu:
output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
else:
output = bn
return output, {
"input": x,
"conv": conv,
"weight": conv_weight,
"bias": conv_bias,
"output": output,
}

return _WrapperModule(_conv_bn)

# Needed for matching, otherwise the matches gets filtered out due to unused
# nodes returned by batch norm
gm.graph.eliminate_dead_code()
Expand All @@ -468,7 +468,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):

# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
pattern = get_pattern(conv_fn, relu_is_inplace)
pattern = _get_pattern(conv_fn, relu_is_inplace, has_relu)
pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
pattern.graph.eliminate_dead_code()
pattern.recompile()
Expand Down Expand Up @@ -545,7 +545,6 @@ def _annotate_gru_io_only(
continue
# inside each GRU partition, we should be able to annotate each linear
# subgraph
input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
input_act = input_nodes[0]
input_act_user = next(iter(input_act.users.keys()))
assert isinstance(input_act, Node)
Expand Down