Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
BmmPattern,
CatPattern,
Conv1dPattern,
Conv1dReluPattern0,
Conv1dReluPattern1,
Conv2dPattern,
Conv2dReluPattern0,
Conv2dReluPattern1,
LayerNormPattern,
LinearPattern,
MatmulPattern,
ReluPattern0,
ReluPattern1,
)
from executorch.backends.cadence.aot.quantizer.utils import (
check_out_zero_point_is_min_range,
create_zero_bias_int32,
find_sequential_partitions_aten,
get_conv_args,
Expand All @@ -41,6 +46,13 @@

# Use this part for patterns with multiple aten ops
ReluPatterns = (ReluPattern0, ReluPattern1)
ConvPatterns = (Conv1dPattern, Conv2dPattern)
ConvReluPatterns = (
Conv1dReluPattern0,
Conv1dReluPattern1,
Conv2dReluPattern0,
Conv2dReluPattern1,
)


def get_args_and_kwargs_add(
Expand Down Expand Up @@ -432,12 +444,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
other_inputs = [node.args[idx] for node, idx in anchors.others]

# The node is the first index of the list and first of the tuple
op_node = anchors.output[0][0]
anchor_output_node = anchors.output[0][0]

assert len(op_node.users) == 1
quant_node = list(op_node.users.keys())[0]
assert len(anchor_output_node.users) == 1
quant_node = list(anchor_output_node.users.keys())[0]

with graph_module.graph.inserting_after(op_node):
with graph_module.graph.inserting_after(anchor_output_node):
args = tuple(
inputs_inputs + weights_inputs + other_inputs + bias_inputs
)
Expand All @@ -451,9 +463,29 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
)
elif isinstance(pattern, CatPattern):
args, kwargs = get_args_and_kwargs_cat(
inputs_inputs, other_inputs, op_node
inputs_inputs, other_inputs, anchor_output_node
)
elif isinstance(pattern, ConvReluPatterns):
# For ConvReLU, we are fusing Conv+ReLU
# This means that the op we want to get
# the replacement args and kwargs for is the
# *conv* op, which is the anchor input, NOT
# the anchor output (which is the ReLU)
check_out_zero_point_is_min_range(
quant_node.args[2], quant_node.args[5]
)
anchor_input_node = anchors.inputs[0][0]
args, kwargs = get_args_and_kwargs_conv(
graph_module,
inputs_inputs,
dequants_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
quant_node,
anchor_input_node,
)
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
elif isinstance(pattern, ConvPatterns):
args, kwargs = get_args_and_kwargs_conv(
graph_module,
inputs_inputs,
Expand All @@ -462,7 +494,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
dequants_weights,
bias_inputs,
quant_node,
op_node,
anchor_output_node,
)
elif isinstance(pattern, LinearPattern):
args, kwargs = get_args_and_kwargs_linear(
Expand Down
68 changes: 68 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,71 @@ def partition_types(self) -> List[OpOverload]:
class ReluPattern1(ReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu_.default]


# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops
class ConvReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
pass

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# The first node should be conv, the second should be relu
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv_node = fused_partition[0].nodes[-1] # Second to last node
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[1].nodes[-1] # Last node

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(conv_node, 0)],
weights=[(conv_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(relu_node,)], # Output is from the relu node
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv_nchw.default


# Conv1d + regular relu op fusion
class Conv1dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default]


# Conv1d + alternate relu op fusion
class Conv1dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default]


# Conv2d + regular relu op fusion
class Conv2dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default]


# Conv2d + alternate relu op fusion
class Conv2dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]
23 changes: 23 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
BmmPattern,
CatPattern,
Conv1dPattern,
Conv1dReluPattern0,
Conv1dReluPattern1,
Conv2dPattern,
Conv2dReluPattern0,
Conv2dReluPattern1,
LayerNormPattern,
LinearPattern,
MatmulPattern,
Expand Down Expand Up @@ -260,3 +264,22 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
super().__init__(quantizers)


class CadenceFusedConvReluQuantizer(CadenceQuantizer):
"""
Quantizer using fused conv+relu patterns, and including add and cat
"""

def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
if quantizers is None:
quantizers = []
# Order matters here, perform the "fused" patterns first
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym))
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym))
quantizers = quantizers + get_cadence_default_quantizers()
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
super().__init__(quantizers)
16 changes: 16 additions & 0 deletions backends/cadence/aot/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,19 @@ def find_sequential_partitions_aten(
if _partitions_sequential(candidate):
fused_partitions.append(candidate)
return fused_partitions


def check_out_zero_point_is_min_range(
out_zero_point: int,
out_dtype: torch.dtype,
) -> bool:
"""
Checks if the out_zero_point is the minimum range of the quant type.
"""
if out_dtype == torch.int8:
return out_zero_point == -128
elif out_dtype == torch.int16:
return out_zero_point == -32768
elif out_dtype == torch.uint8 or torch.uint16:
return out_zero_point == 0
return False
Loading