From e55524ef689cef18b3c30522db7b77693f47c49e Mon Sep 17 00:00:00 2001 From: Eashan Garg Date: Mon, 15 Sep 2025 20:19:36 -0700 Subject: [PATCH] Quantized Softmax Kernel (#14096) Summary: Generic implementation of quantized softmax, dummy implementation of DLA_V130 implementation for now NOTE: Mask parameter is nop Reviewed By: mcremon-meta Differential Revision: D78716203 --- backends/cadence/aot/ops_registrations.py | 39 +++++++++ backends/cadence/aot/quantizer/fusion_pass.py | 79 ++++++++++++++++++- backends/cadence/aot/quantizer/patterns.py | 22 ++++++ backends/cadence/aot/quantizer/quantizer.py | 29 +++++++ 4 files changed, 168 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index efb22a9e7d6..bd2bf32834d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -324,6 +324,19 @@ "rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)" +) +lib.define( + "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)" +) +lib.define( + "quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) +lib.define( + "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) + # Load/store with iDMA. These only exist before memory planning. # Post memory planning, we check that outputs/inputs for the load/store are in # DTCM and replace idma_load/idma_store with idma_copy. @@ -2329,3 +2342,29 @@ def softmax_f32_f32_meta( half_to_float: Optional[bool] = None, ) -> torch.Tensor: return self.new_empty(self.size(), dtype=self.dtype) + + +@register_fake("cadence::quantized_softmax") +def quantized_softmax_meta( + input: torch.Tensor, + mask: torch.Tensor, + dim: int, + in_scale: torch.Tensor, + in_zero_point: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=input.dtype) + + +@register_fake("cadence::quantized_softmax.per_tensor") +def quantized_softmax_per_tensor_meta( + input: torch.Tensor, + mask: torch.Tensor, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=input.dtype) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 8f106a815ac..ed14574a8c8 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -6,9 +6,10 @@ # pyre-strict -from typing import Any, Dict, List, Tuple +from typing import Any, cast, Dict, List, Tuple import torch +from executorch.backends.cadence.aot.compiler_utils import get_shape from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, @@ -25,6 +26,7 @@ MatmulPattern, ReluPattern0, ReluPattern1, + SoftmaxPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( check_out_zero_point_is_min_range, @@ -388,6 +390,73 @@ def get_args_and_kwargs_relu( return args, kwargs +def get_args_and_kwargs_softmax( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + quant_node: fx.Node, + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + # Make a dummy mask tensor + mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0])) + mask_shape = list(mask_shape) if mask_shape else [] + mask_shape[-1] = mask_shape[-1] // 16 + mask_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + mask_shape, + 0.0, + ), + {"dtype": torch.int32}, + ) + # Make the scale and zero_point tensors + in_scale_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + dequants_inputs[0].args[1], + ), + {"dtype": torch.float32}, + ) + in_zero_point_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + dequants_inputs[0].args[2], + ), + {"dtype": torch.int32}, + ) + out_scale_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + quant_node.args[1], + ), + {"dtype": torch.float32}, + ) + out_zero_point_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + quant_node.args[2], + ), + {"dtype": torch.int32}, + ) + + # Make the args and kwargs for the replacement op + args = ( + inputs_inputs[0], + mask_tensor, + op_node.args[1], + in_scale_tensor, + in_zero_point_tensor, + out_scale_tensor, + out_zero_point_tensor, + ) + kwargs = {} + return args, kwargs + + class QuantFusion(ExportPass): # pyre-ignore[2]: Parameter `patterns` has no type specified def __init__(self, patterns) -> None: @@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_inputs, quant_node, ) + elif isinstance(pattern, SoftmaxPattern): + args, kwargs = get_args_and_kwargs_softmax( + graph_module, + inputs_inputs, + dequants_inputs, + quant_node, + anchor_output_node, + ) fused = graph_module.graph.call_function( pattern.replacement_op(), args, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index b653be27e8f..33b476f5120 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -485,3 +485,25 @@ def partition_types(self) -> List[OpOverload]: class Conv2dReluPattern1(ConvReluBasePattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default] + + +class SoftmaxPattern(QuantizationPattern): + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten._softmax.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + softmax_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(softmax_node, 0)], + weights=[], + biases=[], + output=[(softmax_node,)], + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_softmax.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index cce7c207a6b..ad5f935173e 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -27,6 +27,7 @@ QuantizationPattern, ReluPattern0, ReluPattern1, + SoftmaxPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( find_sequential_partitions_aten, @@ -58,6 +59,15 @@ observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), ) +act_qspec_asym16s = QuantizationSpec( + dtype=torch.int16, + quant_min=-32768, + quant_max=32767, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), +) + wgt_qspec_asym8s = QuantizationSpec( dtype=torch.int8, quant_min=-128, @@ -92,6 +102,13 @@ None, ) +qconfig_A16 = QuantizationConfig( + act_qspec_asym16s, + act_qspec_asym16s, + wgt_qspec_asym8s, + None, +) + class CadenceAtenQuantizer(Quantizer): def __init__( @@ -283,3 +300,15 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) super().__init__(quantizers) + + +class CadenceWithSoftmaxQuantizer(CadenceQuantizer): + """ + Quantizer including A16 softmax + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = get_cadence_default_quantizers() + quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16)) + super().__init__(quantizers)