Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,8 @@
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_w8a32_linear_out

- func: cadence::quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_w8a32_conv_out
35 changes: 35 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,12 @@
"quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
)

lib.define(
"quantized_w8a32_conv(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor"
)
lib.define(
"quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
)

# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
aten_lib = Library("aten", "FRAGMENT")
Expand Down Expand Up @@ -2589,3 +2595,32 @@ def quantized_w8a32_linear_meta(
assert src_shape[-1] == weight_shape[-1]
src_shape[-1] = weight_shape[0]
return src.new_empty(src_shape, dtype=src.dtype)


@register_fake("cadence::quantized_w8a32_conv")
def quantized_w8a32_conv_meta(
src: torch.Tensor,
weight: torch.Tensor,
w_scale: float,
bias: torch.Tensor,
b_scale: float,
) -> torch.Tensor:
# src comes in shape [batch, in_channel, in_length]
# weight comes in shape [out_ch, in_ch, kernel_dim]
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
assert len(src.shape) == 3

kernel_size, out_channels, in_channels = weight.shape
assert in_channels == src.shape[-1]

# Compute the output tensor size
output_size = get_conv1d_output_size(
src.permute(0, 2, 1).shape,
out_channels,
stride=1,
padding=0,
dilation=1,
kernel_size=kernel_size,
channel_last=False,
)
return src.new_empty(output_size, dtype=src.dtype)
57 changes: 57 additions & 0 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
MixedW8A32ConvPattern,
MixedW8A32LinearPattern,
ReluPattern0,
ReluPattern1,
Expand Down Expand Up @@ -478,6 +479,52 @@ def get_args_and_kwargs_softmax(
out_zero_point_tensor,
)
kwargs = {}

return args, kwargs


def get_args_and_kwargs_mixed_w8a32_conv(
graph_module: GraphModule,
other_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
dequants_biases: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Stride, padding, dilation, groups not supported yet
if len(op_node.args) > 3:
assert op_node.args[3] == [1] # Stride
if len(op_node.args) > 4:
assert op_node.args[4] == [0] # Padding
if len(op_node.args) > 5:
assert op_node.args[5] == [1] # Dilation
if len(op_node.args) > 6:
assert op_node.args[6] == 1 # Groups

assert len(dequants_weights) == 1
assert len(dequants_biases) == 1
W_scale_ = dequants_weights[0].args[1]
B_scale_ = dequants_biases[0].args[1]

transposed_inputs = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
)
transposed_weights = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(weights_inputs[0], [2, 0, 1]), # NCL -> NLC
)

args = (
transposed_inputs,
transposed_weights,
W_scale_,
bias_inputs[0],
B_scale_,
)
kwargs = {}

return args, kwargs


Expand Down Expand Up @@ -650,6 +697,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs,
dequants_biases,
)
elif isinstance(pattern, MixedW8A32ConvPattern):
args, kwargs = get_args_and_kwargs_mixed_w8a32_conv(
graph_module,
other_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
dequants_biases,
op_node,
)

fused = graph_module.graph.call_function(
pattern.replacement_op(),
Expand Down
62 changes: 62 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,65 @@ def get_anchors(

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_linear.default


class MixedW8A32ConvPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-ignore[29]
conv_layer = fused_partition[0].nodes[-1]

# Bail if the arguments have different shapes than expected
# Stride, padding, dilation and groups are not supported
if len(conv_layer.args) != 3 or len(conv_layer.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)

cnn_weights = conv_layer.args[1]
if hasattr(cnn_weights.meta, "tensor_meta"):
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
# Bail if the channels are not multiple of 4 (SIMD)
if cnn_weights_shape[0] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
if cnn_weights_shape[1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
# Bail if the kernel size is not 3
if cnn_weights_shape[2] != 3:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)

return (
PartitionAnchors(
inputs=[],
weights=[(conv_layer, 1)],
biases=[(conv_layer, 2)],
output=[],
others=[(conv_layer, 0)],
),
conv_layer,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_conv.default
4 changes: 4 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
MixedW8A32ConvPattern,
MixedW8A32LinearPattern,
QuantizationPattern,
ReluPattern0,
Expand Down Expand Up @@ -321,6 +322,9 @@ def __init__(self) -> None:
quantizers.append(
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
)
quantizers.append(
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
)
super().__init__(quantizers)


Expand Down
Loading