diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 6c497d5bec4..765ddcd581d 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -24,6 +24,7 @@ from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceDefaultQuantizer, CadenceQuantizer, + CadenceW8A32MixedQuantizer, ) from executorch.backends.cadence.aot.utils import ( get_default_memory_config, @@ -59,6 +60,7 @@ def trace( model: torch.nn.Module, inputs: tuple[object, ...], dump_graphs: bool = False, + quantizer: Optional[CadenceQuantizer] = None, ) -> ExportedProgram: """ Trace the model with export and return an ExportedProgram. @@ -73,6 +75,12 @@ def trace( torch.ops.aten.rms_norm.default, ] + if isinstance(quantizer, CadenceW8A32MixedQuantizer): + ops_to_keep += [ + torch.ops.aten.gru.input, + torch.ops.aten.gru.data, + ] + program = trace_fn( model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep ) @@ -99,7 +107,7 @@ def prepare_pt2( Returns a GraphModule with the prepared model. """ - traced_program = trace(model, inputs, dump_graphs=dump_graphs) + traced_program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer) prepared_program = prepare_traced_pt2( traced_program, quantizer, dump_graphs=dump_graphs ) @@ -184,7 +192,7 @@ def get_fake_quant_model( # Make the model inference mode by calling model.eval() model.eval() - program = trace(model, inputs, dump_graphs=dump_graphs) + program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer) if dump_graphs: logging.info("Graph after trace:") diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index c1cef01c1e8..3bdbb33d59b 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -558,3 +558,8 @@ kernels: - arg_meta: null kernel_name: impl::HiFi::quantized_w8a32_conv_out + +- func: cadence::quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_w8a32_gru_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index a0527618bcf..f827488adfb 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -578,6 +578,15 @@ "quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)" ) +lib.define( + "quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor" +) + +lib.define( + "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)" +) + + # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined aten_lib = Library("aten", "FRAGMENT") aten_lib.define( @@ -2646,3 +2655,19 @@ def quantized_w8a32_conv_meta( channel_last=False, ) return src.new_empty(output_size, dtype=src.dtype) + + +@register_fake("cadence::quantized_w8a32_gru") +def quantized_w8a32_gru_meta( + inputs: torch.Tensor, + hidden: torch.Tensor, + weights_inputs: torch.Tensor, + w_i_scale: float, + weights_hidden: torch.Tensor, + w_h_scale: float, + bias_inputs: torch.Tensor, + b_i_scale: float, + bias_hidden: torch.Tensor, + b_h_scale: float, +) -> torch.Tensor: + return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index c8bfa5cbac7..2fa0f794e3c 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -25,6 +25,7 @@ LinearPattern, MatmulPattern, MixedW8A32ConvPattern, + MixedW8A32GruPattern, MixedW8A32LinearPattern, ReluPattern0, ReluPattern1, @@ -528,6 +529,41 @@ def get_args_and_kwargs_mixed_w8a32_conv( return args, kwargs +def get_args_and_kwargs_mixed_w8a32_gru( + graph_module: GraphModule, + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + dequants_biases: List[fx.Node], + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + # Stride, padding, dilation, groups not supported yet + + assert len(dequants_weights) == 2 + assert len(dequants_biases) == 2 + w_i_scale = dequants_weights[0].args[1] + w_h_scale = dequants_weights[1].args[1] + b_i_scale = dequants_biases[0].args[1] + b_h_scale = dequants_biases[1].args[1] + + args = ( + other_inputs[0], + other_inputs[1], + weights_inputs[0], + w_i_scale, + weights_inputs[1], + w_h_scale, + bias_inputs[0], + b_i_scale, + bias_inputs[1], + b_h_scale, + ) + kwargs = {} + + return args, kwargs + + class QuantFusion(ExportPass): # pyre-ignore[2]: Parameter `patterns` has no type specified def __init__(self, patterns) -> None: @@ -707,6 +743,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_biases, op_node, ) + elif isinstance(pattern, MixedW8A32GruPattern): + args, kwargs = get_args_and_kwargs_mixed_w8a32_gru( + graph_module, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + dequants_biases, + op_node, + ) fused = graph_module.graph.call_function( pattern.replacement_op(), diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 65389aaad37..2452cfdcfea 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -661,3 +661,60 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_conv.default + + +class MixedW8A32GruPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.gru.input] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + gru_layer = fused_partition[0].nodes[-1] + if len(gru_layer.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + gru_layer, + ) + + # Bail if input or states are not multiple of 4 (SIMD) + if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + gru_layer, + ) + if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + gru_layer, + ) + + class Wrapper: # noqa: B903 + def __init__(self, args, meta): + self.args = args + self.meta = meta + + wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta) + + return ( + PartitionAnchors( + inputs=[], + # pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`. + weights=[(wrapper, 0), (wrapper, 1)], + # pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`. + biases=[(wrapper, 2), (wrapper, 3)], + output=[], + others=[(gru_layer, 0), (gru_layer, 1)], + ), + gru_layer, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_w8a32_gru.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index f824ef874c4..d4af074c475 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -25,6 +25,7 @@ LinearPattern, MatmulPattern, MixedW8A32ConvPattern, + MixedW8A32GruPattern, MixedW8A32LinearPattern, QuantizationPattern, ReluPattern0, @@ -325,6 +326,9 @@ def __init__(self) -> None: quantizers.append( CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym) ) + quantizers.append( + CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym) + ) super().__init__(quantizers)