diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 8c65e745c21..c1cef01c1e8 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -553,3 +553,8 @@ kernels: - arg_meta: null kernel_name: impl::HiFi::quantized_w8a32_linear_out + +- func: cadence::quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_w8a32_conv_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 9266cc72970..38a6b08836c 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -571,6 +571,12 @@ "quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)" ) +lib.define( + "quantized_w8a32_conv(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor" +) +lib.define( + "quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)" +) # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined aten_lib = Library("aten", "FRAGMENT") @@ -2589,3 +2595,32 @@ def quantized_w8a32_linear_meta( assert src_shape[-1] == weight_shape[-1] src_shape[-1] = weight_shape[0] return src.new_empty(src_shape, dtype=src.dtype) + + +@register_fake("cadence::quantized_w8a32_conv") +def quantized_w8a32_conv_meta( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + # src comes in shape [batch, in_channel, in_length] + # weight comes in shape [out_ch, in_ch, kernel_dim] + # output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1] + assert len(src.shape) == 3 + + kernel_size, out_channels, in_channels = weight.shape + assert in_channels == src.shape[-1] + + # Compute the output tensor size + output_size = get_conv1d_output_size( + src.permute(0, 2, 1).shape, + out_channels, + stride=1, + padding=0, + dilation=1, + kernel_size=kernel_size, + channel_last=False, + ) + return src.new_empty(output_size, dtype=src.dtype) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index cdadedff6cf..c8bfa5cbac7 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -24,6 +24,7 @@ LayerNormPattern, LinearPattern, MatmulPattern, + MixedW8A32ConvPattern, MixedW8A32LinearPattern, ReluPattern0, ReluPattern1, @@ -478,6 +479,52 @@ def get_args_and_kwargs_softmax( out_zero_point_tensor, ) kwargs = {} + + return args, kwargs + + +def get_args_and_kwargs_mixed_w8a32_conv( + graph_module: GraphModule, + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + dequants_biases: List[fx.Node], + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + # Stride, padding, dilation, groups not supported yet + if len(op_node.args) > 3: + assert op_node.args[3] == [1] # Stride + if len(op_node.args) > 4: + assert op_node.args[4] == [0] # Padding + if len(op_node.args) > 5: + assert op_node.args[5] == [1] # Dilation + if len(op_node.args) > 6: + assert op_node.args[6] == 1 # Groups + + assert len(dequants_weights) == 1 + assert len(dequants_biases) == 1 + W_scale_ = dequants_weights[0].args[1] + B_scale_ = dequants_biases[0].args[1] + + transposed_inputs = graph_module.graph.call_function( + torch.ops.aten.permute.default, + (other_inputs[0], [0, 2, 1]), # NCL -> NLC + ) + transposed_weights = graph_module.graph.call_function( + torch.ops.aten.permute.default, + (weights_inputs[0], [2, 0, 1]), # NCL -> NLC + ) + + args = ( + transposed_inputs, + transposed_weights, + W_scale_, + bias_inputs[0], + B_scale_, + ) + kwargs = {} + return args, kwargs @@ -650,6 +697,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs, dequants_biases, ) + elif isinstance(pattern, MixedW8A32ConvPattern): + args, kwargs = get_args_and_kwargs_mixed_w8a32_conv( + graph_module, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + dequants_biases, + op_node, + ) fused = graph_module.graph.call_function( pattern.replacement_op(), diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 5ceb2ffdda3..65389aaad37 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -599,3 +599,65 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_linear.default + + +class MixedW8A32ConvPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-ignore[29] + conv_layer = fused_partition[0].nodes[-1] + + # Bail if the arguments have different shapes than expected + # Stride, padding, dilation and groups are not supported + if len(conv_layer.args) != 3 or len(conv_layer.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + + cnn_weights = conv_layer.args[1] + if hasattr(cnn_weights.meta, "tensor_meta"): + cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape + # Bail if the channels are not multiple of 4 (SIMD) + if cnn_weights_shape[0] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + if cnn_weights_shape[1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + # Bail if the kernel size is not 3 + if cnn_weights_shape[2] != 3: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + + return ( + PartitionAnchors( + inputs=[], + weights=[(conv_layer, 1)], + biases=[(conv_layer, 2)], + output=[], + others=[(conv_layer, 0)], + ), + conv_layer, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_w8a32_conv.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4df69df0779..f824ef874c4 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -24,6 +24,7 @@ LayerNormPattern, LinearPattern, MatmulPattern, + MixedW8A32ConvPattern, MixedW8A32LinearPattern, QuantizationPattern, ReluPattern0, @@ -321,6 +322,9 @@ def __init__(self) -> None: quantizers.append( CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym) ) + quantizers.append( + CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym) + ) super().__init__(quantizers)