Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,19 @@
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
)
lib.define(
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
)
lib.define(
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)
lib.define(
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)

# Load/store with iDMA. These only exist before memory planning.
# Post memory planning, we check that outputs/inputs for the load/store are in
# DTCM and replace idma_load/idma_store with idma_copy.
Expand Down Expand Up @@ -2329,3 +2342,29 @@ def softmax_f32_f32_meta(
half_to_float: Optional[bool] = None,
) -> torch.Tensor:
return self.new_empty(self.size(), dtype=self.dtype)


@register_fake("cadence::quantized_softmax")
def quantized_softmax_meta(
input: torch.Tensor,
mask: torch.Tensor,
dim: int,
in_scale: torch.Tensor,
in_zero_point: torch.Tensor,
out_scale: torch.Tensor,
out_zero_point: torch.Tensor,
) -> torch.Tensor:
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_softmax.per_tensor")
def quantized_softmax_per_tensor_meta(
input: torch.Tensor,
mask: torch.Tensor,
dim: int,
in_scale: float,
in_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
return input.new_empty(input.size(), dtype=input.dtype)
79 changes: 78 additions & 1 deletion backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

# pyre-strict

from typing import Any, Dict, List, Tuple
from typing import Any, cast, Dict, List, Tuple

import torch
from executorch.backends.cadence.aot.compiler_utils import get_shape
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
Expand All @@ -25,6 +26,7 @@
MatmulPattern,
ReluPattern0,
ReluPattern1,
SoftmaxPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
check_out_zero_point_is_min_range,
Expand Down Expand Up @@ -388,6 +390,73 @@ def get_args_and_kwargs_relu(
return args, kwargs


def get_args_and_kwargs_softmax(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Make a dummy mask tensor
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
mask_shape = list(mask_shape) if mask_shape else []
mask_shape[-1] = mask_shape[-1] // 16
mask_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
mask_shape,
0.0,
),
{"dtype": torch.int32},
)
# Make the scale and zero_point tensors
in_scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
dequants_inputs[0].args[1],
),
{"dtype": torch.float32},
)
in_zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
dequants_inputs[0].args[2],
),
{"dtype": torch.int32},
)
out_scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
quant_node.args[1],
),
{"dtype": torch.float32},
)
out_zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
quant_node.args[2],
),
{"dtype": torch.int32},
)

# Make the args and kwargs for the replacement op
args = (
inputs_inputs[0],
mask_tensor,
op_node.args[1],
in_scale_tensor,
in_zero_point_tensor,
out_scale_tensor,
out_zero_point_tensor,
)
kwargs = {}
return args, kwargs


class QuantFusion(ExportPass):
# pyre-ignore[2]: Parameter `patterns` has no type specified
def __init__(self, patterns) -> None:
Expand Down Expand Up @@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
dequants_inputs,
quant_node,
)
elif isinstance(pattern, SoftmaxPattern):
args, kwargs = get_args_and_kwargs_softmax(
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
anchor_output_node,
)
fused = graph_module.graph.call_function(
pattern.replacement_op(),
args,
Expand Down
22 changes: 22 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,25 @@ def partition_types(self) -> List[OpOverload]:
class Conv2dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]


class SoftmaxPattern(QuantizationPattern):

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten._softmax.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
softmax_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[(softmax_node, 0)],
weights=[],
biases=[],
output=[(softmax_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_softmax.default
29 changes: 29 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
QuantizationPattern,
ReluPattern0,
ReluPattern1,
SoftmaxPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
Expand Down Expand Up @@ -58,6 +59,15 @@
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)

act_qspec_asym16s = QuantizationSpec(
dtype=torch.int16,
quant_min=-32768,
quant_max=32767,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)

wgt_qspec_asym8s = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
Expand Down Expand Up @@ -92,6 +102,13 @@
None,
)

qconfig_A16 = QuantizationConfig(
act_qspec_asym16s,
act_qspec_asym16s,
wgt_qspec_asym8s,
None,
)


class CadenceAtenQuantizer(Quantizer):
def __init__(
Expand Down Expand Up @@ -283,3 +300,15 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
super().__init__(quantizers)


class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
"""
Quantizer including A16 softmax
"""

def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
if quantizers is None:
quantizers = get_cadence_default_quantizers()
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
super().__init__(quantizers)