diff --git a/kernels/optimized/blas/CPUBlas.cpp b/kernels/optimized/blas/CPUBlas.cpp index 35b208d30fc..99003f8f0ea 100644 --- a/kernels/optimized/blas/CPUBlas.cpp +++ b/kernels/optimized/blas/CPUBlas.cpp @@ -173,5 +173,28 @@ void gemm( } // clang-format on +// clang-format off +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const BFloat16 alpha, + const BFloat16 *a, int64_t lda, + const BFloat16 *b, int64_t ldb, + const BFloat16 beta, + BFloat16 *c, int64_t ldc) { + normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + + using acc_type = utils::compute_dtype; + gemm_impl( + transa, transb, + m, n, k, + static_cast(alpha), + a, lda, + b, ldb, + static_cast(beta), + c, ldc); +} +// clang-format on + } // namespace cpublas } // namespace executorch diff --git a/kernels/optimized/blas/CPUBlas.h b/kernels/optimized/blas/CPUBlas.h index dd4a24cbce0..71e50601238 100644 --- a/kernels/optimized/blas/CPUBlas.h +++ b/kernels/optimized/blas/CPUBlas.h @@ -17,6 +17,7 @@ namespace executorch { namespace cpublas { +using BFloat16 = torch::executor::BFloat16; using Half = torch::executor::Half; enum class TransposeType { @@ -104,6 +105,15 @@ void gemm( const Half *b, int64_t ldb, const Half beta, Half *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const BFloat16 alpha, + const BFloat16 *a, int64_t lda, + const BFloat16 *b, int64_t ldb, + const BFloat16 beta, + BFloat16 *c, int64_t ldc); // clang-format on // clang-format off diff --git a/kernels/optimized/test/libblas_test.cpp b/kernels/optimized/test/libblas_test.cpp index 8f30a357e1a..24aeaba776a 100644 --- a/kernels/optimized/test/libblas_test.cpp +++ b/kernels/optimized/test/libblas_test.cpp @@ -9,6 +9,7 @@ #include #include +#include #include @@ -17,7 +18,8 @@ _(); \ _(); \ _(); \ - _(); + _(); \ + _(); namespace {