Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ executorch_generated_lib(
"//executorch/backends/cadence/generic/operators:dequantize_per_tensor",
"//executorch/backends/cadence/generic/operators:quantize_per_tensor",
"//executorch/backends/cadence/generic/operators:quantized_add_out",
"//executorch/backends/cadence/generic/operators:quantized_conv_nchw_out",
"//executorch/backends/cadence/generic/operators:quantized_conv_nhwc_out",
"//executorch/backends/cadence/generic/operators:quantized_conv2d_nchw_out",
"//executorch/backends/cadence/generic/operators:quantized_conv2d_nhwc_out",
"//executorch/backends/cadence/generic/operators:quantized_fully_connected_out",
"//executorch/backends/cadence/generic/operators:quantized_layer_norm",
"//executorch/backends/cadence/generic/operators:quantized_linear_out",
Expand Down
80 changes: 40 additions & 40 deletions backends/cadence/aot/functions.yaml

Large diffs are not rendered by default.

80 changes: 40 additions & 40 deletions backends/cadence/aot/functions_hifi.yaml

Large diffs are not rendered by default.

168 changes: 88 additions & 80 deletions backends/cadence/aot/ops_registrations.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.default


class Conv2dPattern(QuantizationPattern):
Expand Down Expand Up @@ -286,7 +286,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.default


class LayerNormPattern(QuantizationPattern):
Expand Down Expand Up @@ -460,7 +460,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.default


# Conv1d + regular relu op fusion
Expand Down
84 changes: 46 additions & 38 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def quantized_conv_per_tensor(
)


@impl(m, "quantized_conv_nchw.per_tensor")
def quantized_conv_nchw_per_tensor(
@impl(m, "quantized_conv2d_nchw.per_tensor")
def quantized_conv2d_nchw_per_tensor(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
Expand Down Expand Up @@ -679,8 +679,8 @@ def quantized_conv_nchw_per_tensor(
)


@impl(m, "quantized_conv_nhwc.per_tensor")
def quantized_conv_nhwc_per_tensor(
@impl(m, "quantized_conv2d_nhwc.per_tensor")
def quantized_conv2d_nhwc_per_tensor(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
Expand Down Expand Up @@ -800,7 +800,7 @@ def variant(
# Call the appropriate base function
match layout:
case "nchw":
return quantized_conv_nchw_per_tensor(
return quantized_conv2d_nchw_per_tensor(
input_tensor,
weight,
bias,
Expand All @@ -817,7 +817,7 @@ def variant(
out_shift,
)
case "nhwc":
return quantized_conv_nhwc_per_tensor(
return quantized_conv2d_nhwc_per_tensor(
input_tensor,
weight,
bias,
Expand All @@ -841,84 +841,92 @@ def variant(
return decorator


@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nchw", torch.int8, torch.int8)
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nchw", torch.int8, torch.int8)
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nchw", torch.int8, torch.int8)
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> (
torch.Tensor
): ...


@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> (
torch.Tensor
): ...


@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> (
torch.Tensor
): ...


@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> (
torch.Tensor
): ...


@impl(m, "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True)
def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True)
def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor")
@impl(m, "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True)
def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor")
@impl(m, "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True)
def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


def quantized_relu_common(
Expand Down
32 changes: 16 additions & 16 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ class ReplaceTrivialConvWithLinear(ExportPass):

trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
exir_ops.edge.cadence.quantized_conv_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
exir_ops.edge.cadence.quantized_conv_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
exir_ops.edge.cadence.quantized_conv2d_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
exir_ops.edge.cadence.quantized_conv2d_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
}

def call_operator(self, op, args, kwargs, meta):
Expand All @@ -800,8 +800,8 @@ def call_operator(self, op, args, kwargs, meta):
# extra args holding at least the zero point and scale of input, weight, bias,
# and output tensor.
quantized_op = (
op == exir_ops.edge.cadence.quantized_conv_nchw.default
or op == exir_ops.edge.cadence.quantized_conv_nhwc.default
op == exir_ops.edge.cadence.quantized_conv2d_nchw.default
or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default
)
assert (len(args) == 8 and not quantized_op) or (
len(args) >= 12 and quantized_op
Expand Down Expand Up @@ -979,18 +979,18 @@ def call_operator(
) -> ProxyValue:
if op not in {
exir_ops.edge.cadence.convolution.default,
exir_ops.edge.cadence.quantized_conv_nchw.default,
exir_ops.edge.cadence.quantized_conv2d_nchw.default,
}:
return super().call_operator(op, args, kwargs, meta)

quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.default
quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.default

if not quantized_op and len(args) == 8 and args[-1] is True:
# Already in NHWC layout.
return super().call_operator(op, args, kwargs, meta)

new_op = (
exir_ops.edge.cadence.quantized_conv_nhwc.default
exir_ops.edge.cadence.quantized_conv2d_nhwc.default
if quantized_op
else exir_ops.edge.cadence.convolution.default
)
Expand Down Expand Up @@ -1067,8 +1067,8 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass):
# decompose to.
conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
exir_ops.edge.cadence.quantized_conv_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
exir_ops.edge.cadence.quantized_conv_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
exir_ops.edge.cadence.quantized_conv2d_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
exir_ops.edge.cadence.quantized_conv2d_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
}

def call_operator(self, op, args, kwargs, meta):
Expand All @@ -1077,8 +1077,8 @@ def call_operator(self, op, args, kwargs, meta):

# Get the relevant args from convolution node.
quantized_op = (
op == exir_ops.edge.cadence.quantized_conv_nchw.default
or op == exir_ops.edge.cadence.quantized_conv_nhwc.default
op == exir_ops.edge.cadence.quantized_conv2d_nchw.default
or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default
)
assert (len(args) == 8 and not quantized_op) or (
len(args) >= 12 and quantized_op
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def call_operator(self, op, args, kwargs, meta):
# channel_last layout is specified by the channel_last arg of conv
# op, which is either the last argument (15th) or implicitely False
# if the op is quantized, or the last argument if not.
channel_last = op == exir_ops.edge.cadence.quantized_conv_nhwc.default
channel_last = op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default
# The weight tensor is [out_channels, in_channels, X] for NCHW layout,
# and [out_channels, X, in_channels] for NHWC layout. Here, X is the
# kernel_width for conv1d, and X = kernel_height * kernel_width for
Expand Down Expand Up @@ -1622,12 +1622,12 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
exir_ops.edge.cadence.quantized_add.per_tensor,
[1, 2, 4, 5],
),
exir_ops.edge.cadence.quantized_conv_nchw: (
exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nchw: (
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
[8, 9, 12, 13],
),
exir_ops.edge.cadence.quantized_conv_nhwc: (
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nhwc: (
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor,
[8, 9, 12, 13],
),
exir_ops.edge.cadence.quantized_fully_connected: (
Expand Down
28 changes: 14 additions & 14 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,40 +906,40 @@ def test_quantized_conv_per_tensor(

convs = [
(
torch.ops.cadence.quantized_conv_nchw.per_tensor
torch.ops.cadence.quantized_conv2d_nchw.per_tensor
if memory_format == torch.contiguous_format
else torch.ops.cadence.quantized_conv_nhwc.per_tensor
else torch.ops.cadence.quantized_conv2d_nhwc.per_tensor
)
]

optimized_convs = []
if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8:
if memory_format == torch.contiguous_format:
optimized_convs = [
torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor,
]

else:
optimized_convs = [
torch.ops.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
]
elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8:
if memory_format == torch.contiguous_format:
optimized_convs = [
torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor,
]

else:
optimized_convs = [
torch.ops.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor,
]

convs.extend(optimized_convs)
Expand Down
Loading
Loading