From 9090b2cbd3a56db9aaf23ee31eac6fe2575662bb Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Mon, 1 Dec 2025 08:56:00 -0800 Subject: [PATCH 1/4] Fix -Wimplicit-const-int-float-conversion error. (#15982) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/15982 as titled Differential Revision: D87882617 Reviewed By: skrtskrtfb, zonglinpeng --- backends/cadence/vision/kernels/kernels.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backends/cadence/vision/kernels/kernels.cpp b/backends/cadence/vision/kernels/kernels.cpp index 70c811df741..66cfcadbf13 100644 --- a/backends/cadence/vision/kernels/kernels.cpp +++ b/backends/cadence/vision/kernels/kernels.cpp @@ -11,7 +11,6 @@ #include #include #include -#include namespace impl { namespace vision { @@ -25,10 +24,10 @@ void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { // Quantize a fp32 value to an int8_t/uint8_t value template T quantize(const float x, float scale, int32_t zero_point) { - constexpr float min_val = std::numeric_limits::min(); - constexpr float max_val = std::numeric_limits::max(); + constexpr float kMinValue = static_cast(std::numeric_limits::min()); + constexpr float kMaxValue = static_cast(std::numeric_limits::max()); float tmp = roundf(x * scale + zero_point); - return std::max(std::min(tmp, max_val), min_val); + return std::max(std::min(tmp, kMaxValue), kMinValue); } // Quantize an fp32 array to an int8_t/uint8_t array From 1d14aa98a04a609d21f12b10e657c4ddc6acf148 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Mon, 1 Dec 2025 09:23:00 -0800 Subject: [PATCH 2/4] Add generic operator implementations. (#15983) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/15983 Migrates cadence custom operators to oss. Differential Revision: D87880917 Reviewed By: ethansfng --- .../generic/operators/op_linalg_svd.cpp | 365 ++++++++++++++++++ .../cadence/generic/operators/op_linalg_svd.h | 36 ++ .../operators/op_roi_align_box_processor.cpp | 179 +++++++++ .../operators/op_roi_align_box_processor.h | 29 ++ .../cadence/generic/operators/op_rope.cpp | 65 ++++ backends/cadence/generic/operators/op_rope.h | 28 ++ .../generic/operators/op_where_scalar.cpp | 33 ++ .../generic/operators/op_where_scalar.h | 27 ++ .../cadence/generic/operators/targets.bzl | 65 ++++ 9 files changed, 827 insertions(+) create mode 100644 backends/cadence/generic/operators/op_linalg_svd.cpp create mode 100644 backends/cadence/generic/operators/op_linalg_svd.h create mode 100644 backends/cadence/generic/operators/op_roi_align_box_processor.cpp create mode 100644 backends/cadence/generic/operators/op_roi_align_box_processor.h create mode 100644 backends/cadence/generic/operators/op_rope.cpp create mode 100644 backends/cadence/generic/operators/op_rope.h create mode 100644 backends/cadence/generic/operators/op_where_scalar.cpp create mode 100644 backends/cadence/generic/operators/op_where_scalar.h diff --git a/backends/cadence/generic/operators/op_linalg_svd.cpp b/backends/cadence/generic/operators/op_linalg_svd.cpp new file mode 100644 index 00000000000..e84a8914f35 --- /dev/null +++ b/backends/cadence/generic/operators/op_linalg_svd.cpp @@ -0,0 +1,365 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +const float EPSILON = 1e-10; +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +// A simple 3x3 matrix struct. +struct Matrix3x3 { + float m[3][3]; +}; + +// Returns the 3x3 identity matrix. +Matrix3x3 identityMatrix() { + Matrix3x3 I{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + I.m[i][j] = (i == j) ? 1.0 : 0.0; + } + } + return I; +} + +// Transposes matrix A. +Matrix3x3 transpose(const Matrix3x3& A) { + Matrix3x3 At{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + At.m[i][j] = A.m[j][i]; + } + } + return At; +} + +// Multiplies matrices A and B. +Matrix3x3 multiply(const Matrix3x3& A, const Matrix3x3& B) { + Matrix3x3 C{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + C.m[i][j] = 0.0; + for (int k = 0; k < 3; k++) { + C.m[i][j] += A.m[i][k] * B.m[k][j]; + } + } + } + return C; +} + +// Jacobi method to compute the eigen-decomposition of a symmetric 3x3 matrix A. +// It outputs the eigenvalues (in 'diag') and the eigenvectors as columns in V. +void jacobiEigenDecomposition(const Matrix3x3& A, float diag[3], Matrix3x3& V) { + Matrix3x3 D = A; // Make a copy; D will be transformed into a diagonal matrix. + V = identityMatrix(); + + // Iterate until convergence (or max iterations) + for (int iter = 0; iter < 100; iter++) { + // Find the largest off-diagonal element in D. + int p = 0, q = 1; + float maxOff = std::fabs(D.m[0][1]); + if (std::fabs(D.m[0][2]) > maxOff) { + maxOff = std::fabs(D.m[0][2]); + p = 0; + q = 2; + } + if (std::fabs(D.m[1][2]) > maxOff) { + maxOff = std::fabs(D.m[1][2]); + p = 1; + q = 2; + } + + if (maxOff < EPSILON) { + break; + } + + // Compute the Jacobi rotation angle. + float theta = 0.0; + if (std::fabs(D.m[p][p] - D.m[q][q]) < EPSILON) { + theta = M_PI / 4.0; + } else { + theta = 0.5 * std::atan2(2 * D.m[p][q], D.m[q][q] - D.m[p][p]); + } + + float c = std::cos(theta); + float s = std::sin(theta); + + // Update the diagonal elements. + float D_pp = c * c * D.m[p][p] - 2 * s * c * D.m[p][q] + s * s * D.m[q][q]; + float D_qq = s * s * D.m[p][p] + 2 * s * c * D.m[p][q] + c * c * D.m[q][q]; + D.m[p][p] = D_pp; + D.m[q][q] = D_qq; + D.m[p][q] = D.m[q][p] = 0.0; + + // Update the remaining elements. + for (int j = 0; j < 3; j++) { + if (j != p && j != q) { + float D_pj = c * D.m[p][j] - s * D.m[q][j]; + float D_qj = s * D.m[p][j] + c * D.m[q][j]; + D.m[p][j] = D.m[j][p] = D_pj; + D.m[q][j] = D.m[j][q] = D_qj; + } + } + + // Update the eigenvector matrix V. + for (int i = 0; i < 3; i++) { + float V_ip = c * V.m[i][p] - s * V.m[i][q]; + float V_iq = s * V.m[i][p] + c * V.m[i][q]; + V.m[i][p] = V_ip; + V.m[i][q] = V_iq; + } + } + + diag[0] = D.m[0][0]; + diag[1] = D.m[1][1]; + diag[2] = D.m[2][2]; +} + +// Sorts the eigenvalues (and the corresponding eigenvectors in V) in descending +// order. +void sortEigenDecomposition(float eigenvalues[3], Matrix3x3& V) { + int indices[3] = {0, 1, 2}; + std::sort(indices, indices + 3, [&](int a, int b) { + return eigenvalues[a] > eigenvalues[b]; + }); + + float sortedEigenvalues[3]; + Matrix3x3 sortedV{}; + for (int i = 0; i < 3; i++) { + sortedEigenvalues[i] = eigenvalues[indices[i]]; + for (int j = 0; j < 3; j++) { + sortedV.m[j][i] = V.m[j][indices[i]]; + } + } + for (int i = 0; i < 3; i++) { + eigenvalues[i] = sortedEigenvalues[i]; + for (int j = 0; j < 3; j++) { + V.m[j][i] = sortedV.m[j][i]; + } + } +} + +// Computes the cross product of two 3D vectors. +void crossProduct(const float a[3], const float b[3], float result[3]) { + result[0] = a[1] * b[2] - a[2] * b[1]; + result[1] = a[2] * b[0] - a[0] * b[2]; + result[2] = a[0] * b[1] - a[1] * b[0]; +} + +// Normalizes a 3D vector. +void normalize(float v[3]) { + float norm = std::sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + if (norm > EPSILON) { + v[0] /= norm; + v[1] /= norm; + v[2] /= norm; + } +} + +// Computes the singular value decomposition of A such that A = U * S * Vt. +// U and Vt are orthogonal matrices and S is a diagonal matrix with singular +// values. +std::tuple svd(const Matrix3x3& A) { + // Compute A^T * A (which is symmetric). + Matrix3x3 At = transpose(A); + Matrix3x3 ATA = multiply(At, A); + + // Compute the eigen-decomposition of ATA. + float eigenvalues[3]; + Matrix3x3 V{}; + jacobiEigenDecomposition(ATA, eigenvalues, V); + sortEigenDecomposition(eigenvalues, V); + + // The singular values are the square roots of the eigenvalues. + float sigma[3]; + for (int i = 0; i < 3; i++) { + sigma[i] = (eigenvalues[i] > 0.0) ? std::sqrt(eigenvalues[i]) : 0.0; + } + + // Compute U = A * V * S^{-1}. + Matrix3x3 U{}; + for (int i = 0; i < 3; i++) { + float av[3] = {0, 0, 0}; + // Multiply A by the i-th eigenvector (column of V). + for (int r = 0; r < 3; r++) { + for (int c = 0; c < 3; c++) { + av[r] += A.m[r][c] * V.m[c][i]; + } + } + if (sigma[i] > EPSILON) { + for (int r = 0; r < 3; r++) { + U.m[r][i] = av[r] / sigma[i]; + } + } else { + // If sigma[i] is nearly zero, choose a vector orthogonal to the previous + // ones. + float vec[3] = {0, 0, 0}; + if (i == 1) { + float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]}; + float tmp[3] = {1, 0, 0}; + float dot = u0[0] * tmp[0] + u0[1] * tmp[1] + u0[2] * tmp[2]; + if (std::fabs(dot) > 0.9) { + tmp[0] = 0; + tmp[1] = 1; + tmp[2] = 0; + } + crossProduct(u0, tmp, vec); + } else if (i == 2) { + float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]}; + float u1[3] = {U.m[0][1], U.m[1][1], U.m[2][1]}; + crossProduct(u0, u1, vec); + } + normalize(vec); + for (int r = 0; r < 3; r++) { + U.m[r][i] = vec[r]; + } + } + } + + // Construct the diagonal S matrix. + Matrix3x3 S{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + S.m[i][j] = (i == j) ? sigma[i] : 0.0; + } + } + + // Vt is the transpose of V. + Matrix3x3 Vt = transpose(V); + + return std::make_tuple(U, S, Vt); +} +} // namespace + +std::tuple linalg_svd_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& A, + bool full_matrices, + bool compute_uv, + ::executorch::aten::optional<::executorch::aten::string_view> driver, + Tensor& U, + Tensor& S, + Tensor& Vh) { + std::tuple ret_val(U, S, Vh); + + ET_KERNEL_CHECK_MSG( + ctx, + A.scalar_type() == ScalarType::Float, + InvalidArgument, + ret_val, + "input.dtype(): %s must be %s", + ::torch::executor::toString(A.scalar_type()), + ::torch::executor::toString(ScalarType::Float)); + + ET_KERNEL_CHECK_MSG( + ctx, A.numel() > 0, InvalidArgument, ret_val, "input.size() must be > 0"); + + ET_KERNEL_CHECK_MSG( + ctx, + A.numel() % 9 == 0, + InvalidArgument, + ret_val, + "SVD of only 3x3 matrix is supported! Expected the input to have (batch_size x 9) number of elements, but received %d elements instead", + int(A.numel())); + + int batch_size = A.numel() / 9; + + ET_KERNEL_CHECK_MSG( + ctx, + U.numel() / 9 == batch_size, + InvalidArgument, + ret_val, + "Output tensor U must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(U.numel() / 9)); + + ET_KERNEL_CHECK_MSG( + ctx, + S.numel() / 3 == batch_size, + InvalidArgument, + ret_val, + "Output tensor S must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(S.numel() / 3)); + + ET_KERNEL_CHECK_MSG( + ctx, + Vh.numel() / 9 == batch_size, + InvalidArgument, + ret_val, + "Output tensor Vh must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(Vh.numel() / 9)); + + const float* A_data = A.const_data_ptr(); + float* U_data = U.mutable_data_ptr(); + float* S_data = S.mutable_data_ptr(); + float* Vh_data = Vh.mutable_data_ptr(); + + for (int i = 0; i < batch_size; i++) { + int offset = i * 9; + Matrix3x3 A_mat = {{ + {A_data[offset + 0], A_data[offset + 1], A_data[offset + 2]}, + {A_data[offset + 3], A_data[offset + 4], A_data[offset + 5]}, + {A_data[offset + 6], A_data[offset + 7], A_data[offset + 8]}, + }}; + + Matrix3x3 U_mat{}, S_mat{}, Vh_mat{}; + std::tie(U_mat, S_mat, Vh_mat) = svd(A_mat); + + U_data[offset + 0] = U_mat.m[0][0]; + U_data[offset + 1] = U_mat.m[0][1]; + U_data[offset + 2] = U_mat.m[0][2]; + U_data[offset + 3] = U_mat.m[1][0]; + U_data[offset + 4] = U_mat.m[1][1]; + U_data[offset + 5] = U_mat.m[1][2]; + U_data[offset + 6] = U_mat.m[2][0]; + U_data[offset + 7] = U_mat.m[2][1]; + U_data[offset + 8] = U_mat.m[2][2]; + + S_data[offset + 0] = S_mat.m[0][0]; + S_data[offset + 1] = S_mat.m[1][1]; + S_data[offset + 2] = S_mat.m[2][2]; + + Vh_data[offset + 0] = Vh_mat.m[0][0]; + Vh_data[offset + 1] = Vh_mat.m[0][1]; + Vh_data[offset + 2] = Vh_mat.m[0][2]; + Vh_data[offset + 3] = Vh_mat.m[1][0]; + Vh_data[offset + 4] = Vh_mat.m[1][1]; + Vh_data[offset + 5] = Vh_mat.m[1][2]; + Vh_data[offset + 6] = Vh_mat.m[2][0]; + Vh_data[offset + 7] = Vh_mat.m[2][1]; + Vh_data[offset + 8] = Vh_mat.m[2][2]; + } + + return ret_val; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_linalg_svd.h b/backends/cadence/generic/operators/op_linalg_svd.h new file mode 100644 index 00000000000..7635276c4f5 --- /dev/null +++ b/backends/cadence/generic/operators/op_linalg_svd.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +std::tuple< + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&> +linalg_svd_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& A, + bool full_matrices, + bool compute_uv, + ::executorch::aten::optional<::executorch::aten::string_view> driver, + ::executorch::aten::Tensor& U, + ::executorch::aten::Tensor& S, + ::executorch::aten::Tensor& Vh); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_roi_align_box_processor.cpp b/backends/cadence/generic/operators/op_roi_align_box_processor.cpp new file mode 100644 index 00000000000..a2f7fdfde7c --- /dev/null +++ b/backends/cadence/generic/operators/op_roi_align_box_processor.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +using UnpackedVec = std::array; +using PackedVec = std::array; +using IterVec = std::array; + +IterVec computeAddrIncr(const IterVec& shape, const IterVec& strides) { + auto rank = shape.size(); + auto inc = strides; + for (int n = 1; n < static_cast(rank); ++n) { + inc[n] = strides[n] - strides[n - 1] * shape[n - 1] + inc[n - 1]; + } + return inc; +} + +template +PackedVec packTuringVals(const UnpackedVec& vals, bool is_signed) { + PackedVec result{}; + int bitPos = 0; // bit position in output vector + for (int v : vals) { + assert(is_signed || v >= 0); + if (is_signed) { + assert( + v >= -(1 << (perItemBitWidth - 1)) && + v < (1 << (perItemBitWidth - 1))); + } else { + assert(v < (1 << perItemBitWidth)); + } + + if (v < 0) { + v = (1 << perItemBitWidth) + v; + } + + // Extract bit by bit and store in the output array + for (int bit = 0; bit < perItemBitWidth; ++bit) { + auto outBitIndex = bitPos + bit; + auto byteIndex = outBitIndex / 8; + auto bitInByte = outBitIndex % 8; + // Extract bit from val + uint8_t bitVal = (v >> bit) & 1; + // Set bit in output byte + result[byteIndex] |= (bitVal << bitInByte); + } + bitPos += perItemBitWidth; + } + assert(bitPos == vals.size() * perItemBitWidth); + return result; +} + +template +constexpr int get_fp_scale() { + return 1 << frac_bits; +} + +template +int convert_to_S13(float fp) { + return int(std::round(fp * get_fp_scale())); +} + +PackedVec convertBoxPosToTuringConfig( + float topLeftX, + float topLeftY, + float bottomRightX, + float bottomRightY, + int roiAlignNumBoxes, + int output_size_h, + int output_size_w, + int sampling_ratio, + bool aligned) { + constexpr int precisionMode = 0; + auto dstImgH = output_size_h * sampling_ratio; + auto dstImgW = output_size_w * sampling_ratio; + auto dstTileH = dstImgH; + auto dstTileW = dstImgW; + + float stepX = (bottomRightX - topLeftX) / dstImgW; + float stepY = (bottomRightY - topLeftY) / dstImgH; + + if (aligned) { + topLeftX -= 0.5; + topLeftY -= 0.5; + } + + auto anchorX = convert_to_S13(topLeftX + stepX / 2); + auto anchorY = convert_to_S13(topLeftY + stepY / 2); + + UnpackedVec vals{}; + vals[0] = anchorX; + vals[1] = anchorY; + + IterVec shape = {dstTileW, dstTileH, 1, 1, 1, roiAlignNumBoxes}; + auto addrIncrementX = computeAddrIncr( + shape, + {convert_to_S13(stepX), + 0, + convert_to_S13(stepX * dstTileW), + 0, + 0, + 0}); + auto addrIncrementY = computeAddrIncr( + shape, + {0, + convert_to_S13(stepY), + 0, + convert_to_S13(stepY * dstTileH), + 0, + 0}); + + for (int i = 0; i < 10; ++i) { + vals[i + 2] = i < addrIncrementX.size() + ? addrIncrementX[i] + : addrIncrementX[addrIncrementX.size() - 1]; + vals[i + 12] = i < addrIncrementY.size() + ? addrIncrementY[i] + : addrIncrementY[addrIncrementY.size() - 1]; + } + + return packTuringVals(vals, true); +} + +} // namespace + +Tensor& roi_align_box_processor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& rois, + int64_t output_size_h, + int64_t output_size_w, + int64_t sampling_ratio, + bool aligned, + Tensor& out) { + int K = static_cast(rois.size(0)); + auto roi = rois.const_data_ptr(); + for (int i = 0; i < K; ++i) { + assert( + static_cast(roi[i * 5]) == 0 && "Only support 1 image for now."); + auto x1 = roi[i * 5 + 1]; + auto y1 = roi[i * 5 + 2]; + auto x2 = roi[i * 5 + 3]; + auto y2 = roi[i * 5 + 4]; + auto turing_roi = convertBoxPosToTuringConfig( + x1, + y1, + x2, + y2, + static_cast(K), + static_cast(output_size_h), + static_cast(output_size_w), + static_cast(sampling_ratio), + aligned); + static_assert(turing_roi.size() == 80); + + auto out_ptr = out.mutable_data_ptr() + i * turing_roi.size(); + for (auto val : turing_roi) { + *out_ptr++ = val; + } + } + return out; +} +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_roi_align_box_processor.h b/backends/cadence/generic/operators/op_roi_align_box_processor.h new file mode 100644 index 00000000000..34d07d7b700 --- /dev/null +++ b/backends/cadence/generic/operators/op_roi_align_box_processor.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& roi_align_box_processor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& rois, + int64_t output_size_h, + int64_t output_size_w, + int64_t sampling_ratio, + bool aligned, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_rope.cpp b/backends/cadence/generic/operators/op_rope.cpp new file mode 100644 index 00000000000..f47e9752125 --- /dev/null +++ b/backends/cadence/generic/operators/op_rope.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "executorch/backends/cadence/generic/operators/op_rope.h" + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; + +Tensor& rope_out( + ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& sin_tensor, + const Tensor& cos_tensor, + const optional& pos, + Tensor& out) { + // Input shape is [1, seq, h, hd / 2, 2] or [1, seq, h, hd] + const auto kSeq = input.size(1); + const auto kH = input.size(2); + const auto kHd = input.numel() / (kSeq * kH); + for (int32_t s = 0; s < kSeq; ++s) { + for (int32_t h = 0; h < kH; ++h) { + for (int32_t hd_o = 0; hd_o < kHd / 2; ++hd_o) { + float x_0 = + input.const_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2]; + float x_1 = + input + .const_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2 + 1]; + int64_t token_id = s; + if (pos.has_value()) { + if (pos->scalar_type() == ::executorch::aten::ScalarType::Int) { + token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; + } else { + token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; + } + } + float sin = + sin_tensor.const_data_ptr()[token_id * kHd / 2 + hd_o]; + float cos = + cos_tensor.const_data_ptr()[token_id * kHd / 2 + hd_o]; + + float out_0 = x_0 * cos - x_1 * sin; + float out_1 = x_0 * sin + x_1 * cos; + out.mutable_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2] = + out_0; + out.mutable_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2 + 1] = + out_1; + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_rope.h b/backends/cadence/generic/operators/op_rope.h new file mode 100644 index 00000000000..cdd4db1be0f --- /dev/null +++ b/backends/cadence/generic/operators/op_rope.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& rope_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& sin_tensor, + const ::executorch::aten::Tensor& cos_tensor, + const ::executorch::aten::optional<::executorch::aten::Tensor>& pos, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_where_scalar.cpp b/backends/cadence/generic/operators/op_where_scalar.cpp new file mode 100644 index 00000000000..33c3afdc9f0 --- /dev/null +++ b/backends/cadence/generic/operators/op_where_scalar.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& where_Scalar_out( + ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& condition, + const double val1, + const double val2, + ::executorch::aten::Tensor& out) { + const float val1_f = static_cast(val1); + const float val2_f = static_cast(val2); + for (int i = 0; i < out.numel(); ++i) { + out.mutable_data_ptr()[i] = + condition.const_data_ptr()[i] ? val1_f : val2_f; + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_where_scalar.h b/backends/cadence/generic/operators/op_where_scalar.h new file mode 100644 index 00000000000..80a0b8054d5 --- /dev/null +++ b/backends/cadence/generic/operators/op_where_scalar.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& where_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& condition, + double val1, + double val2, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index fa0f128b229..a04dd70f53e 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -183,6 +183,71 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "op_where_scalar", + srcs = ["op_where_scalar.cpp"], + exported_headers = ["op_where_scalar.h", "operators.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_rope", + srcs = ["op_rope.cpp"], + exported_headers = ["op_rope.h", "operators.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_linalg_svd", + srcs = ["op_linalg_svd.cpp"], + headers = ["op_linalg_svd.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_roi_align_box_processor", + srcs = ["op_roi_align_box_processor.cpp"], + exported_headers = ["op_roi_align_box_processor.h", "operators.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + # Combined target for backward compatibility # NOTE: cadence_aot_lib now uses individual targets directly for better linking runtime.cxx_library( From 907a468aa73b9cdab6eb23518d1d9933cf5bd1f5 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Mon, 1 Dec 2025 10:44:55 -0800 Subject: [PATCH 3/4] Rename generic operator cpp/headers. Summary: Stick to `op_` convention for operator library. Differential Revision: D88079745 --- ...ensor.cpp => op_dequantize_per_tensor.cpp} | 2 + .../operators/op_dequantize_per_tensor.h | 30 +++++++++++++ .../{im2row_out.cpp => op_im2row.cpp} | 27 ++++++++---- .../cadence/generic/operators/op_im2row.h | 42 +++++++++++++++++++ ..._tensor.cpp => op_quantize_per_tensor.cpp} | 2 + .../operators/op_quantize_per_tensor.h | 29 +++++++++++++ .../cadence/generic/operators/targets.bzl | 13 +++--- 7 files changed, 130 insertions(+), 15 deletions(-) rename backends/cadence/generic/operators/{dequantize_per_tensor.cpp => op_dequantize_per_tensor.cpp} (98%) create mode 100644 backends/cadence/generic/operators/op_dequantize_per_tensor.h rename backends/cadence/generic/operators/{im2row_out.cpp => op_im2row.cpp} (95%) create mode 100644 backends/cadence/generic/operators/op_im2row.h rename backends/cadence/generic/operators/{quantize_per_tensor.cpp => op_quantize_per_tensor.cpp} (98%) create mode 100644 backends/cadence/generic/operators/op_quantize_per_tensor.h diff --git a/backends/cadence/generic/operators/dequantize_per_tensor.cpp b/backends/cadence/generic/operators/op_dequantize_per_tensor.cpp similarity index 98% rename from backends/cadence/generic/operators/dequantize_per_tensor.cpp rename to backends/cadence/generic/operators/op_dequantize_per_tensor.cpp index ec05272da1b..a17fa48778f 100644 --- a/backends/cadence/generic/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/generic/operators/op_dequantize_per_tensor.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include diff --git a/backends/cadence/generic/operators/op_dequantize_per_tensor.h b/backends/cadence/generic/operators/op_dequantize_per_tensor.h new file mode 100644 index 00000000000..b4965c4e7ab --- /dev/null +++ b/backends/cadence/generic/operators/op_dequantize_per_tensor.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& dequantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::Tensor& out); + +} +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/im2row_out.cpp b/backends/cadence/generic/operators/op_im2row.cpp similarity index 95% rename from backends/cadence/generic/operators/im2row_out.cpp rename to backends/cadence/generic/operators/op_im2row.cpp index c79a7bf1b7f..8c939c7ad5c 100644 --- a/backends/cadence/generic/operators/im2row_out.cpp +++ b/backends/cadence/generic/operators/op_im2row.cpp @@ -6,10 +6,18 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include +#include + +#ifndef DISABLE_ALWAYS_INLINE +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define ALWAYS_INLINE inline +#endif + namespace impl { namespace generic { namespace native { @@ -20,7 +28,7 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; template -__attribute__((always_inline)) void im2row_( +ALWAYS_INLINE void im2row_( const T* __restrict__ data_im, const int32_t in_zero_point, /* input parameters*/ @@ -76,7 +84,7 @@ __attribute__((always_inline)) void im2row_( // 'channels' contiguous values. Otherwise we will fill the output // with 0's. if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - std::memcpy(slice_col, slice_im, channels * sizeof(T)); + memcpy(slice_col, slice_im, channels * sizeof(T)); } else { std::fill_n(slice_col, channels, T(in_zero_point)); } @@ -115,8 +123,8 @@ __attribute__((always_inline)) void im2row_( } } -void im2row_out( - __ET_UNUSED KernelRuntimeContext& ctx, +Tensor& im2row_out( + ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, IntArrayRef kernel_size, IntArrayRef dilation, @@ -170,7 +178,7 @@ void im2row_out( in_zero_point.const_data_ptr(); \ int32_t in_plane = in_c * in_h * in_w; \ int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ - for (size_t n = 0; n < batch_size; ++n) { \ + for (int32_t n = 0; n < batch_size; ++n) { \ im2row_( \ &in_data[n * in_plane], \ per_tensor_quantized ? zero_point[0] : zero_point[n], \ @@ -205,10 +213,12 @@ void im2row_out( torch::executor::toString(dtype)); } #undef typed_im2row + + return out; } -void im2row_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, +Tensor& im2row_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, IntArrayRef kernel_size, IntArrayRef dilation, @@ -291,6 +301,7 @@ void im2row_per_tensor_out( torch::executor::toString(dtype)); } #undef typed_im2row_per_tensor + return out; } } // namespace native diff --git a/backends/cadence/generic/operators/op_im2row.h b/backends/cadence/generic/operators/op_im2row.h new file mode 100644 index 00000000000..0eed36c9139 --- /dev/null +++ b/backends/cadence/generic/operators/op_im2row.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& im2row_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef stride, + const ::executorch::aten::Tensor& in_zero_point, + bool channel_last, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& im2row_per_tensor_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef stride, + int64_t in_zero_point, + bool channel_last, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/quantize_per_tensor.cpp b/backends/cadence/generic/operators/op_quantize_per_tensor.cpp similarity index 98% rename from backends/cadence/generic/operators/quantize_per_tensor.cpp rename to backends/cadence/generic/operators/op_quantize_per_tensor.cpp index 8ce70d2b51d..51a0a1d0f9d 100644 --- a/backends/cadence/generic/operators/quantize_per_tensor.cpp +++ b/backends/cadence/generic/operators/op_quantize_per_tensor.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include diff --git a/backends/cadence/generic/operators/op_quantize_per_tensor.h b/backends/cadence/generic/operators/op_quantize_per_tensor.h new file mode 100644 index 00000000000..c1e826a7cf9 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantize_per_tensor.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::Tensor& out); +} +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index a04dd70f53e..0a6859c64b0 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -6,8 +6,8 @@ def define_common_targets(): runtime.cxx_library( name = "im2row_out", - srcs = ["im2row_out.cpp"], - exported_headers = ["operators.h"], + srcs = ["op_im2row.cpp"], + exported_headers = ["op_im2row.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", @@ -32,11 +32,10 @@ def define_common_targets(): ], ) - # Quantized operators that need cadence kernels for quantize/dequantize runtime.cxx_library( name = "dequantize_per_tensor", - srcs = ["dequantize_per_tensor.cpp"], - exported_headers = ["quantized_ops.h"], + srcs = ["op_dequantize_per_tensor.cpp"], + exported_headers = ["op_dequantize_per_tensor.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", @@ -50,8 +49,8 @@ def define_common_targets(): runtime.cxx_library( name = "quantize_per_tensor", - srcs = ["quantize_per_tensor.cpp"], - exported_headers = ["quantized_ops.h"], + srcs = ["op_quantize_per_tensor.cpp"], + exported_headers = ["op_quantize_per_tensor.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", From 7fd66d26ad30fbd58a299d1053b75cf3c076a2d5 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Mon, 1 Dec 2025 14:20:47 -0800 Subject: [PATCH 4/4] Migrate generic cadence operators to oss. (#16023) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/16023 * op_avg_pool2d * op_softmax * op_conv*d Reviewed By: mcremon-meta Differential Revision: D88081853 --- .../generic/operators/op_avg_pool2d.cpp | 156 ++++++++++++++++++ .../cadence/generic/operators/op_avg_pool2d.h | 34 ++++ .../cadence/generic/operators/op_conv1d.cpp | 104 ++++++++++++ .../cadence/generic/operators/op_conv1d.h | 35 ++++ .../cadence/generic/operators/op_conv2d.cpp | 122 ++++++++++++++ .../cadence/generic/operators/op_conv2d.h | 35 ++++ .../cadence/generic/operators/op_conv3d.cpp | 139 ++++++++++++++++ .../cadence/generic/operators/op_conv3d.h | 35 ++++ .../cadence/generic/operators/op_softmax.cpp | 137 +++++++++++++++ .../cadence/generic/operators/op_softmax.h | 34 ++++ .../cadence/generic/operators/targets.bzl | 77 +++++++-- 11 files changed, 896 insertions(+), 12 deletions(-) create mode 100644 backends/cadence/generic/operators/op_avg_pool2d.cpp create mode 100644 backends/cadence/generic/operators/op_avg_pool2d.h create mode 100644 backends/cadence/generic/operators/op_conv1d.cpp create mode 100644 backends/cadence/generic/operators/op_conv1d.h create mode 100644 backends/cadence/generic/operators/op_conv2d.cpp create mode 100644 backends/cadence/generic/operators/op_conv2d.h create mode 100644 backends/cadence/generic/operators/op_conv3d.cpp create mode 100644 backends/cadence/generic/operators/op_conv3d.h create mode 100644 backends/cadence/generic/operators/op_softmax.cpp create mode 100644 backends/cadence/generic/operators/op_softmax.h diff --git a/backends/cadence/generic/operators/op_avg_pool2d.cpp b/backends/cadence/generic/operators/op_avg_pool2d.cpp new file mode 100644 index 00000000000..936a5583d8c --- /dev/null +++ b/backends/cadence/generic/operators/op_avg_pool2d.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; + +// Compute the avg_pool2d for in_data in NCHW layout. IT is the input datatype, +// and AT is the accumulation datatype. 'quantized' is true when the input is +// quantized tensor. +template +void avg_pool2d_nchw( + const IT* __restrict__ in_data, + const int32_t in_zero_point, + IT* __restrict__ out_data, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool count_include_pad, + int64_t divisor, + int leading_dims, + int ih, + int iw, + int oh, + int ow) { + int kh = kernel_size[0]; + int kw = kernel_size[1]; + int s0 = stride[0]; + int s1 = stride[1]; + int p0 = padding[0]; + int p1 = padding[1]; + + for (int _n = 0; _n < leading_dims; ++_n) { + for (int _ih = 0, _oh = 0; _oh < oh; ++_oh, _ih += s0) { + int input_offset = _n * ih * iw; + int output_offset = _n * oh * ow + _oh * ow; + for (int _iw = 0, _ow = 0; _ow < ow; ++_ow, _iw += s1) { + int kh_lo = std::max(0, _ih - p0); + int kh_hi = std::min(ih, _ih + kh - p0); + int kw_lo = std::max(0, _iw - p1); + int kw_hi = std::min(iw, _iw + kw - p1); + // Count the number of contributions sans padding + int count = (kh_hi - kh_lo) * (kw_hi - kw_lo); + // Set the accumulator + AT acc = count_include_pad ? in_zero_point * (kh * kw - count) : 0; + // Accumulate values + for (int _kh = kh_lo; _kh < kh_hi; ++_kh) { + for (int _kw = kw_lo; _kw < kw_hi; ++_kw) { + int input_addr = input_offset + _kh * iw + _kw; + acc += in_data[input_addr]; + } + } + // The divisor changes depending on whether the count includes + // padded cells or not. + float inv_divisor = 1. / (count_include_pad ? divisor : count); + float val = acc * inv_divisor; + if (quantized) { + int32_t min_val = + static_cast(std::numeric_limits::min()); + int32_t max_val = + static_cast(std::numeric_limits::max()); + out_data[output_offset + _ow] = std::min( + std::max(int32_t(std::nearbyint(val)), min_val), max_val); + } else { + out_data[output_offset + _ow] = val; + } + } + } + } +} + +Tensor& avg_pool2d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + optional divisor_override, + const optional& in_zero_point_t, + bool channel_last, + Tensor& out) { + ET_DCHECK_MSG(!channel_last, "NHWC layout for avg_pool2d not yet supported"); + const int32_t in_zero_point = in_zero_point_t.has_value() + ? in_zero_point_t.value().const_data_ptr()[0] + : 0; + const int64_t divisor = divisor_override.has_value() + ? divisor_override.value() + : kernel_size[0] * kernel_size[1]; + + const int odim = out.dim(); + const int on = getLeadingDims(out, odim - 2); + const int oh = out.size(odim - 2); + const int ow = out.size(odim - 1); + const int ih = input.size(odim - 2); + const int iw = input.size(odim - 1); + + // We generate the kernel for float and uint8_t types. The operator also + // works for double, but does not support other dtypes. +#define typed_avg_pool2d(btype, ctype, quantized, dtype) \ + case ScalarType::dtype: { \ + avg_pool2d_nchw( \ + input.const_data_ptr(), \ + in_zero_point, \ + out.mutable_data_ptr(), \ + kernel_size, \ + stride, \ + padding, \ + count_include_pad, \ + divisor, \ + on, \ + ih, \ + iw, \ + oh, \ + ow); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_avg_pool2d(float, float, false, Float); + typed_avg_pool2d(uint8_t, int32_t, true, Byte); + default: + ET_DCHECK_MSG( + false, + "avg_pool2d not implemented for dtype %s", + torch::executor::toString(dtype)); + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_avg_pool2d.h b/backends/cadence/generic/operators/op_avg_pool2d.h new file mode 100644 index 00000000000..05f1810bb61 --- /dev/null +++ b/backends/cadence/generic/operators/op_avg_pool2d.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& avg_pool2d_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + ::executorch::aten::optional divisor_override, + const ::executorch::aten::optional<::executorch::aten::Tensor>& + in_zero_point_t, + bool channel_last, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv1d.cpp b/backends/cadence/generic/operators/op_conv1d.cpp new file mode 100644 index 00000000000..9a69c9fd549 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv1d.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 1D float32 convolution kernel. +// The input is of shape [n x c x w] (batch x channels x width) +// The weight is of shape [oc x wc x ww] (out_channels x weight_channels x +// weight_width) The output is of shape [n x oc x ow] (batch x out_channels x +// out_width) The bias is of shape [oc] + +Tensor& conv1d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out) { + // Extract dimensions + const int n = input.size(0); + const int c = input.size(1); + const int w = input.size(2); + const int oc = weight.size(0); + const int wc = weight.size(1); + const int ww = weight.size(2); + const int ow = out.size(2); + + const int16_t s = static_cast(stride[0]); + const int16_t p = static_cast(padding[0]); + const int16_t d = static_cast(dilation[0]); + const int16_t g = static_cast(groups); + + const float* p_in = input.const_data_ptr(); + const float* p_weight = weight.const_data_ptr(); + const float* p_bias = bias.const_data_ptr(); + float* p_out = out.mutable_data_ptr(); + + const bool zero_pad_unit_dilation = d == 1 && p == 0; + const int ocpg = oc / g; + const int icpg = c / g; + + for (int _n = 0; _n < n; ++_n) { + const float* in_batch = p_in + _n * c * w; + float* out_batch = p_out + _n * oc * ow; + for (int _g = 0; _g < g; ++_g) { + int sic = _g * icpg; + int soc = _g * ocpg; + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + float* out_plane = out_batch + _oc * ow; + const float* weight_batch = p_weight + _oc * wc * ww; + for (int _w = 0, _ow = 0; _ow < ow; _w += s, ++_ow) { + float acc = p_bias[_oc]; + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * w; + const float* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = _w + _ww; + acc += in_plane[ioff] * weight_plane[_ww]; + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * w; + const float* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + int w_pos = _w + d * _ww - p; + if (w_pos >= 0 && w_pos < w) { + acc += in_plane[w_pos] * weight_plane[_ww]; + } + } + } + } + out_plane[_ow] = acc; + } + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv1d.h b/backends/cadence/generic/operators/op_conv1d.h new file mode 100644 index 00000000000..457c46ea358 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv1d.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& conv1d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv2d.cpp b/backends/cadence/generic/operators/op_conv2d.cpp new file mode 100644 index 00000000000..b71695fe367 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv2d.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 2D float32 convolution kernel. +// The input is of shape [n x c x h x w] (batch x channels x height x width) +// The weight is of shape [oc x wc x wh x ww] (out_channels x weight_channels x +// weight_height x weight_width) The output is of shape [n x oc x oh x ow] +// (batch x out_channels x out_height x out_width) The bias is of shape [oc] + +Tensor& conv2d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out) { + // Extract dimensions + const int n = input.size(0); + const int c = input.size(1); + const int h = input.size(2); + const int w = input.size(3); + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = weight.size(2); + const int ww = weight.size(3); + const int oh = out.size(2); + const int ow = out.size(3); + + const int16_t s0 = static_cast(stride[0]); + const int16_t s1 = static_cast(stride[1]); + const int16_t p0 = static_cast(padding[0]); + const int16_t p1 = static_cast(padding[1]); + const int16_t d0 = static_cast(dilation[0]); + const int16_t d1 = static_cast(dilation[1]); + const int16_t g = static_cast(groups); + + const float* p_in = input.const_data_ptr(); + const float* p_weight = weight.const_data_ptr(); + const float* p_bias = bias.const_data_ptr(); + float* p_out = out.mutable_data_ptr(); + + const bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + const int ocpg = oc / g; + const int icpg = c / g; + + for (int _n = 0; _n < n; ++_n) { + const float* in_batch = p_in + _n * c * h * w; + float* out_batch = p_out + _n * oc * oh * ow; + for (int _g = 0; _g < g; ++_g) { + int sic = _g * icpg; + int soc = _g * ocpg; + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + float* out_plane = out_batch + _oc * oh * ow; + const float* weight_batch = p_weight + _oc * wc * wh * ww; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * h * w; + const float* weight_plane = + weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + acc += in_plane[ioff] * weight_plane[woff]; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * h * w; + const float* weight_plane = + weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int h_pos = _h + d0 * _wh - p0; + int w_pos = _w + d1 * _ww - p1; + if (h_pos >= 0 && h_pos < h && w_pos >= 0 && w_pos < w) { + int ioff = h_pos * w + w_pos; + int woff = _wh * ww + _ww; + acc += in_plane[ioff] * weight_plane[woff]; + } + } + } + } + } + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv2d.h b/backends/cadence/generic/operators/op_conv2d.h new file mode 100644 index 00000000000..576cb5d5cb5 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv2d.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& conv2d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv3d.cpp b/backends/cadence/generic/operators/op_conv3d.cpp new file mode 100644 index 00000000000..0ea6b063311 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv3d.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 3D float32 convolution kernel. +// The input is of shape [n x c x d x h x w] (batch x channels x depth x height +// x width) The weight is of shape [oc x wc x wd x wh x ww] (out_channels x +// weight_channels x weight_depth x weight_height x weight_width) The output is +// of shape [n x oc x od x oh x ow] (batch x out_channels x out_depth x +// out_height x out_width) The bias is of shape [oc] + +Tensor& conv3d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out) { + // Extract dimensions + const int n = input.size(0); + const int c = input.size(1); + const int d = input.size(2); + const int h = input.size(3); + const int w = input.size(4); + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wd = weight.size(2); + const int wh = weight.size(3); + const int ww = weight.size(4); + const int od = out.size(2); + const int oh = out.size(3); + const int ow = out.size(4); + + const int16_t s0 = static_cast(stride[0]); + const int16_t s1 = static_cast(stride[1]); + const int16_t s2 = static_cast(stride[2]); + const int16_t p0 = static_cast(padding[0]); + const int16_t p1 = static_cast(padding[1]); + const int16_t p2 = static_cast(padding[2]); + const int16_t d0 = static_cast(dilation[0]); + const int16_t d1 = static_cast(dilation[1]); + const int16_t d2 = static_cast(dilation[2]); + const int16_t g = static_cast(groups); + + const float* p_in = input.const_data_ptr(); + const float* p_weight = weight.const_data_ptr(); + const float* p_bias = bias.const_data_ptr(); + float* p_out = out.mutable_data_ptr(); + + const bool zero_pad_unit_dilation = + d0 == 1 && d1 == 1 && d2 == 1 && p0 == 0 && p1 == 0 && p2 == 0; + const int ocpg = oc / g; + const int icpg = c / g; + + for (int _n = 0; _n < n; ++_n) { + const float* in_batch = p_in + _n * c * d * h * w; + float* out_batch = p_out + _n * oc * od * oh * ow; + for (int _g = 0; _g < g; ++_g) { + int sic = _g * icpg; + int soc = _g * ocpg; + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + float* out_volume = out_batch + _oc * od * oh * ow; + const float* weight_batch = p_weight + _oc * wc * wd * wh * ww; + for (int _d = 0, _od = 0; _od < od; _d += s0, ++_od) { + for (int _h = 0, _oh = 0; _oh < oh; _h += s1, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s2, ++_ow) { + float acc = p_bias[_oc]; + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_volume = in_batch + _ic * d * h * w; + const float* weight_volume = + weight_batch + (_ic - sic) * wd * wh * ww; + for (int _wd = 0; _wd < wd; ++_wd) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = + (_d + _wd) * h * w + (_h + _wh) * w + (_w + _ww); + int woff = _wd * wh * ww + _wh * ww + _ww; + acc += in_volume[ioff] * weight_volume[woff]; + } + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_volume = in_batch + _ic * d * h * w; + const float* weight_volume = + weight_batch + (_ic - sic) * wd * wh * ww; + for (int _wd = 0; _wd < wd; ++_wd) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int d_pos = _d + d0 * _wd - p0; + int h_pos = _h + d1 * _wh - p1; + int w_pos = _w + d2 * _ww - p2; + if (d_pos >= 0 && d_pos < d && h_pos >= 0 && + h_pos < h && w_pos >= 0 && w_pos < w) { + int ioff = d_pos * h * w + h_pos * w + w_pos; + int woff = _wd * wh * ww + _wh * ww + _ww; + acc += in_volume[ioff] * weight_volume[woff]; + } + } + } + } + } + } + out_volume[_od * oh * ow + _oh * ow + _ow] = acc; + } + } + } + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv3d.h b/backends/cadence/generic/operators/op_conv3d.h new file mode 100644 index 00000000000..db896f5f318 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv3d.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& conv3d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_softmax.cpp b/backends/cadence/generic/operators/op_softmax.cpp new file mode 100644 index 00000000000..97c64a22511 --- /dev/null +++ b/backends/cadence/generic/operators/op_softmax.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +namespace { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void vec_softmax_f32_f32( + float* __restrict__ y, + const float* __restrict__ x, + int n) { + // compute softmax(x, x+n) and returns in y + // y = e ^ (x - max(x)) / sum(e^(x - max(x)) + float max_x = *std::max_element(x, x + n); + float sum = 0; + + for (int i = 0; i < n; ++i) { + y[i] = expf(x[i] - max_x); + sum += y[i]; + } + + for (int i = 0; i < n; ++i) { + y[i] /= sum; + } +} + +// This function is borrowed from the portable kernel implementation, with only +// float type supported. +void _softmax_portable(const Tensor& in, int64_t dim, Tensor& out) { + const float* const in_data = in.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in log_softmax dim. During log_softmax + // computation each value is subtracted by the maximum in + // value before calling exp to preserve numerical stability. + const float max_in = torch::executor::apply_unary_reduce_fn( + [](const float val_in, float val_accum) { + return std::max(val_in, val_accum); + }, + in_data + base, + size, + stride); + + float temp_sum = + torch::executor::apply_unary_map_reduce_fn( + [max_in](const float val_in) { + return std::exp(val_in - max_in); + }, + [](const float mapped_in, float val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const float val_in) { + return std::exp(val_in - max_in) / temp_sum; + }, + in_data + base, + out_data + base, + size, + stride); + }, + in, + dim); +} + +} // namespace + +Tensor& _softmax_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t dim, + ET_UNUSED bool half_to_float, + Tensor& Y) { + if (dim < 0) { + dim += X.dim(); + } + + // If dim is not the last dimension, we cannot use the kernel below. + // Falling back on a more generic kernel. + if (dim < X.dim() - 1) { + _softmax_portable(X, dim, Y); + return Y; + } + + const float* __restrict__ x_data = X.const_data_ptr(); + float* __restrict__ y_data = Y.mutable_data_ptr(); + + size_t K = X.size(X.dim() - 1); + size_t leading_dim = ::executorch::runtime::getLeadingDims(X, X.dim() - 1); + + for (size_t i = 0; i < leading_dim; ++i) { + const float* x = x_data + i * K; + float* y = y_data + i * K; + vec_softmax_f32_f32(y, x, K); + } + + return Y; +} + +Tensor& _softmax_f32_f32_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t dim, + __ET_UNUSED ::executorch::aten::optional half_to_float, + Tensor& Y) { + _softmax_out(ctx, X, dim, false, Y); + + return Y; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_softmax.h b/backends/cadence/generic/operators/op_softmax.h new file mode 100644 index 00000000000..ec51b1d00c0 --- /dev/null +++ b/backends/cadence/generic/operators/op_softmax.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& _softmax_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + int64_t dim, + __ET_UNUSED bool half_to_float, + ::executorch::aten::Tensor& Y); + +::executorch::aten::Tensor& _softmax_f32_f32_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + int64_t dim, + __ET_UNUSED ::executorch::aten::optional half_to_float, + ::executorch::aten::Tensor& Y); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index 0a6859c64b0..623233e4732 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -247,22 +247,75 @@ def define_common_targets(): ], ) - # Combined target for backward compatibility - # NOTE: cadence_aot_lib now uses individual targets directly for better linking runtime.cxx_library( - name = "cadence_generic_ops", - srcs = glob([ - "*.cpp", - ]), - exported_headers = glob([ - "*.h", - ]), + name = "op_softmax", + srcs = ["op_softmax.cpp"], + exported_headers = ["op_softmax.h"], platforms = CXX, deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/runtime/kernel:kernel_includes", - "//executorch/kernels/portable/cpu:scalar_utils", - "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:reduce_util", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_conv1d", + srcs = ["op_conv1d.cpp"], + exported_headers = ["op_conv1d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_conv2d", + srcs = ["op_conv2d.cpp"], + exported_headers = ["op_conv2d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_conv3d", + srcs = ["op_conv3d.cpp"], + exported_headers = ["op_conv3d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_avg_pool2d", + srcs = ["op_avg_pool2d.cpp"], + exported_headers = ["op_avg_pool2d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", ], visibility = [ "//executorch/backends/cadence/...",