diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index d077169022a..cd9c5ae080d 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -60,6 +60,18 @@ python_library( ], ) +python_library( + name = "ops_registrations", + srcs = [ + "ops_registrations.py", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/exir:scalar_type", + ], +) + export_file(name = "functions.yaml") executorch_generated_lib( diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index a4d856ebed2..df8b98d5b5f 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from math import prod from typing import Optional, Tuple @@ -74,8 +76,8 @@ def quantize_per_tensor_meta( zero_point: int, quant_min: int, quant_max: int, - dtype: ScalarType, -): + dtype: torch.dtype, +) -> torch.Tensor: return input.new_empty(input.size(), dtype=dtype) @@ -86,8 +88,8 @@ def dequantize_per_tensor_meta( zero_point: int, quant_min: int, quant_max: int, - dtype: ScalarType, -): + dtype: torch.dtype, +) -> torch.Tensor: return input.new_empty(input.size(), dtype=torch.float) @@ -102,7 +104,7 @@ def quantized_linear_meta( out_shift: torch.Tensor, out_zero_point: int, offset: Optional[torch.Tensor], -): +) -> torch.Tensor: # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -162,7 +164,7 @@ def quantized_layer_norm_meta( eps: float, output_scale: float, output_zero_point: int, -): +) -> torch.Tensor: return input.new_empty(input.size(), dtype=torch.uint8) @@ -173,7 +175,7 @@ def quantized_relu_meta( out_zero_point: int, out_multiplier: torch.Tensor, out_shift: torch.Tensor, -): +) -> torch.Tensor: return X.new_empty(X.size(), dtype=torch.uint8)