Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,26 @@
- arg_meta: null
kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,26 @@
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
200 changes: 200 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,30 @@
lib.define(
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
)
Expand Down Expand Up @@ -1165,6 +1189,182 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
in_size,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size[0],
False,
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, False
)
)

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
in_size,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size[0],
False,
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, False
)
)

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
in_size,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size[0],
True,
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, True
)
)

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
in_size,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size[0],
True,
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, True
)
)

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_layer_norm")
def quantized_layer_norm_meta(
input: torch.Tensor,
Expand Down
Loading
Loading