diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 47183bed21d..75a57531adf 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -88,6 +88,9 @@ elseif(EXECUTORCH_FUSION_G3_OPT) ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 ) +elseif(EXECUTORCH_VISION_OPT) + set(TARGET_DIR vision) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) else() set(TARGET_DIR reference) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) diff --git a/backends/cadence/aot/functions_vision.yaml b/backends/cadence/aot/functions_vision.yaml new file mode 100644 index 00000000000..8d9cdd16105 --- /dev/null +++ b/backends/cadence/aot/functions_vision.yaml @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This yaml file contains operators that are also defined by the ATen library. +# For lean mode: +# - Codegen'd target `executorch_generated_lib` will be reading all the information +# from this file, including operator schema and kernel metadata. +# - Selective build target `codegen:executorch_defined_ops` now is selecting all the +# operators in this file, by dumping all the op names into `selected_operators.yaml`. +# +# See the README.md file in executorch/kernels/portable for a description of the syntax used +# by this file. + + +# aten ops +- op: _to_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::to_copy_out + +- op: _softmax.out + kernels: + - arg_meta: null + kernel_name: impl::vision::native::_softmax_out + +- op: add.out + kernels: + - arg_meta: null + kernel_name: impl::vision::native::add_out + +- op: bmm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::bmm_out + +- op: cat.out + kernels: + - arg_meta: null + kernel_name: torch::executor::cat_out + +- op: clone.out + kernels: + - arg_meta: null + kernel_name: torch::executor::clone_out + +- op: div.out + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out + +- op: div.out_mode + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out_mode + +- op: embedding.out + kernels: + - arg_meta: null + kernel_name: impl::vision::native::embedding_out + +- op: empty.out + kernels: + - arg_meta: null + kernel_name: torch::executor::empty_out + +- op: expand_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::expand_copy_out + +- op: full.out + kernels: + - arg_meta: null + kernel_name: impl::vision::native::full_out + +- op: gelu.out + kernels: + - arg_meta: null + kernel_name: torch::executor::gelu_out + +- op: hardtanh.out + kernels: + - arg_meta: null + kernel_name: torch::executor::hardtanh_out + +- op: max_pool2d_with_indices.out + kernels: + - arg_meta: null + kernel_name: torch::executor::max_pool2d_with_indices_out + +- op: mean.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dim_out + +- op: mul.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_out + +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_scalar_out + +- op: permute_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::permute_copy_out + +- op: rsqrt.out + kernels: + - arg_meta: null + kernel_name: torch::executor::rsqrt_out + +- op: sigmoid.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sigmoid_out + +- op: slice_copy.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::slice_copy_Tensor_out + +- op: split_with_sizes_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::split_with_sizes_copy_out + +- op: sub.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sub_out + +- op: view_copy.out + kernels: + - arg_meta: null + kernel_name: impl::vision::native::view_copy_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::where_out + +- op: transpose_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::transpose_copy_int_out + +- op: eq.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::eq_scalar_out + +- op: logical_not.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logical_not_out + +- op: any.out + kernels: + - arg_meta: null + kernel_name: torch::executor::any_out + +- op: native_group_norm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::native_group_norm_out + +- op: sum.IntList_out + kernels: + - arg_meta: null + kernel_name: torch::executor::sum_dim_out + +- op: select_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::select_copy_int_out + +# custom ops +- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantize_per_tensor_out + +- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::vision::native::dequantize_per_tensor_out + +- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_conv_out + +- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_layer_norm_out +- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_layer_norm_per_tensor_out + +- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_linear_out + +- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_relu_out + +- func: cadence::quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_relu_per_tensor_out + +- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_matmul_out + +- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_linear_per_tensor_out + +- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::im2row_out + +- func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::im2row_per_tensor_out + +- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_conv_per_tensor_out + +- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_fully_connected_out + +- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::quantized_fully_connected_per_tensor_out + +- func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::requantize_out + +- func: cadence::requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::native::requantize_per_tensor_out diff --git a/backends/cadence/build_cadence_vision.sh b/backends/cadence/build_cadence_vision.sh new file mode 100755 index 00000000000..7c2c6d68860 --- /dev/null +++ b/backends/cadence/build_cadence_vision.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -euo pipefail + +unset CMAKE_PREFIX_PATH +unset XTENSA_CORE +export XTENSA_CORE=XRC_Vision_130_AO +git submodule sync +git submodule update --init --recursive +./install_requirements.sh +./install_executorch.sh + +rm -rf cmake-out + +STEPWISE_BUILD=false + +if $STEPWISE_BUILD; then + echo "Building ExecuTorch" + CXXFLAGS="-fno-exceptions -fno-rtti" cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_ENABLE_EVENT_TRACER=OFF \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CPUINFO=OFF \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_USE_DL=OFF \ + -DEXECUTORCH_BUILD_CADENCE=OFF \ + -Bcmake-out . + + echo "Building any Cadence-specific binaries on top" + CXXFLAGS="-fno-exceptions -fno-rtti" cmake -DBUCK2="$BUCK" \ + -DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CADENCE=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \ + -DEXECUTORCH_USE_DL=OFF \ + -DEXECUTORCH_BUILD_PORTABLE_OPS=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM=OFF \ + -DPYTHON_EXECUTABLE=python3 \ + -DEXECUTORCH_VISION_OPT=ON \ + -DHAVE_FNMATCH_H=OFF \ + -Bcmake-out/backends/cadence \ + backends/cadence + cmake --build cmake-out/backends/cadence -j8 +else + echo "Building Cadence toolchain with ExecuTorch packages" + cmake_prefix_path="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags" + CXXFLAGS="-fno-exceptions -fno-rtti" cmake -DBUCK2="$BUCK" \ + -DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \ + -DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CPUINFO=OFF \ + -DEXECUTORCH_BUILD_CADENCE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \ + -DEXECUTORCH_USE_DL=OFF \ + -DEXECUTORCH_BUILD_PORTABLE_OPS=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM=OFF \ + -DPYTHON_EXECUTABLE=python3 \ + -DEXECUTORCH_VISION_OPT=ON \ + -DHAVE_FNMATCH_H=OFF \ + -Bcmake-out + cmake --build cmake-out --target install --config Release -j8 +fi + +echo "Run simple model to verify cmake build" +python3 -m examples.portable.scripts.export --model_name="add" +xt-run --turbo cmake-out/executor_runner --model_path=add.pte diff --git a/backends/cadence/vision/kernels/CMakeLists.txt b/backends/cadence/vision/kernels/CMakeLists.txt new file mode 100644 index 00000000000..fa7b2b5203b --- /dev/null +++ b/backends/cadence/vision/kernels/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# lint_cmake: -linelength +add_library( + cadence_kernels + kernels.cpp + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/api/tensor_transposef.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/api/vsoftmaxf.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/tables/expf_tbl.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/tables/nanf_tbl.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/tables/inff_tbl.c +) + +# Let files say "include ". +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 +) + +target_include_directories( + cadence_kernels + PUBLIC . ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/include + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/include_private + ${_common_include_directories} +) + +target_link_libraries(cadence_kernels PRIVATE idma) diff --git a/backends/cadence/vision/kernels/kernels.cpp b/backends/cadence/vision/kernels/kernels.cpp new file mode 100644 index 00000000000..70c811df741 --- /dev/null +++ b/backends/cadence/vision/kernels/kernels.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +namespace impl { +namespace vision { +namespace kernels { + +void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { + Result temp_mem_res = ctx.allocate_temp(size); + return temp_mem_res.ok() ? temp_mem_res.get() : nullptr; +} + +// Quantize a fp32 value to an int8_t/uint8_t value +template +T quantize(const float x, float scale, int32_t zero_point) { + constexpr float min_val = std::numeric_limits::min(); + constexpr float max_val = std::numeric_limits::max(); + float tmp = roundf(x * scale + zero_point); + return std::max(std::min(tmp, max_val), min_val); +} + +// Quantize an fp32 array to an int8_t/uint8_t array +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float inv_scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = quantize(x[i], inv_scale, zero_point); + } +} + +// Dequantize an int8_t/uint8_t value to an fp32 value +template +float dequantize(const T x, float scale, int32_t zero_point) { + return scale * (x - zero_point); +} + +// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = dequantize(x[i], scale, zero_point); + } +} + +// Requantize the int8_t/uint8_t in value to a uint8_t/int8_t out value. +// The scale and zero_point for requantization are in the args. +template +OT requantize( + const IT in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point) { + float dequant = dequantize(in, in_scale, in_zero_point); + return quantize(dequant, inv_out_scale, out_zero_point); +} + +// Requantize the int8_t/uint8_t in array to a uint8_t/int8_t out array. +// The scale and zero_point for requantization are in the args. +template +void requantize( + OT* __restrict__ out, + const IT* __restrict__ in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + out[i] = requantize( + in[i], in_scale, in_zero_point, inv_out_scale, out_zero_point); + } +} + +// explicit template instantiation + +#define typed_quantize_val(dtype) \ + template dtype quantize(const float x, float inv_scale, int32_t zero_point); +typed_quantize_val(int8_t); +typed_quantize_val(uint8_t); +typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); +typed_quantize_val(int32_t); +#undef typed_quantize_val + +#define typed_quantize_vec(dtype) \ + template void quantize( \ + dtype* __restrict__ y, \ + const float* __restrict__ x, \ + float inv_scale, \ + int32_t zero_point, \ + size_t size); +typed_quantize_vec(int8_t); +typed_quantize_vec(uint8_t); +typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); +typed_quantize_vec(int32_t); +#undef typed_quantize_vec + +#define typed_dequantize_val(dtype) \ + template float dequantize(const dtype x, float scale, int32_t zero_point); +typed_dequantize_val(int8_t); +typed_dequantize_val(uint8_t); +typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); +typed_dequantize_val(int32_t); +#undef typed_dequantize_val + +#define typed_dequantize_vec(dtype) \ + template void dequantize( \ + float* __restrict__ y, \ + const dtype* __restrict__ x, \ + float scale, \ + int32_t zero_point, \ + size_t size); +typed_dequantize_vec(int8_t); +typed_dequantize_vec(uint8_t); +typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); +typed_dequantize_vec(int32_t); +#undef typed_dequantize_vec + +#define typed_requantize_val(itype, otype) \ + template otype requantize( \ + const itype in, \ + float in_scale, \ + int32_t in_zero_point, \ + float inv_out_scale, \ + int32_t out_zero_point); +typed_requantize_val(int8_t, int8_t); +typed_requantize_val(int8_t, uint8_t); +typed_requantize_val(int8_t, int16_t); +typed_requantize_val(int8_t, uint16_t); +typed_requantize_val(uint8_t, int8_t); +typed_requantize_val(uint8_t, uint8_t); +typed_requantize_val(uint8_t, int16_t); +typed_requantize_val(uint8_t, uint16_t); +typed_requantize_val(int16_t, int8_t); +typed_requantize_val(int16_t, uint8_t); +typed_requantize_val(int16_t, int16_t); +typed_requantize_val(int16_t, uint16_t); +typed_requantize_val(uint16_t, int8_t); +typed_requantize_val(uint16_t, uint8_t); +typed_requantize_val(uint16_t, int16_t); +typed_requantize_val(uint16_t, uint16_t); +#undef typed_requantize_val + +#define typed_requantize_vec(itype, otype) \ + template void requantize( \ + otype* __restrict__ out, \ + const itype* __restrict__ in, \ + float in_scale, \ + int32_t in_zero_point, \ + float inv_out_scale, \ + int32_t out_zero_point, \ + size_t size); +typed_requantize_vec(int8_t, int8_t); +typed_requantize_vec(int8_t, uint8_t); +typed_requantize_vec(int8_t, int16_t); +typed_requantize_vec(int8_t, uint16_t); +typed_requantize_vec(uint8_t, int8_t); +typed_requantize_vec(uint8_t, uint8_t); +typed_requantize_vec(uint8_t, int16_t); +typed_requantize_vec(uint8_t, uint16_t); +typed_requantize_vec(int16_t, int8_t); +typed_requantize_vec(int16_t, uint8_t); +typed_requantize_vec(int16_t, int16_t); +typed_requantize_vec(int16_t, uint16_t); +typed_requantize_vec(uint16_t, int8_t); +typed_requantize_vec(uint16_t, uint8_t); +typed_requantize_vec(uint16_t, int16_t); +typed_requantize_vec(uint16_t, uint16_t); +#undef typed_requantize_vec + +}; // namespace kernels +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/kernels/kernels.h b/backends/cadence/vision/kernels/kernels.h new file mode 100644 index 00000000000..e86a36515ec --- /dev/null +++ b/backends/cadence/vision/kernels/kernels.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "inttypes.h" +#include "stddef.h" + +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::Result; + +namespace impl { +namespace vision { +namespace kernels { + +void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size); + +template +T quantize(const float x, float scale, int32_t zero_point); + +template +float dequantize(const T x, float scale, int32_t zero_point); + +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +template +OT requantize( + const IT in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point); + +template +void requantize( + OT* __restrict__ out, + const IT* __restrict__ in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point, + size_t size); + +}; // namespace kernels +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/kernels/targets.bzl b/backends/cadence/vision/kernels/targets.bzl new file mode 100644 index 00000000000..02136c872b3 --- /dev/null +++ b/backends/cadence/vision/kernels/targets.bzl @@ -0,0 +1,25 @@ +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "cadence_kernels", + srcs = ["kernels.cpp"], + exported_headers = [ + "kernels.h", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + platforms = CXX, + compatible_with = select({ + "DEFAULT": [], + "ovr_config//cpu:xtensa": ["ovr_config//cpu:xtensa"], + }), + define_static_target = True, + deps = [ + "//executorch/backends/cadence/vision/third-party:vision-nnlib", + "//executorch/runtime/kernel:kernel_includes", + ], + ) diff --git a/backends/cadence/vision/operators/CMakeLists.txt b/backends/cadence/vision/operators/CMakeLists.txt new file mode 100644 index 00000000000..76b784681be --- /dev/null +++ b/backends/cadence/vision/operators/CMakeLists.txt @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +# ATen compliant ops that are needed to run this model. +set(_aten_ops__srcs + "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_softmax.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_hardtanh.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_max_pool2d_with_indices.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/delinearize_index.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp" +) +add_library(aten_ops_cadence ${_aten_ops__srcs}) +target_link_libraries(aten_ops_cadence PUBLIC executorch) +target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels) + +# Let files say "include ". +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 +) + +target_include_directories( + aten_ops_cadence PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} +) + +# Custom ops that are needed to run the test model. +add_library( + custom_ops + "op_quantized_linear_out.cpp" + "op_quantized_conv_out.cpp" + "op_quantized_relu_out.cpp" + "op_quantized_layer_norm.cpp" + "op_quantize_per_tensor.cpp" + "op_quantized_fully_connected_out.cpp" + "op_dequantize_per_tensor.cpp" + "op_quantized_matmul_out.cpp" + "op_requantize_out.cpp" + "op_im2row_out.cpp" +) +target_include_directories( + custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} +) + +target_link_libraries(custom_ops PUBLIC executorch) +target_link_libraries(custom_ops PRIVATE cadence_kernels) + +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in functions_vision.yaml +gen_selected_ops( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions_vision.yaml" "" "" +) +generate_bindings_for_kernels( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML FUNCTIONS_YAML + ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions_vision.yaml +) +message("Generated cadence x86 files ${gen_command_sources}") + +gen_operators_lib( + LIB_NAME "cadence_ops_lib" KERNEL_LIBS custom_ops DEPS aten_ops_cadence +) diff --git a/backends/cadence/vision/operators/op_add.cpp b/backends/cadence/vision/operators/op_add.cpp new file mode 100644 index 00000000000..81014143275 --- /dev/null +++ b/backends/cadence/vision/operators/op_add.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using executorch::aten::Scalar; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::canCast; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::promoteTypes; +using torch::executor::apply_binary_elementwise_fn; +using torch::executor::Error; +using torch::executor::native::utils::extract_scalar; + +namespace impl { +namespace vision { +namespace native { + +Tensor& add_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + (void)ctx; + + using namespace torch::executor::native::utils; + + ScalarType a_type = a.scalar_type(); + ScalarType b_type = b.scalar_type(); + ScalarType common_type = promoteTypes(a_type, b_type); + ScalarType out_type = out.scalar_type(); + + ET_CHECK_MSG(a_type == ScalarType::Float, "Input tensor not a float.\n"); + ET_CHECK_MSG(b_type == ScalarType::Float, "Input tensor not a float.\n"); + ET_CHECK_MSG(out_type == ScalarType::Float, "Output tensor not a float.\n"); + + ET_CHECK(canCast(common_type, out_type)); + + using CTYPE_A = float; + using CTYPE_B = float; + using CTYPE_IN = float; + using CTYPE_OUT = float; + CTYPE_IN alpha_val; + ET_CHECK_MSG( + extract_scalar(alpha, &alpha_val), + "Could not be extracted: wrong type or out of range"); + + apply_binary_elementwise_fn( + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted + alpha_val * b_casted; + + return static_cast(value); + }, + a, + b, + out); + + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_dequantize_per_tensor.cpp b/backends/cadence/vision/operators/op_dequantize_per_tensor.cpp new file mode 100644 index 00000000000..833606fb651 --- /dev/null +++ b/backends/cadence/vision/operators/op_dequantize_per_tensor.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +void dequantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + impl::vision::native::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + impl::vision::native::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if ( + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::UInt16) { + const uint16_t* input_data = input.const_data_ptr(); + impl::vision::native::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + impl::vision::native::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + impl::vision::native::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_embedding.cpp b/backends/cadence/vision/operators/op_embedding.cpp new file mode 100644 index 00000000000..5273cb083e8 --- /dev/null +++ b/backends/cadence/vision/operators/op_embedding.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +namespace impl { +namespace vision { +namespace native { + +void embedding_out( + KernelRuntimeContext& ctx, + const Tensor& weight, + const Tensor& indices, + int64_t padding_idx, + bool scale_grad_by_freq, + bool sparse, + Tensor& out) { + int64_t nbytes_per_entry = weight.size(1) * weight.element_size(); + const char* w_data = weight.const_data_ptr(); + char* out_data = out.mutable_data_ptr(); + const int64_t* indices_ptr = indices.const_data_ptr(); + + for (int i = 0, e = indices.numel(); i < e; i++) { + // memcpy(dest, src, nbytes); + memcpy( + out_data, w_data + nbytes_per_entry * indices_ptr[i], nbytes_per_entry); + out_data += nbytes_per_entry; + } +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_full.cpp b/backends/cadence/vision/operators/op_full.cpp new file mode 100644 index 00000000000..afc29718a2b --- /dev/null +++ b/backends/cadence/vision/operators/op_full.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using executorch::aten::IntArrayRef; +using executorch::aten::Scalar; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using torch::executor::Error; +using torch::executor::native::utils::extract_scalar; +using torch::executor::native::utils::get_scalar_dtype; + +namespace impl { +namespace vision { +namespace native { + +Tensor& full_out( + KernelRuntimeContext& ctx, + const IntArrayRef sizes, + const Scalar& fill_value, + Tensor& out) { + (void)ctx; + + ScalarType val_type = get_scalar_dtype(fill_value); + ScalarType out_type = out.scalar_type(); + + Error err = resize_tensor(out, sizes); + ET_CHECK_MSG(err == Error::Ok, "Could not resize out"); + + ET_SWITCH_REAL_TYPES_AND(Bool, val_type, ctx, "full", CTYPE_VAL, [&] { + CTYPE_VAL val; + ET_CHECK_MSG( + extract_scalar(fill_value, &val), + "Could not be extracted: wrong type or out of range"); + + ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "full", CTYPE_OUT, [&] { + CTYPE_OUT val_casted = static_cast(val); + auto data_out = out.mutable_data_ptr(); + for (size_t i = 0; i < out.numel(); ++i) { + data_out[i] = val_casted; + } + }); + }); + + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_im2row_out.cpp b/backends/cadence/vision/operators/op_im2row_out.cpp new file mode 100644 index 00000000000..501f8ce5376 --- /dev/null +++ b/backends/cadence/vision/operators/op_im2row_out.cpp @@ -0,0 +1,298 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +template +__attribute__((always_inline)) void im2row_( + const T* __restrict__ data_im, + const int32_t in_zero_point, + /* input parameters*/ + const int32_t channels, + const int32_t height, + const int32_t width, + /* output parameters */ + const int32_t out_height, + const int32_t out_width, + /* convolution parameters */ + const int32_t kernel_h, + const int32_t kernel_w, + const int32_t pad_h, + const int32_t pad_w, + const int32_t stride_h, + const int32_t stride_w, + const int32_t dilation_h, + const int32_t dilation_w, + T* __restrict__ data_col, + bool channels_last) { + // Consider convolving the input image of dimensions channels * height * width + // (or height * width * channels for NHWC layout) with a filter of dimensions + // channels * kernels_h * kernels_w. Assume that this convolution will produce + // an output of dimensinos out_height x out_width. For each point the output, + // im2row takes the data from the input that is used in the computation of + // that output point, and flattens it into a vector of size channels_col = + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D + // array of size (out_height * out_width) x channels_col + const int32_t channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const T* __restrict__ slice_im = + data_im + (h_im * width + w_im) * channels; + T* __restrict__ slice_col = data_col + i_col * channels_col + + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::memcpy(slice_col, slice_im, channels * sizeof(T)); + } else { + std::fill_n(slice_col, channels, T(in_zero_point)); + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + + // Each point in the output domain is the result of applying a filter + // of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + for (int _kw = 0; _kw < kernel_w; ++_kw) { + // c_col is the linearized access in the channels_col vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + // If the current data access is within the input tensor, copy the + // value + data_col[i_col * channels_col + c_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(_c * height + h_im) * width + w_im] + : static_cast(in_zero_point); + } + } + } + } + } + } +} + +void im2row_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + const Tensor& in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + + // Check if the input is per-tensor quantized or per-channel quantized. The + // zero point for each batch could differ for per-channel quantized input. + bool per_tensor_quantized = in_zero_point.numel() == 1; + +#define typed_im2row(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const int32_t* __restrict__ zero_point = \ + in_zero_point.const_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + per_tensor_quantized ? zero_point[0] : zero_point[n], \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row(Float, float); + typed_im2row(Byte, uint8_t); + typed_im2row(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row +} + +void im2row_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + int64_t in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + +#define typed_im2row_per_tensor(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + in_zero_point, \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row_per_tensor(Float, float); + typed_im2row_per_tensor(Byte, uint8_t); + typed_im2row_per_tensor(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row.per_tensor not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row_per_tensor +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_quantize_per_tensor.cpp b/backends/cadence/vision/operators/op_quantize_per_tensor.cpp new file mode 100644 index 00000000000..8d209af24b1 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantize_per_tensor.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +// Quantize the input tensor (PT2 version). Note that quant_ are not +// used in any computation. +void quantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + impl::vision::native::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + impl::vision::native::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if ( + out.scalar_type() == ScalarType::Bits16 || + out.scalar_type() == ScalarType::UInt16) { + uint16_t* out_data = out.mutable_data_ptr(); + impl::vision::native::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + impl::vision::native::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + impl::vision::native::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(out.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_conv_out.cpp b/backends/cadence/vision/operators/op_quantized_conv_out.cpp new file mode 100644 index 00000000000..6ffb36aa836 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_conv_out.cpp @@ -0,0 +1,535 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 2d conv kernel that operates on raw pointers. +// The version handles both quantized and fp32 convolutions. +// The input is of shape [n x c x h x w] +// The weight is of shape [oc x wc x wh x ww], where wc == c +// The output is of shape [n x oc x oh x ow] +// The bias is of shape [oc] +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + int ioff = + (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + ::impl::vision::native::kernels::quantize( + val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler.x`` + if (zero_pad_unit_dilation) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = + in_batch + (_h + _wh) * w * c + (_w + _ww) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1 < w))) { + const IT* in_line = in_batch + + (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = ::impl::vision::native::kernels::quantize( + val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + +// The quantized convolution kernel. in_scale and weight_scale are implicit in +// bias_scale, since it is a product of the two. The kernel will branch to +// quantized::conv1d or quantized::conv2d based on the dimensionality of +// activation tensor. +void quantized_conv_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + +#define typed_quantized_conv2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nchw +} + +void quantized_conv_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + +#define typed_quantized_conv2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc +} + +void quantized_conv_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED const Tensor& out_multiplier, + __ET_UNUSED const Tensor& out_shift, + bool channel_last, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + if (channel_last) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } else { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } +} + +void quantized_conv_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + bool channel_last, + Tensor& out) { + if (channel_last) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } else { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_fully_connected_out.cpp b/backends/cadence/vision/operators/op_quantized_fully_connected_out.cpp new file mode 100644 index 00000000000..29aa8906414 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_fully_connected_out.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantized_fully_connected_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +void quantized_fully_connected_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_layer_norm.cpp b/backends/cadence/vision/operators/op_quantized_layer_norm.cpp new file mode 100644 index 00000000000..a9685eddedb --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_layer_norm.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; + +namespace impl { +namespace vision { +namespace native { + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_per_tensor_( + const Tensor& input, + double input_scale, + int64_t input_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Get the raw pointers to input, output, weight, and bias + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.const_data_ptr(); + + float output_inv_scale = 1.0f / output_scale; + + size_t last_dim = input.size(input.dim() - 1); + size_t leading_dims = getLeadingDims(input, input.dim() - 1); + + // Visualize the input tensor as a set of 1d vectors, and compute the + // layer_norm for each vector. + for (size_t i = 0; i < leading_dims; ++i) { + const T* x = in_data + i * last_dim; + T* y = out_data + i * last_dim; + + // compute sum and squared sum. The fp32 sum can be approximated as: + // (X_1 - in_zero_point) * in_scale + (X_2 - in_zero_point) * in_scale + ... + // (X_N - in_zero_point) * in_scale. + int32_t sum = 0; + int32_t sq_sum = last_dim * input_zero_point * input_zero_point; + for (size_t j = 0; j < last_dim; ++j) { + int32_t val = x[j]; + sum += val; + sq_sum += val * val; + } + sq_sum -= (2 * sum * input_zero_point); + sum -= (last_dim * input_zero_point); + + float mean = (input_scale * sum) / last_dim; + float variance = + (sq_sum * input_scale * input_scale) / last_dim - mean * mean; + float inv_std = 1.0f / std::sqrt(variance + eps); + + // y = (x - mean) / std * kGamma + kBeta + for (int j = 0; j < last_dim; ++j) { + // y[j] = (x[j] - mean) / std * kGamma + kBeta; + // Since X is quantized, we dequantize it, compute fp32 result, and + // quantize the result to an int8/uint8 value. + float val = kernels::dequantize(x[j], input_scale, input_zero_point); + + val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; + y[j] = kernels::quantize(val, output_inv_scale, output_zero_point); + } + } +} + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + + // Call other overload + quantized_layer_norm_per_tensor_( + input, + input_scale, + input_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); +} + +void quantized_layer_norm_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +void quantized_layer_norm_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_layer_norm_per_tensor_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_layer_norm_per_tensor_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_linear_out.cpp b/backends/cadence/vision/operators/op_quantized_linear_out.cpp new file mode 100644 index 00000000000..b6b7cdd17bc --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_linear_out.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::getLeadingDims; +using executorch::runtime::KernelRuntimeContext; + +template +void inline _typed_quantized_linear( + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + Tensor& out) { + const T* __restrict__ src_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + const auto M = weight.size(0); // = out_dim + const auto N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + const auto leading_dims = getLeadingDims(src, src.dim() - 1); + + ET_CHECK_MSG( + out_multiplier.numel() == 1, "out_multiplier should have one element"); + ET_CHECK_MSG( + out_shift.numel() == 1, "out_multiplier should have one element"); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += (src_data[i * N + k] - src_zero_point) * + (weight_data[j * N + k] - weight_zero_point); + } + out_data[i * M + j] = + kernels::quantize(sum, out_scale, out_zero_point); + } + } +} + +void quantized_linear_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { + // TODO: refactor to use switch case as quantized_linear_per_tensor_out + if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(src.scalar_type())); + } +} + +void quantized_linear_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + executorch::aten::ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_matmul_out.cpp b/backends/cadence/vision/operators/op_quantized_matmul_out.cpp new file mode 100644 index 00000000000..54a303288c3 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_matmul_out.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::getLeadingDims; +using executorch::runtime::KernelRuntimeContext; + +// The quantized matmul. The quantized matmul accumulates in a wider register, +// whose type is TA. +template < + typename TZ, + typename TA = float, + bool transposed = false, + typename TX = TZ, + typename TY = TZ> +__attribute__((noinline)) void qmatmul( + TZ* __restrict__ Z, + int32_t Z_multiplier, + int32_t Z_shift, + int32_t Z_zero_point, + const TX* __restrict__ X, + int32_t X_zero_point, + const TY* __restrict__ y, + int32_t Y_zero_point, + size_t m, + size_t n, + size_t p) { + // Compute the Z_scale from Z_multiplier and Z_shift + const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < p; ++j) { + TA sum = 0; + for (size_t k = 0; k < n; ++k) { + if (transposed) { + sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); + } else { + sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); + } + } + Z[i * p + j] = kernels::quantize(sum, Z_scale, Z_zero_point); + } + } +} + +template +void inline _typed_quantized_matmul( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const executorch::aten::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + size_t batch_size = getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } +} + +void quantized_matmul_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const executorch::aten::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(X.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_relu_out.cpp b/backends/cadence/vision/operators/op_quantized_relu_out.cpp new file mode 100644 index 00000000000..45b9e09b1dd --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_relu_out.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +template +void quantized_relu_( + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + T q_zero_point = in_zero_point.const_data_ptr()[0]; + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; + out[i] = kernels::quantize(temp, out_scale, out_zero_point); + } +} + +void quantized_relu_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_relu_( + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_relu_( + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +template +void quantized_relu_per_tensor_out_( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; + out[i] = kernels::quantize(temp, out_scale, out_zero_point); + } +} + +void quantized_relu_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { +#define typed_quantized_relu(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_relu_per_tensor_out_( \ + ctx, \ + input, \ + in_zero_point, \ + out_zero_point, \ + out_multiplier, \ + out_shift, \ + output); \ + break; \ + } + + executorch::aten::ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_relu +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_requantize_out.cpp b/backends/cadence/vision/operators/op_requantize_out.cpp new file mode 100644 index 00000000000..ef538bf4045 --- /dev/null +++ b/backends/cadence/vision/operators/op_requantize_out.cpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +// Requantize the int8_t/uint8_t input tensor to a uint8_t/int8_t out tensor. +// The scale and zero_point for requantization are in the args. +Tensor& requantize_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale_t, + const Tensor& in_zero_point_t, + const Tensor& out_scale_t, + const Tensor& out_zero_point_t, + const ScalarType out_dtype, + Tensor& out) { + ET_KERNEL_CHECK_MSG( + ctx, + in_scale_t.scalar_type() == ScalarType::Float, + InvalidArgument, + out, + "In scale is not a float: %s", + torch::executor::toString(in_scale_t.scalar_type())); + float in_scale = in_scale_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + in_zero_point_t.scalar_type() == ScalarType::Int, + InvalidArgument, + out, + "In zero point is not an int: %s", + torch::executor::toString(in_zero_point_t.scalar_type())); + int32_t in_zero_point = in_zero_point_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + out_scale_t.scalar_type() == ScalarType::Float, + InvalidArgument, + out, + "Out scale is not a float: %s", + torch::executor::toString(out_scale_t.scalar_type())); + float out_scale = out_scale_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + out_zero_point_t.scalar_type() == ScalarType::Int, + InvalidArgument, + out, + "Out zero point is not an int: %s", + torch::executor::toString(out_zero_point_t.scalar_type())); + int32_t out_zero_point = out_zero_point_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + out.scalar_type() == out_dtype, + InvalidArgument, + out, + "Out tensor dtype (%s) does not match the passed in out dtype (%s)", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + + const size_t numel = out.numel(); + ScalarType in_dtype = input.scalar_type(); + + // Assert that the output tensor's dtype is same as out_dtype. + ET_KERNEL_CHECK_MSG( + ctx, + out_dtype == out.scalar_type(), + InvalidArgument, + out, + "Out dtype %s does not match requant dtype %s", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + +#define typed_requantize(ctype, dtype) \ + const ctype* input_data = input.const_data_ptr(); \ + dtype* out_data = out.mutable_data_ptr(); \ + kernels::requantize( \ + out_data, \ + input_data, \ + in_scale, \ + in_zero_point, \ + 1.0 / out_scale, \ + out_zero_point, \ + numel); + +#define typed_requantize_in(ctype) \ + switch (out_dtype) { \ + case ScalarType::Byte: { \ + typed_requantize(ctype, uint8_t); \ + break; \ + } \ + case ScalarType::Char: { \ + typed_requantize(ctype, int8_t); \ + break; \ + } \ + case ScalarType::UInt16: { \ + typed_requantize(ctype, uint16_t); \ + break; \ + } \ + case ScalarType::Short: { \ + typed_requantize(ctype, int16_t); \ + break; \ + } \ + default: \ + ET_KERNEL_CHECK_MSG( \ + ctx, \ + false, \ + InvalidArgument, \ + out, \ + "Unhandled output dtype %s", \ + torch::executor::toString(out_dtype)); \ + } + + switch (in_dtype) { + case ScalarType::Byte: { + typed_requantize_in(uint8_t); + break; + } + case ScalarType::Char: { + typed_requantize_in(int8_t); + break; + } + case ScalarType::UInt16: { + typed_requantize_in(uint16_t); + break; + } + case ScalarType::Short: { + typed_requantize_in(int16_t); + break; + } + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + out, + "Unhandled input dtype %s", + torch::executor::toString(in_dtype)); + } +#undef typed_requantize_in +#undef typed_requantize + return out; +} + +// Requantize the int8_t/uint8_t input tensor to a uint8_t/int8_t out tensor. +// The scale and zero_point for requantization are in the args. +Tensor& requantize_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + const ScalarType out_dtype, + Tensor& out) { + ET_KERNEL_CHECK_MSG( + ctx, + out.scalar_type() == out_dtype, + InvalidArgument, + out, + "Out tensor dtype (%s) does not match the passed in out dtype (%s)", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + + const size_t numel = out.numel(); + ScalarType in_dtype = input.scalar_type(); + + // Assert that the output tensor's dtype is same as out_dtype. + ET_KERNEL_CHECK_MSG( + ctx, + out_dtype == out.scalar_type(), + InvalidArgument, + out, + "Out dtype %s does not match requant dtype %s", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + +#define typed_requantize(ctype, dtype) \ + const ctype* input_data = input.const_data_ptr(); \ + dtype* out_data = out.mutable_data_ptr(); \ + kernels::requantize( \ + out_data, \ + input_data, \ + static_cast(in_scale), \ + static_cast(in_zero_point), \ + 1.0 / static_cast(out_scale), \ + static_cast(out_zero_point), \ + numel); + +#define typed_requantize_in(ctype) \ + switch (out_dtype) { \ + case ScalarType::Byte: { \ + typed_requantize(ctype, uint8_t); \ + break; \ + } \ + case ScalarType::Char: { \ + typed_requantize(ctype, int8_t); \ + break; \ + } \ + case ScalarType::UInt16: { \ + typed_requantize(ctype, uint16_t); \ + break; \ + } \ + case ScalarType::Short: { \ + typed_requantize(ctype, int16_t); \ + break; \ + } \ + default: \ + ET_KERNEL_CHECK_MSG( \ + ctx, \ + false, \ + InvalidArgument, \ + out, \ + "Unhandled output dtype %s", \ + torch::executor::toString(out_dtype)); \ + } + + switch (in_dtype) { + case ScalarType::Byte: { + typed_requantize_in(uint8_t); + break; + } + case ScalarType::Char: { + typed_requantize_in(int8_t); + break; + } + case ScalarType::UInt16: { + typed_requantize_in(uint16_t); + break; + } + case ScalarType::Short: { + typed_requantize_in(int16_t); + break; + } + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + out, + "Unhandled input dtype %s", + torch::executor::toString(in_dtype)); + } +#undef typed_requantize_in +#undef typed_requantize + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_softmax.cpp b/backends/cadence/vision/operators/op_softmax.cpp new file mode 100644 index 00000000000..e2963bdcffe --- /dev/null +++ b/backends/cadence/vision/operators/op_softmax.cpp @@ -0,0 +1,303 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using torch::executor::Error; + +namespace impl { +namespace vision { +namespace native { + +Tensor& _softmax_out( + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + bool half_to_float, + Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); + + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + + const executorch::aten::optional& dim_t = dim; + const size_t d = ET_NORMALIZE_IX(dim_t.value(), in.dim()); + const size_t size = in.size(d); + + size_t stride = 1, outer_size = 1; + + size_t outer_stride = 1; + + constexpr auto name = "_softmax.out"; + constexpr int MaxDim = 5; + + bool optimized = true; + bool ping_pong_process = false; + bool ping_process_pong = false; + + if ((d == in.dim() - 1)) { + if (size <= IDMA_BUFF_SIZE / 4 && in.dim() != 1) { + ping_pong_process = true; + } else if (size <= IDMA_BUFF_SIZE / 2) { + ping_process_pong = true; + } + } + + if (out.scalar_type() != ScalarType::Float) + optimized = false; + + if (in.dim() > MaxDim) + optimized = false; + + if (optimized) { + const float* ptr_inp = (float*)in.const_data_ptr(); + float* out_data = (float*)out.mutable_data_ptr(); + + /* Channel 0*/ + idma_init(0, 0, MAX_BLOCK_16, 8, TICK_CYCLES_1, 0, NULL); + idma_init_loop(0, descbuf[0], IDMA_2D_DESC, 1, NULL, NULL); + + /* Channel 1*/ + idma_init(1, 0, MAX_BLOCK_16, 8, TICK_CYCLES_1, 0, NULL); + idma_init_loop(1, descbuf[1], IDMA_2D_DESC, 1, NULL, NULL); + + if (ping_pong_process) { + for (int i = 0; i < in.dim(); i++) { + if (i != d) + outer_size *= in.size(i); + } + + outer_stride = size; + stride = size; + + int pp_swap = 0; + + float32_t* ptr_out = out_data; + float32_t* ptr_in = (float32_t*)ptr_inp; + + idma_copy_2d_desc( + 0, inpData[pp_swap], ptr_in, 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + pp_swap = 1; + + for (int i = 0; i < (outer_size - 1); i++) { + IDMA_HW_WAIT_ALL(0); + ptr_in += outer_stride; + idma_copy_2d_desc( + 0, + inpData[pp_swap], + ptr_in, + 4 * stride, + DESC_IDMA_PRIOR_H, + 1, + 0, + 0); + pp_swap = pp_swap ^ 1; + + /* PROCESS CALL */ + vsoftmaxf(outData[pp_swap], inpData[pp_swap], stride); + + IDMA_HW_WAIT_ALL(1); + idma_copy_2d_desc( + 1, + ptr_out, + outData[pp_swap], + 4 * stride, + DESC_IDMA_PRIOR_H, + 1, + 0, + 0); + ptr_out += outer_stride; + } + + IDMA_HW_WAIT_ALL(0); + pp_swap = pp_swap ^ 1; + + /* PROCESS CALL */ + vsoftmaxf(outData[pp_swap], inpData[pp_swap], stride); + + IDMA_HW_WAIT_ALL(1); + idma_copy_2d_desc( + 1, ptr_out, outData[pp_swap], 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + + IDMA_HW_WAIT_ALL(1); + + return out; + } else if (ping_process_pong) { + for (int i = 0; i < in.dim(); i++) { + if (i != d) + outer_size *= in.size(i); + } + + outer_stride = size; + stride = size; + + float32_t* ptr_out = out_data; + float32_t* ptr_in = (float32_t*)ptr_inp; + + for (int i = 0; i < outer_size; i++) { + idma_copy_2d_desc( + 0, data_dram0, ptr_in, 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + IDMA_HW_WAIT_ALL(0); + + vsoftmaxf(data_dram1, data_dram0, stride); + + idma_copy_2d_desc( + 1, ptr_out, data_dram1, 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + IDMA_HW_WAIT_ALL(1); + + ptr_in += outer_stride; + ptr_out += outer_stride; + } + + return out; + } else { + int num_inp_dims = in.dim(); + int num_out_dims = num_inp_dims; + + int ptr_inp_shape[MaxDim]; + int ptr_out_shape[MaxDim]; + int ptr_permute_vec[MaxDim]; + + for (int i = 0; i < num_inp_dims; i++) + ptr_inp_shape[i] = in.size(i); + + for (int i = 0; i < num_inp_dims; i++) { + if (i == d) + ptr_permute_vec[i] = num_inp_dims - 1; + else if (i == (num_inp_dims - 1)) + ptr_permute_vec[num_inp_dims - 1] = d; + else + ptr_permute_vec[i] = i; + + ptr_out_shape[i] = ptr_inp_shape[ptr_permute_vec[i]]; + + if (i != d) + outer_size = outer_size * ptr_inp_shape[i]; + } + + outer_stride = size; + + float* ptr_out = (float*)kernels::allocate_temp_memory( + ctx, out.numel() * sizeof(float)); + + ET_KERNEL_CHECK(ctx, ptr_out != nullptr, MemoryAllocationFailed, out); + + float* ptr_out1 = (float*)kernels::allocate_temp_memory( + ctx, out.numel() * sizeof(float)); + + ET_KERNEL_CHECK(ctx, ptr_out1 != nullptr, MemoryAllocationFailed, out); + + tensor_transposef( + ptr_out, + ptr_out_shape, + ptr_inp, + ptr_inp_shape, + ptr_permute_vec, + num_out_dims, + num_inp_dims); + + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + size_t outer = outer_idx * outer_stride; + for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) { + size_t base = outer + inner_idx; + + float* ptr_in_data = &ptr_out[base]; + float* ptr_out_data = &ptr_out1[base]; + + vsoftmaxf(ptr_out_data, ptr_in_data, size); + } + } + + tensor_transposef( + out_data, + ptr_inp_shape, + ptr_out1, + ptr_out_shape, + ptr_permute_vec, + num_out_dims, + num_inp_dims); + + return out; + } + } + + ET_SWITCH_FLOATHBF16_TYPES( + in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); + + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in softmax dim. During softmax computation each + // value is subtracted by the maximum in value before calling exp + // to preserve numerical stability. + const CTYPE max_in = torch::executor::apply_unary_reduce_fn( + [](const CTYPE val_in, CTYPE val_accum) { + return std::max(val_in, val_accum); + }, + in_data + base, + size, + stride); + + const CTYPE temp_sum = + torch::executor::apply_unary_map_reduce_fn( + [max_in](const CTYPE val_in) { + return std::exp(val_in - max_in); + }, + [](const CTYPE mapped_in, CTYPE val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const CTYPE val_in) { + return std::exp(val_in - max_in) / temp_sum; + }, + in_data + base, + out_data + base, + size, + stride); + }, + in, + dim); + }); + + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_view_copy.cpp b/backends/cadence/vision/operators/op_view_copy.cpp new file mode 100644 index 00000000000..6d4d3a8a5e0 --- /dev/null +++ b/backends/cadence/vision/operators/op_view_copy.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::IntArrayRef; +using ::executorch::aten::IntArrayRef; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +Tensor& view_copy_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const IntArrayRef size, + Tensor& out) { + memcpy(out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes()); + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/operators.h b/backends/cadence/vision/operators/operators.h new file mode 100644 index 00000000000..36c4486bf85 --- /dev/null +++ b/backends/cadence/vision/operators/operators.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::runtime::getLeadingDims; + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +inline __attribute__((always_inline)) void linear_( + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::optional<::executorch::aten::Tensor>& bias, + ::executorch::aten::Tensor& output) { + const float* __restrict__ input_data = input.const_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.value().const_data_ptr(); + float* __restrict__ output_data = output.mutable_data_ptr(); + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + int64_t M = weight.size(0); // = out_dim + int64_t N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + int64_t leading_dims = getLeadingDims(input, input.dim() - 1); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += input_data[i * N + k] * weight_data[j * N + k]; + } + output_data[i * M + j] = sum; + } + } +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/quantized_ops.h b/backends/cadence/vision/operators/quantized_ops.h new file mode 100644 index 00000000000..b42e45b0b3d --- /dev/null +++ b/backends/cadence/vision/operators/quantized_ops.h @@ -0,0 +1,196 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +template +inline __attribute__((always_inline)) void quantized_linear_per_tensor_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + const int64_t leading_dims = + executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const T* __restrict__ in_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + // Compute the requant_scale from out_multiplier and out_shift + const float requant_scale = + -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + int32_t sum = bias_data[j]; + for (size_t k = 0; k < in_dim; ++k) { + int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; + int32_t w = + (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; + sum += x * w; + } + out_data[i * out_dim + j] = ::impl::vision::native::kernels::quantize( + sum, requant_scale, out_zero_point); + } + } +} + +template +inline __attribute__((always_inline)) void quantized_linear_per_tensor_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // Get the zero_point of weight. + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +template +inline __attribute__((always_inline)) void quantized_linear_per_channel_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + int64_t weight_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = + executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const T* __restrict__ in_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + int32_t sum = bias_data[j]; + for (size_t k = 0; k < in_dim; ++k) { + int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; + int32_t w = + (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; + sum += x * w; + } + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[j] * 1.0 / (1 << 31) * pow(2, out_shift_data[j]); + out_data[i * out_dim + j] = ::impl::vision::native::kernels::quantize( + sum, out_scale, out_zero_point); + } + } +} + +template +inline __attribute__((always_inline)) void quantized_linear_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + int64_t weight_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + if (out_multiplier.numel() == 1) { + // Use per-tensor quantization kernel. + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier_data[0], + out_shift_data[0], + out_zero_point, + out); + return; + } + + // Use per-channel quantization kernel. + quantized_linear_per_channel_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +template +inline __attribute__((always_inline)) void quantized_linear_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // Get the zero_point of weight. + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + quantized_linear_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} diff --git a/backends/cadence/vision/operators/targets.bzl b/backends/cadence/vision/operators/targets.bzl new file mode 100644 index 00000000000..b12118a9c47 --- /dev/null +++ b/backends/cadence/vision/operators/targets.bzl @@ -0,0 +1,64 @@ +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + + +def define_operator(name: str, deps: list[str] | None = None) -> None: + op_name = "op_{}".format(name) + + # Deps used by all operators. + common_deps = [ + "//executorch/kernels/portable/cpu/util:all_deps", + "//executorch/kernels/portable/cpu/pattern:all_deps", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/backends/cadence/vision/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/pattern:bitwise_op", + "//executorch/backends/cadence/vision/third-party:vision-nnlib", + "//executorch/kernels/portable/cpu/pattern:comparison_op" + ] + if deps == None: + deps = [] + + runtime.cxx_library( + name = op_name, + srcs = [op_name + ".cpp"], + platforms = CXX, + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + compatible_with = ["ovr_config//cpu:xtensa"], + deps = deps + common_deps, + exported_headers = ["operators.h"], + ) + +OPERATORS = [ + "add", + "full", + "quantized_fully_connected_out", + "quantized_matmul_out", + "requantize_out", + "dequantize_per_tensor", + "im2row_out", + "quantized_layer_norm", + "quantized_relu_out", + "softmax", + "embedding", + "quantized_conv_out", + "quantized_linear_out", + "quantize_per_tensor", + "view_copy" +] + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + # Define build targets for all operators registered in the tables above. + for op in OPERATORS: + define_operator(op) diff --git a/backends/cadence/vision/third-party/dummy.c b/backends/cadence/vision/third-party/dummy.c new file mode 100644 index 00000000000..52fb7c18c38 --- /dev/null +++ b/backends/cadence/vision/third-party/dummy.c @@ -0,0 +1,17 @@ +/* Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* Dummy source file for non-Xtensa builds + * This file is used when building the vision-nnlib library on platforms + * other than Xtensa, providing empty stubs for compatibility. + * The actual function implementations are provided as stubs via DISCARD_FUN + * in headers when COMPILER_XTENSA is not defined. + */ + +// This file intentionally contains no function definitions and no includes. +// When COMPILER_XTENSA is not defined, all functions are stubbed out +// using the DISCARD_FUN macro in the header files. diff --git a/backends/cadence/vision/third-party/include/api.h b/backends/cadence/vision/third-party/include/api.h new file mode 100644 index 00000000000..efb80c3d76d --- /dev/null +++ b/backends/cadence/vision/third-party/include/api.h @@ -0,0 +1,83 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + * API + */ + +#ifndef __API_H__ +#define __API_H__ + +#include "dtypes.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*------------------------------------------------------------------------- +Softmax + +Description: The function computes the softmax (normalized exponential +function) of input data. 16-bit fixed-point functions accept inputs in +Q3.12 and form outputs in Q7.8 format. + +vsoftmax 16-bit +vsoftmax_fp16 IEEE-754 Std. half precision floating-point. +vsoftmaxf IEEE-754 Std. single precision floating-point. + +Accuracy: +2 LSB for fixed point API +2 ULP for floating point API +NOTE: Accuracy of function may depend on amount of data and their +distribution. Given accuracy is achieved for N=2 for any pair of +data from input domain. + + +Parameters: +Input: +x[N] input data, Q3.12 floating point +N Length of input/output data vectors +Output: +y[N] result, Q7.8 or floating point + +Restrictions: +x,y aligned on 2*BBE_SIMD_WIDTH-bytes boundary (vsoftmax) +x,y Must not overlap +N multiple of BBE_SIMD_WIDTH (vsoftmax) +-------------------------------------------------------------------------*/ +void vsoftmaxf(float32_t *y, const float32_t *x, int N); + +void tensor_transposef(float32_t *restrict ptr_out + ,const int *const ptr_out_shape + ,const float32_t *restrict ptr_inp + ,const int *const ptr_inp_shape + ,const int *restrict ptr_permute_vec + ,int num_out_dims + ,int num_inp_dims); + +#ifdef __cplusplus +}; +#endif + +#endif /* __API_H__ */ diff --git a/backends/cadence/vision/third-party/include/dtypes.h b/backends/cadence/vision/third-party/include/dtypes.h new file mode 100644 index 00000000000..c12bbf23ac2 --- /dev/null +++ b/backends/cadence/vision/third-party/include/dtypes.h @@ -0,0 +1,380 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + * Cross-platform data type definitions and utility macros + */ + +#ifndef __DTYPES_H__ +#define __DTYPES_H__ + +#include + +#ifndef COMPILER_ANSI +/* ---------------------------------------------------------- + Compilers autodetection + ----------------------------------------------------------*/ +#define ___UNKNOWN_COMPILER_YET +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef _MSC_VER + +#ifdef _ARM_ +#define COMPILER_CEARM9E /* Microsoft Visual C++,ARM9E */ +#else +#define COMPILER_MSVC /* Microsoft Visual C++ */ +#endif + +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef _TMS320C6X +#if defined(_TMS320C6400) +#define COMPILER_C64 +#undef ___UNKNOWN_COMPILER_YET +#endif +#if defined(_TMS320C6400_PLUS) +#define COMPILER_C64PLUS +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __TMS320C55X__ +#define COMPILER_C55 +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __ADSPBLACKFIN__ +#define COMPILER_ADSP_BLACKFIN +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __XCC__ +#define COMPILER_XTENSA +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __GNUC__ +#ifdef __arm__ +#ifndef COMPILER_GNU_ARM +#endif +#define COMPILER_GNUARM /* GNU C/C++ compiler*/ +#else +/* GNU GCC x86 compiler */ +#ifndef COMPILER_GNU +#endif +#define COMPILER_GNU /* GNU C/C++ */ +#endif +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#error Unknown compiler +#endif + +#endif /* #ifndef COMPILER_ANSI */ + +/* ---------------------------------------------------------- + Language-dependent definitions + ----------------------------------------------------------*/ +#ifdef __cplusplus + +#undef extern_C +#define extern_C extern "C" + +#else + +#undef extern_C +#define extern_C + +#ifndef false +#define false 0 +#endif +#ifndef true +#define true 1 +#endif + +#endif + +/* Assertion support */ +#if !defined(_ASSERT) +#include +#if defined(_DEBUG) /*&& defined(COMPILER_MSVC)*/ +#define ASSERT(x) \ + { assert(x); } +#else + +/*#undef ASSERT*/ +#ifndef ASSERT +#define ASSERT(_ignore) ((void)0) +#endif + +#endif /* _DEBUG */ +#else /* ASSERT*/ +#define ASSERT(exp) \ + { \ + extern void ExternalAssertHandler(void *, void *, unsigned); \ + (void)((exp) || (ExternalAssertHandler(#exp, __FILE__, __LINE__), 0)); \ + } +#endif /* ASSERT */ + +/*** Inline methods definition ***/ +#undef inline_ +#if (defined COMPILER_MSVC) || (defined COMPILER_CEARM9E) +#define inline_ __inline +#elif defined(COMPILER_ADSP_BLACKFIN) +#define inline_ inline +#elif defined(COMPILER_ANSI) +#define inline_ +#elif (defined COMPILER_GNU) || (defined COMPILER_GNUARM) || \ + (defined COMPILER_ARM) +#define inline_ static inline +#else +#define inline_ static inline +#endif + +#ifndef MAX_INT16 +#define MAX_INT16 ((int16_t)0x7FFF) +#endif +#ifndef MIN_INT16 +#define MIN_INT16 ((int16_t)0x8000) +#endif +#ifndef MAX_INT32 +#define MAX_INT32 ((int32_t)0x7FFFFFFFL) +#endif +#ifndef MIN_INT32 +#define MIN_INT32 ((int32_t)0x80000000L) +#endif +#ifndef MIN_INT64 +#define MIN_INT64 ((int64_t)0x8000000000000000LL) +#endif +#ifndef MAX_INT64 +#define MAX_INT64 ((int64_t)0x7fffffffffffffffLL) +#endif + +/* size of variables in bytes */ +#ifdef COMPILER_C55 +#define SIZEOF_BYTE(x) (sizeof(x) << 1) +#else +#define SIZEOF_BYTE(x) sizeof(x) +#endif + +/*--------------------------------------- + special keywords definition + restrict keyword means that the memory + is addressed exclusively via + this pointer + onchip keyword means that the memory + is on-chip and can not be + accessed via external bus +---------------------------------------*/ +#if defined(COMPILER_C55) +#define NASSERT _nassert +#elif defined(COMPILER_C64) +#define onchip +#define NASSERT _nassert +#elif defined(COMPILER_ADSP_BLACKFIN) +#define onchip +#define NASSERT(x) __builtin_assert(x) +#elif defined(COMPILER_GNUARM) +#define onchip +#define NASSERT(x) \ + { (void)__builtin_expect((x) != 0, 1); } +#define restrict __restrict +#elif defined(COMPILER_GNU) +#define onchip +#define NASSERT(x) \ + { \ + (void)__builtin_expect((x) != 0, 1); \ + ASSERT(x); \ + } +#define restrict __restrict +#elif defined(COMPILER_CEARM9E) +#define onchip +#define NASSERT(x) +#define restrict +#elif defined(COMPILER_XTENSA) +#ifndef restrict +#define restrict __restrict +#endif +#define onchip +#define NASSERT(x) \ + { \ + (void)__builtin_expect((x) != 0, 1); \ + ASSERT(x); \ + } +#else +#define restrict +#define onchip +#define NASSERT ASSERT +#endif +#if defined(COMPILER_ADSP_BLACKFIN) +#define NASSERT_ALIGN(addr, align) __builtin_aligned(addr, align) +#else +#define NASSERT_ALIGN(addr, align) NASSERT(((uintptr_t)(addr)) % (align) == 0) +#endif +#define NASSERT_ALIGN2(addr) NASSERT_ALIGN(addr, 2) +#define NASSERT_ALIGN4(addr) NASSERT_ALIGN(addr, 4) +#define NASSERT_ALIGN8(addr) NASSERT_ALIGN(addr, 8) +#define NASSERT_ALIGN16(addr) NASSERT_ALIGN(addr, 16) +#define NASSERT_ALIGN32(addr) NASSERT_ALIGN(addr, 32) +#define NASSERT_ALIGN64(addr) NASSERT_ALIGN(addr, 64) +#define NASSERT_ALIGN128(addr) NASSERT_ALIGN(addr, 128) +/* ---------------------------------------------------------- + Common types + ----------------------------------------------------------*/ +#if defined(COMPILER_GNU) | defined(COMPILER_GNUARM) | defined(COMPILER_XTENSA) +/* + typedef signed char int8_t; + typedef unsigned char uint8_t; +*/ +#include +#elif defined(COMPILER_C64) +#include +#elif defined(COMPILER_C55) +#include +typedef signed char int8_t; +typedef unsigned char uint8_t; +#elif defined(COMPILER_ADSP_BLACKFIN) +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef unsigned long uint32_t; +typedef unsigned short uint16_t; +typedef long int32_t; +typedef short int16_t; +typedef long long int64_t; +typedef unsigned long long uint64_t; +typedef uint32_t uintptr_t; +#else +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef unsigned long uint32_t; +typedef unsigned short uint16_t; +typedef long int32_t; +typedef short int16_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#endif + +#if defined(COMPILER_CEARM9E) +typedef uint32_t uintptr_t; +#endif + +#if defined(COMPILER_ARM) +typedef uint32_t uintptr_t; +#endif + +typedef int16_t float16_t; +typedef float float32_t; +typedef double float64_t; +typedef int16_t fract16; +typedef int32_t fract32; + +typedef union tag_complex_fract16 { + struct { + int16_t re, im; + } s; + uint32_t a; /* just for 32-bit alignment */ +} complex_fract16; + +typedef union tag_complex_fract32 { + struct { + int32_t re, im; + } s; + uint64_t a; /* just for 64-bit alignment */ +} complex_fract32; + +#if defined(COMPILER_MSVC) +#if 0 +/* Note: Visual Studio does not support C99 compatible complex types yet */ +typedef union tag_complex_float { + struct { + float32_t re, im; + } s; + uint64_t a; /* just for 64-bit alignment */ +} complex_float; +typedef union tag_complex_double { + struct { + float64_t re, im; + } s; + uint64_t a[2]; /* only 64-bit alignment under Visual Studio :(( */ +} complex_double; + +inline_ float32_t crealf(complex_float x) { return x.s.re; } +inline_ float32_t cimagf(complex_float x) { return x.s.im; } +inline_ float64_t creal(complex_double x) { return x.s.re; } +inline_ float64_t cimag(complex_double x) { return x.s.im; } +#else +#include +#define complex_float _Fcomplex +#define complex_double _Dcomplex +#endif + +#else +/* C99 compatible type */ +#include +#define complex_float __complex__ float +#define complex_double __complex__ double +#endif + +/* complex half-precision datatype */ +typedef union tag_complex_float16 { + struct { + float16_t re, im; + } s; + uint32_t a; /* just for 32-bit alignment */ +} complex_float16; + +inline_ float16_t crealh(complex_float16 x) { return x.s.re; } +inline_ float16_t cimagh(complex_float16 x) { return x.s.im; } +/* union data type for writing float32_t/float64_t constants in a bitexact + * form */ +union ufloat32uint32 { + uint32_t u; + float32_t f; +}; +union ufloat64uint64 { + uint64_t u; + float64_t f; +}; +union ufloat16uint16 { + uint16_t u; + float16_t f; +}; + +#if defined(__RENAMING__) +#include "__renaming__.h" +#endif + +#endif /* __DTYPE_H__ */ diff --git a/backends/cadence/vision/third-party/include_private/common.h b/backends/cadence/vision/third-party/include_private/common.h new file mode 100644 index 00000000000..4fc07d8b4d1 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/common.h @@ -0,0 +1,199 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +#ifndef __COMMON_H__ +#define __COMMON_H__ + +#if defined COMPILER_XTENSA +#include +#include +#include +#include +#include +#include +#if XCHAL_HAVE_IDMA +#ifndef IDMA_USE_MULTICHANNEL + #define IDMA_USE_MULTICHANNEL 1 +#endif +#include +#endif +#define IVP_SIMD_WIDTH XCHAL_IVPN_SIMD_WIDTH + +#include "xtensa/config/core-isa.h" +#include "xtensa/tie/xt_ivpn.h" +#if XCHAL_HAVE_IDMA +#include "xtensa/idma.h" +#endif + +#ifdef _MSC_VER +#define ALIGN(x) _declspec(align(x)) +#else +#define ALIGN(x) __attribute__((aligned(x))) +#endif + +#ifdef COMPILER_XTENSA +#define ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#define ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) +#define ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define ATTRIBUTE_ALWAYS_INLINE +#define ATTRIBUTE_NEVER_INLINE +#define ATTRIBUTE_UNUSED +#endif + +/* 'restrict' qualifier, is applied to pointers only under clang compiler */ +#ifdef __clang__ +#define restrict_clang restrict +#else +#define restrict_clang +#endif + +// Performance measurement macros +#define XTPERF_PRINTF(...) printf(__VA_ARGS__) +#define TIME_DECL(test) long start_time_##test, end_time_##test; +#define TIME_START(test) { start_time_##test = 0; XT_WSR_CCOUNT(0); } +#define TIME_END(test) { end_time_##test = XT_RSR_CCOUNT(); } +#define TIME_DISPLAY(test, opcnt, opname) { long long cycles_##test = end_time_##test - start_time_##test; \ + XTPERF_PRINTF("PERF_LOG : %s : %d : %s : %lld : cycles : %.2f : %s/cycle : %.2f : cycles/%s\n", \ + #test, opcnt, opname, cycles_##test, cycles_##test == 0 ? 0 : (double)(opcnt)/cycles_##test, \ + opname, cycles_##test == 0 ? 0 : 1/((double)(opcnt)/cycles_##test), opname); } + +//----------------------------------------------------- +// log2(BBE_SIMD_WIDTH) +//----------------------------------------------------- +#define LOG2_IVP_SIMD_WIDTH 5 +#define ALIGN_SIMD ALIGN(64) +#define ALIGN_2SIMD ALIGN(128) + +#define LOG2_SIMD_N_2 (LOG2_IVP_SIMD_WIDTH - 1) +#define LOG2_SIMD_2N (LOG2_IVP_SIMD_WIDTH + 1) +//----------------------------------------------------- +// some C++ support +//----------------------------------------------------- + +// special XCC type casting of pointers +#ifdef __cplusplus +#define castxcc(type_, ptr) (ptr) +#else +#define castxcc(type_, ptr) (type_ *)(ptr) +#endif + +//----------------------------------------------------- +// C99 pragma wrapper +//----------------------------------------------------- + +#ifdef COMPILER_XTENSA +#define __Pragma(a) _Pragma(a) +#else +#define __Pragma(a) +#endif + +//----------------------------------------------------- +// Conditionalization support +//----------------------------------------------------- +/* place DISCARD_FUN(retval_type,name) instead of function definition for + functions to be discarded from the executable THIS WORKS only for external + library functions declared as extern "C" and not supported for internal + references without "C" qualifier! +*/ +#ifdef COMPILER_MSVC +#pragma section("$DISCARDED_FUNCTIONS", execute, discard) +#pragma section("$$$$$$$$$$", execute, discard) +#define DISCARD_FUN(retval_type, name, arglist) \ + __pragma(alloc_text("$DISCARDED_FUNCTIONS", name)) \ + __pragma(section("$DISCARDED_FUNCTIONS", execute, discard)) \ + __pragma(warning(push)) __pragma(warning(disable : 4026 4716)) \ + retval_type name arglist {} \ + __pragma(warning(pop)) +#endif + +#if defined(COMPILER_XTENSA) || defined(COMPILER_GNU) +#define DISCARD_FUN(retval_type, name, arglist) \ + __asm__(".type " #name ", @object\n\t.global " #name \ + "\n\t.align 4\n\t" #name ":\n\t.long 0x49438B96,0x4D73F192\n\t"); +#endif + +/*------ LIST OF DEFINES DEPENDING ON ISA OPTIONS ------*/ + +/* Single-precision Extended Vector Floating-point option */ +#if ((XCHAL_HAVE_VISION_SP_VFPU)) +#define HAVE_SPX_VFPU 1 +#else +#define HAVE_SPX_VFPU 0 +#endif + +/* all vector single precision/Extended vector floating point instructions */ +#if ((XCHAL_HAVE_VISION_SP_VFPU)) +#define HAVE_SPX_VFPU 1 +#define HAVE_VFPU 1 +#else +#define HAVE_SPX_VFPU 0 +#define HAVE_VFPU 0 +#endif + +/* all scalar single precision floating point instructions */ +#if ((XCHAL_HAVE_VISION_SP_VFPU) || (XCHAL_HAVE_FP)) +#define HAVE_FPU 1 +#else +#define HAVE_FPU 0 +#endif + +#else +#define HAVE_VFPU 0 +#define HAVE_FPU 0 +#endif + +/* detect if half precision FPU is present in a core */ +#if ((XCHAL_HAVE_VISION_HP_VFPU)) +#define HAVE_HPFPU 1 +#include +#else +#define HAVE_HPFPU 0 +#endif + +/* detect if double precision FPU is present in a core */ +#if ((XCHAL_HAVE_VISION_DP_VFPU)) +#define HAVE_DPFPU 1 +#include +#else +#define HAVE_DPFPU 0 +#endif + +/* + 32x32 multiplier +*/ +#if defined(BBE_MULN_2X32) +#define HAVE_32X32 1 +#else +#define HAVE_32X32 0 +#endif + +#ifdef __cplusplus +#define externC extern "C" +#else +#define externC extern +#endif + +#endif // __COMMON_H__ diff --git a/backends/cadence/vision/third-party/include_private/expf_tbl.h b/backends/cadence/vision/third-party/include_private/expf_tbl.h new file mode 100644 index 00000000000..702164aba11 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/expf_tbl.h @@ -0,0 +1,53 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + tables for expf(x) approximation +*/ +#ifndef __EXPF_TBL_H__ +#define __EXPF_TBL_H__ + +/* Portable data types. */ +#include "dtypes.h" +#include "common.h" + +/* + polynomial coefficients for 2^x in range 0...1 + + derived by MATLAB code: + order=6; + x=(0:pow2(1,-16):1); + y=2.^x; + p=polyfit(x,y,6); + p(order+1)=1; + p(order)=p(order)-(sum(p)-2); +*/ +externC const int32_t expftbl_Q30[8]; +externC const union ufloat32uint32 + expfminmax[2]; /* minimum and maximum arguments of expf() input */ +externC const int32_t invln2_Q30; /* 1/ln(2), Q30 */ +externC const union ufloat32uint32 expftblf[7]; +externC const union ufloat32uint32 log2_e[2]; +#endif /* __EXPF_TBL_H__ */ diff --git a/backends/cadence/vision/third-party/include_private/idma_init.h b/backends/cadence/vision/third-party/include_private/idma_init.h new file mode 100644 index 00000000000..ee0666842fd --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/idma_init.h @@ -0,0 +1,31 @@ +#ifndef __IDMA__INIT_H__ +#define __IDMA__INIT_H__ + +#include "dtypes.h" +#include "common.h" + +#define IDMA_BUFF_SIZE 16384 // 16 kb DRAM storage. Assume 4 buffers (2 input and 2 output) + +#ifndef PLACE_IN_DRAM0 + #define PLACE_IN_DRAM0 __attribute__ ((aligned(2*IVP_SIMD_WIDTH), section(".dram0.data"))) +#endif + +#ifndef PLACE_IN_DRAM1 + #define PLACE_IN_DRAM1 __attribute__ ((aligned(2*IVP_SIMD_WIDTH), section(".dram1.data"))) +#endif + +float32_t data_dram0[IDMA_BUFF_SIZE / 2] PLACE_IN_DRAM0; +float32_t data_dram1[IDMA_BUFF_SIZE / 2] PLACE_IN_DRAM1; + +float32_t *inpData[2] = {&data_dram0[0], &data_dram1[0]}; +float32_t *outData[2] = {&data_dram0[IDMA_BUFF_SIZE / 4], &data_dram1[IDMA_BUFF_SIZE / 4]}; + +IDMA_BUFFER_DEFINE(buffer_idma_ch0, 1, IDMA_2D_DESC); +IDMA_BUFFER_DEFINE(buffer_idma_ch1, 1, IDMA_2D_DESC); + +idma_buffer_t * descbuf[] = { + buffer_idma_ch0, + buffer_idma_ch1, +}; + +#endif // __IDMA__INIT_H__ \ No newline at end of file diff --git a/backends/cadence/vision/third-party/include_private/inff_tbl.h b/backends/cadence/vision/third-party/include_private/inff_tbl.h new file mode 100644 index 00000000000..1326e92a3c1 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/inff_tbl.h @@ -0,0 +1,39 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + Infinities for single precision routines +*/ +#ifndef __INFF_TBL_H__ +#define __INFF_TBL_H__ + +#include "dtypes.h" +#include "common.h" + +externC const union ufloat32uint32 minusInff; /* -Inf */ +externC const union ufloat32uint32 plusInff; /* +Inf */ +externC const union ufloat32uint32 realmaxf; /* maximum floating point number */ +externC const union ufloat32uint32 realminf; /* minimum floating point number */ +#endif /* __INFF_TBL_H__ */ diff --git a/backends/cadence/vision/third-party/include_private/nanf_tbl.h b/backends/cadence/vision/third-party/include_private/nanf_tbl.h new file mode 100644 index 00000000000..4881b99f070 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/nanf_tbl.h @@ -0,0 +1,42 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + NaN values for single precision routines +*/ + +#ifndef __NANF_TBL_H__ +#define __NANF_TBL_H__ + +/* Portable data types. */ +#include "dtypes.h" +/* Common utility macros. */ +#include "common.h" + +extern const union ufloat32uint32 sNaNf; /* Signalling NaN */ +extern const union ufloat32uint32 qNaNf; /* Quiet NaN */ +extern const union ufloat32uint32 minus_sNaNf; /* Negative Signalling NaN */ +extern const union ufloat32uint32 minus_qNaNf; /* Negative Quiet NaN */ + +#endif /* __NANF_TBL_H__ */ diff --git a/backends/cadence/vision/third-party/library/api/tensor_transposef.c b/backends/cadence/vision/third-party/library/api/tensor_transposef.c new file mode 100644 index 00000000000..e6865033740 --- /dev/null +++ b/backends/cadence/vision/third-party/library/api/tensor_transposef.c @@ -0,0 +1,167 @@ +#include "api.h" +#include "common.h" + +/* + * Currently only supports upto 5D input tensors. + * 1/2/3/4 D input tensors will be scaled up to 5D. + * For example, 2x3 -> 1x1x1x2x3. + */ + +void tensor_transposef(float32_t *restrict ptr_out + ,const int *const ptr_out_shape + ,const float32_t *restrict ptr_inp + ,const int *const ptr_inp_shape + ,const int *restrict ptr_permute_vec + ,int num_out_dims + ,int num_inp_dims) +{ + + /* Shift all dim with 1 in the outer part */ + int eff_output_shape[5]; + int eff_permute_vec[5]; + + for (int i = 0; i < num_out_dims; i++){ + eff_output_shape[i] = ptr_out_shape[i]; + eff_permute_vec[i] = ptr_permute_vec[i]; + } + + int one_i = num_out_dims - 1, non_one_i = num_out_dims - 1; + while (one_i > 0 && non_one_i >= 0){ + while (one_i > 0 && eff_output_shape[one_i] != 1){ + one_i--; + } + non_one_i = one_i; + while (non_one_i >= 0 && eff_output_shape[non_one_i]==1){ + non_one_i--; + } + if (one_i > 0 && non_one_i >= 0){ + int temp; + /*swap output_shape*/ + { + temp = eff_output_shape[one_i]; + eff_output_shape[one_i] = eff_output_shape[non_one_i]; + eff_output_shape[non_one_i] = temp; + } + /*swap permute_vec*/ + { + temp = eff_permute_vec[one_i]; + eff_permute_vec[one_i] = eff_permute_vec[non_one_i]; + eff_permute_vec[non_one_i] = temp; + } + } + } + + /* Promoting lesser dim tensors to 5D tensors. + * Also updating the permute_vec and shapes as needed for optimization */ + int ptr_5D_inp_shape[5] = {1, 1, 1, 1, 1}; + int ptr_5D_out_shape[5] = {1, 1, 1, 1, 1}; + int ptr_5D_permute_vec[5] = {0, 1, 2, 3, 4}; + + /* Check if any inner inp dimension is same in the output */ + int last_dim_same = 1, last_n_same_dim = 0; + int itr = num_inp_dims - 1; + while(itr >= 0){ + last_n_same_dim = (last_dim_same && (eff_permute_vec[itr] == itr)) ? (last_n_same_dim + 1) : last_n_same_dim; + last_dim_same = (eff_permute_vec[itr] == itr) ? last_dim_same & 1 : last_dim_same & 0; + itr--; + } + + int dims_added = 5 - num_inp_dims; + itr = num_inp_dims - 1; + int same_count = last_n_same_dim; + int count = 4; + while(itr >= 0){ + ptr_5D_inp_shape[count] = (same_count > 0) ? ptr_5D_inp_shape[count] * ptr_inp_shape[itr] : ptr_inp_shape[itr]; + ptr_5D_out_shape[count] = (same_count > 0) ? ptr_5D_out_shape[count] * eff_output_shape[itr] : eff_output_shape[itr]; + same_count--; + itr--; + count = (same_count > 0) ? count : count - 1; + } + + itr = num_inp_dims - 1; + same_count = (last_n_same_dim) ? num_inp_dims - (last_n_same_dim - 1) : 0; + count = 4; + while(itr >= 0){ + ptr_5D_permute_vec[count] = (same_count > 0) ? eff_permute_vec[itr-(last_n_same_dim - 1)] + dims_added + last_n_same_dim - 1 : eff_permute_vec[itr] + dims_added; + same_count--; + itr--; + count--; + } + + int out_dim0, out_dim1, out_dim2, out_dim3, out_dim4; + int inp_dim1, inp_dim2, inp_dim3, inp_dim4; + int inp_stride[5]; + + out_dim0 = ptr_5D_out_shape[0]; + out_dim1 = ptr_5D_out_shape[1]; + out_dim2 = ptr_5D_out_shape[2]; + out_dim3 = ptr_5D_out_shape[3]; + out_dim4 = ptr_5D_out_shape[4]; + + inp_dim1 = ptr_5D_inp_shape[1]; + inp_dim2 = ptr_5D_inp_shape[2]; + inp_dim3 = ptr_5D_inp_shape[3]; + inp_dim4 = ptr_5D_inp_shape[4]; + + inp_stride[0] = inp_dim1 * inp_dim2 * inp_dim3 * inp_dim4; + inp_stride[1] = inp_dim2 * inp_dim3 * inp_dim4; + inp_stride[2] = inp_dim3 * inp_dim4; + inp_stride[3] = inp_dim4; + inp_stride[4] = 1; + + if (last_n_same_dim){ + int itr0, itr1, itr2, itr3, itr4; + float32_t *ptr_inp0 = (float32_t *)ptr_inp; + for (itr0 = 0; itr0 < out_dim0; itr0++){ + float32_t *ptr_inp1 = ptr_inp0 + (itr0 * inp_stride[ptr_5D_permute_vec[0]]); +#pragma looptr_count min=1 + for (itr1 = 0; itr1 < out_dim1; itr1++){ + float32_t *ptr_inp2 = ptr_inp1 + (itr1 * inp_stride[ptr_5D_permute_vec[1]]); +#pragma looptr_count min=1 + for (itr2 = 0; itr2 < out_dim2; itr2++){ + float32_t *ptr_inp3 = ptr_inp2 + (itr2 * inp_stride[ptr_5D_permute_vec[2]]); +#pragma looptr_count min=1 + for (itr3 = 0; itr3 < out_dim3; itr3++, ptr_out += out_dim4){ + float32_t *ptr_inp4 = ptr_inp3 + (itr3 * inp_stride[ptr_5D_permute_vec[3]]); + xb_vecN_2xf32 *restrict pae_i = (xb_vecN_2xf32 *)(ptr_inp4); + xb_vecN_2xf32 *restrict pae_o = (xb_vecN_2xf32 *)(ptr_out); + valign a_inp = IVP_LAN_2XF32_PP(pae_i); + valign a_out = IVP_ZALIGN(); + xb_vecN_2xf32 d0; + for(itr4 = 0; itr4 < (out_dim4 >> (LOG2_IVP_SIMD_WIDTH - 1)); itr4++){ + IVP_LAN_2XF32_IP(d0, a_inp, pae_i); + IVP_SAN_2XF32_IP(d0, a_out, pae_o); + } + IVP_SAPOSN_2XF32_FP(a_out, pae_o); + float32_t *restrict puae_i = (float32_t *)(pae_i); + float32_t *restrict puae_o = (float32_t *)(pae_o); +#pragma looptr_count max = 17 + for(itr4 = 0; itr4 < (out_dim4 & (IVP_SIMD_WIDTH / 2 - 1)); itr4++){ + puae_o[itr4] = puae_i[itr4]; + } + } + } + } + } + } + else{ + int itr0, itr1, itr2, itr3, itr4; + float32_t *ptr_inp0 = (float32_t *)ptr_inp; + for(itr0 = 0; itr0 < out_dim0; itr0++){ + float32_t *ptr_inp1 = ptr_inp0 + (itr0 * inp_stride[ptr_5D_permute_vec[0]]); + for(itr1 = 0; itr1 < out_dim1; itr1++){ + float32_t *ptr_inp2 = ptr_inp1 + (itr1 * inp_stride[ptr_5D_permute_vec[1]]); + for(itr2 = 0; itr2 < out_dim2; itr2++){ + float32_t *ptr_inp3 = ptr_inp2 + (itr2 * inp_stride[ptr_5D_permute_vec[2]]); + for(itr3 = 0; itr3 < out_dim3; itr3++){ + float32_t *ptr_inp4 = ptr_inp3 + (itr3 * inp_stride[ptr_5D_permute_vec[3]]); + for(itr4 = 0; itr4 < out_dim4; itr4++){ + *ptr_out++ = *ptr_inp4; + ptr_inp4 = ptr_inp4 + inp_stride[ptr_5D_permute_vec[4]]; + } + } + } + } + } + } +} diff --git a/backends/cadence/vision/third-party/library/api/vsoftmaxf.c b/backends/cadence/vision/third-party/library/api/vsoftmaxf.c new file mode 100644 index 00000000000..413b6f10567 --- /dev/null +++ b/backends/cadence/vision/third-party/library/api/vsoftmaxf.c @@ -0,0 +1,241 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + NatureDSP_Baseband library. Vector Mathematics. + Softmax, floating-point data +*/ +#include "api.h" +#include "common.h" +#include "expf_tbl.h" +#include "inff_tbl.h" +#include "nanf_tbl.h" + +/*------------------------------------------------------------------------- +Softmax + +Description: The function computes the softmax (normalized exponential +function) of input data. 16-bit fixed-point functions accept inputs in +Q3.12 and form outputs in Q7.8 format. + +vsoftmax 16-bit +vsoftmax_fp16 IEEE-754 Std. half precision floating-point. +vsoftmaxf IEEE-754 Std. single precision floating-point. + +Accuracy: +2 LSB for fixed point API +2 ULP for floating point API +NOTE: Accuracy of function may depend on amount of data and their +distribution. Given accuracy is achieved for N=2 for any pair of +data from input domain. + + +Parameters: +Input +: +x[N] input data, Q3.12 floating point +N Length of input/output data vectors +Output: +y[N] result, Q7.8 or floating point + +Restrictions: +x,y Must not overlap +-------------------------------------------------------------------------*/ + +#define IVP_ADDSN_2X32(b_, c_) \ + ({ \ + xb_vecN_2x32v a_; \ + xb_vecN_2x64w tmp_a_; \ + tmp_a_ = IVP_MULN_2X32(b_, 1); \ + IVP_MULAN_2X32(tmp_a_, c_, 1); \ + a_ = IVP_PACKVRN_2X64W(tmp_a_, 0); \ + a_; \ + }) + +#if !HAVE_VFPU +DISCARD_FUN(void, vsoftmaxf, (float32_t * y, const float32_t *x, int N)) +#else +void vsoftmaxf(float32_t *y, const float32_t *x, int N) { +#if !defined(IVP_MULN_2X32) +#else + const int *pTbl = (const int *)expftbl_Q30; +#endif + const xb_vecN_2xf32 *restrict pX; + xb_vecN_2xf32 *restrict pY; + xb_vecN_2xf32 norm, ysum, xmax; + int n; + valign al_X, al_R, al_Y; + if (N < 0) + return; + xmax = minusInff.f; + pX = (const xb_vecN_2xf32 *)x; + al_X = IVP_LAN_2XF32_PP(pX); + al_Y = IVP_ZALIGN(); + for (n = 0; n < (N >> (LOG2_IVP_SIMD_WIDTH - 1)); n++) { + xb_vecN_2xf32 x; + IVP_LAN_2XF32_IP(x, al_X, pX); + xmax = IVP_MAXNUMN_2XF32(xmax, x); + } + if (N & (IVP_SIMD_WIDTH / 2 - 1)) { + xb_vecN_2xf32 x; + IVP_LAVN_2XF32_XP(x, al_X, pX, + sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + IVP_MAXNUMN_2XF32T(xmax, xmax, x, + IVP_LTRSN_2((N & (IVP_SIMD_WIDTH / 2 - 1)))); + } + + xmax = IVP_REPN_2XF32(IVP_RMAXNUMN_2XF32(xmax), 0); + __Pragma("no_reorder"); + ysum = 0.f; + pX = (const xb_vecN_2xf32 *)x; + pY = (xb_vecN_2xf32 *)y; + al_X = IVP_LAN_2XF32_PP(pX); + { + vboolN_2 bnan; + bnan = IVP_LTRN_2I(0); + for (n = 0; n < (N >> (LOG2_IVP_SIMD_WIDTH - 1)); n++) { + xb_vecN_2xf32 x; + IVP_LAN_2XF32_IP(x, al_X, pX); + x = IVP_SUBN_2XF32(x, xmax); + bnan |= IVP_UNN_2XF32(x, x); + { + xb_vecN_2xf32 gf, zout; + xb_vecN_2x32v xin_i, fr, exp, t; + xb_vecN_2x32v y, y1, y2, c1, c2, f2; + xb_vecN_2x64w w; + xin_i = IVP_TRUNCN_2XF32(x, 24); + /* Multiply by 1/ln2, extract the integer and fractional (Q32) + * components. */ + /* Q54 <- Q24*Q30 */ + w = IVP_MULN_2X32(xin_i, invln2_Q30); + exp = IVP_PACKVRNRN_2X64W(w, 54); + fr = IVP_SRLN_2X32(IVP_PACKVRNRN_2X64W(w, 22), 1); + /* polynomial for 2^x */ + f2 = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, fr), 31); + y1 = IVP_LSRN_2X32_I(pTbl, 0 * sizeof(int32_t)); + y2 = IVP_LSRN_2X32_I(pTbl, 1 * sizeof(int32_t)); + c1 = IVP_LSRN_2X32_I(pTbl, 2 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 3 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 4 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 5 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 6 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, y2), 31); + y = IVP_ADDSN_2X32(y1, t); + /* scale result to original exponent ignoring very low items */ + gf = IVP_FLOATN_2X32(y, 30); + exp = IVP_SLLIN_2X32(IVP_MAXN_2X32(IVP_ADDN_2X32(127, exp), 0), 23); + zout = IVP_MULN_2XF32(gf, IVP_MOVN_2XF32_FROMN_2X32(exp)); + x = zout; + } + ysum = IVP_ADDN_2XF32(ysum, x); + IVP_SAN_2XF32_IP(x, al_Y, pY); + } + if (N & (IVP_SIMD_WIDTH / 2 - 1)) { + xb_vecN_2xf32 x; + IVP_LAVN_2XF32_XP(x, al_X, pX, + sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + x = IVP_SUBN_2XF32(x, xmax); + bnan |= IVP_UNN_2XF32(x, x); + { + xb_vecN_2xf32 gf, zout; + xb_vecN_2x32v xin_i, fr, exp, t; + xb_vecN_2x32v y, y1, y2, c1, c2, f2; + xb_vecN_2x64w w; + xin_i = IVP_TRUNCN_2XF32(x, 24); + /* Multiply by 1/ln2, extract the integer and fractional (Q32) + * components. */ + /* Q54 <- Q24*Q30 */ + w = IVP_MULN_2X32(xin_i, invln2_Q30); + exp = IVP_PACKVRNRN_2X64W(w, 54); + fr = IVP_SRLN_2X32(IVP_PACKVRNRN_2X64W(w, 22), 1); + /* polynomial for 2^x */ + f2 = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, fr), 31); + y1 = IVP_LSRN_2X32_I(pTbl, 0 * sizeof(int32_t)); + y2 = IVP_LSRN_2X32_I(pTbl, 1 * sizeof(int32_t)); + c1 = IVP_LSRN_2X32_I(pTbl, 2 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 3 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 4 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 5 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 6 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, y2), 31); + y = IVP_ADDSN_2X32(y1, t); + /* scale result to original exponent ignoring very low items */ + gf = IVP_FLOATN_2X32(y, 30); + exp = IVP_SLLIN_2X32(IVP_MAXN_2X32(IVP_ADDN_2X32(127, exp), 0), 23); + zout = IVP_MULN_2XF32(gf, IVP_MOVN_2XF32_FROMN_2X32(exp)); + x = zout; + } + IVP_ADDN_2XF32T(ysum, ysum, x, + IVP_LTRSN_2((N & (IVP_SIMD_WIDTH / 2 - 1)))); + IVP_SAVN_2XF32_XP(x, al_Y, pY, + sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + } + IVP_SAPOSN_2XF32_FP(al_Y, pY); + ysum = IVP_MOVN_2XF32T(qNaNf.f, ysum, bnan); + } + norm = XT_RECIP_S(IVP_RADDN_2XF32(ysum)); + __Pragma("no_reorder"); + pX = (const xb_vecN_2xf32 *)y; + pY = (xb_vecN_2xf32 *)y; + + al_R = IVP_LAN_2XF32_PP(pX); + + for (n = 0; n < (N >> (LOG2_IVP_SIMD_WIDTH - 1)); n++) { + xb_vecN_2xf32 x; + IVP_LAN_2XF32_IP(x, al_R, pX); + x = IVP_MULN_2XF32(x, norm); + IVP_SAN_2XF32_IP(x, al_Y, pY); + } + if (N & (IVP_SIMD_WIDTH / 2 - 1)) { + xb_vecN_2xf32 x; + IVP_LAVN_2XF32_XP(x, al_R, pX, + sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + x = IVP_MULN_2XF32(x, norm); + IVP_SAVN_2XF32_XP(x, al_Y, pY, + sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + } + IVP_SAPOSN_2XF32_FP(al_Y, pY); + +} /* vsoftmaxf() */ +#endif diff --git a/backends/cadence/vision/third-party/library/tables/expf_tbl.c b/backends/cadence/vision/third-party/library/tables/expf_tbl.c new file mode 100644 index 00000000000..0ed5dd22257 --- /dev/null +++ b/backends/cadence/vision/third-party/library/tables/expf_tbl.c @@ -0,0 +1,74 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + tables for expf(x) approximation +*/ +/* Portable data types. */ +#include "expf_tbl.h" +#include "dtypes.h" + +/* + polynomial coefficients for 2^x in range 0...1 + + derived by MATLAB code: + order=6; + x=(0:pow2(1,-16):1); + y=2.^x; + p=polyfit(x,y,6); + p(order+1)=1; + p(order)=p(order)-(sum(p)-2); +*/ +const int32_t ALIGN_2SIMD expftbl_Q30[8] = { + 234841, 1329551, 10400465, 59570027, + 257946177, 744260763, 1073741824, 0 /* Padding to allow for vector loads */ +}; + +const union ufloat32uint32 ALIGN_2SIMD + expfminmax[2] = /* minimum and maximum arguments of expf() input */ + { + {0xc2ce8ed0}, /*-1.0327893066e+002f */ + {0x42b17218} /* 8.8722839355e+001f */ +}; + +const int32_t invln2_Q30 = 1549082005L; /* 1/ln(2), Q30 */ + +const union ufloat32uint32 ALIGN_2SIMD log2_e[2] = { + {0x3fb8aa3b}, /* 1.4426950216 */ + {0x32a57060} /* 1.9259629891e-008 */ +}; + +/* +order=6; +x=(0:pow2(1,-16):1); +y=2.^x; +p=polyfit(x,y,order); +p(order+1)=1; +p(order)=p(order)-(sum(p)-2); +num2hex(single(p)); +*/ +const union ufloat32uint32 ALIGN_2SIMD expftblf[] = { + {0x39655635}, {0x3aa24c7a}, {0x3c1eb2d1}, {0x3d633ddb}, + {0x3e75ff24}, {0x3f317212}, {0x3f800000}}; diff --git a/backends/cadence/vision/third-party/library/tables/inff_tbl.c b/backends/cadence/vision/third-party/library/tables/inff_tbl.c new file mode 100644 index 00000000000..9b2bf62e6bf --- /dev/null +++ b/backends/cadence/vision/third-party/library/tables/inff_tbl.c @@ -0,0 +1,38 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + infinities for single precision routines +*/ + +#include "inff_tbl.h" +#include "dtypes.h" + +const union ufloat32uint32 minusInff = {0xff800000}; /* -Inf */ +const union ufloat32uint32 plusInff = {0x7f800000}; /* +Inf */ +const union ufloat32uint32 realmaxf = { + 0x7f7fffff}; /* maximum floating point number */ +const union ufloat32uint32 realminf = { + 0x00800000}; /* minimum floating point number */ diff --git a/backends/cadence/vision/third-party/library/tables/nanf_tbl.c b/backends/cadence/vision/third-party/library/tables/nanf_tbl.c new file mode 100644 index 00000000000..27c5f437b9a --- /dev/null +++ b/backends/cadence/vision/third-party/library/tables/nanf_tbl.c @@ -0,0 +1,38 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + NaN values for single precision routines +*/ + +/* Portable data types. */ +#include "dtypes.h" +/* NaN values for single precision routines. */ +#include "nanf_tbl.h" + +const union ufloat32uint32 sNaNf = {0x7f800001}; /* Signalling NaN */ +const union ufloat32uint32 qNaNf = {0x7fc00000}; /* Quiet NaN */ +const union ufloat32uint32 minus_sNaNf = { + 0xff800001}; /* Negative Signalling NaN */ +const union ufloat32uint32 minus_qNaNf = {0xffc00000}; /* Negative Quiet NaN */ diff --git a/backends/cadence/vision/third-party/targets.bzl b/backends/cadence/vision/third-party/targets.bzl new file mode 100644 index 00000000000..6bbb7da8d49 --- /dev/null +++ b/backends/cadence/vision/third-party/targets.bzl @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//arvr/tools/build_defs:oxx.bzl", "oxx_binary", "oxx_static_library") + + +def define_common_targets(): + runtime.cxx_library( + name = "vision-nnlib", + srcs = select({ + "DEFAULT": ["dummy.c"], # Use dummy file for non-Xtensa builds + "ovr_config//cpu:xtensa": glob(["library/**/*.c"]), + }), + exported_headers = glob([ + "include/*.h", + "include_private/*.h" + ]), + header_namespace = "backends/cadence/vision/third-party", + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + platforms = CXX, + compatible_with = select({ + "DEFAULT": [], + "ovr_config//cpu:xtensa": ["ovr_config//cpu:xtensa"], + }), + compiler_flags = select({ + "DEFAULT": ["-UCOMPILER_XTENSA"], # Ensure COMPILER_XTENSA is not defined for non-Xtensa builds + "ovr_config//cpu:xtensa": ["-DCOMPILER_XTENSA"], + }), + define_static_target = True, + )