From 0ed161ae6ef8d2c0ff80ea926c77f86fd040eb29 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 5 Sep 2024 17:08:51 -0700 Subject: [PATCH 1/3] [ExecuTorch] Optimized op_mm using CPUBlas gemm No immediate need for this, but it is extremely simple to implement so why not support it? Differential Revision: [D62151659](https://our.internmc.facebook.com/intern/diff/D62151659/) [ghstack-poisoned] --- kernels/optimized/cpu/op_mm.cpp | 66 +++++++++++++++++++++++++++++++ kernels/optimized/cpu/targets.bzl | 7 ++++ kernels/optimized/optimized.yaml | 5 +++ kernels/test/targets.bzl | 2 +- 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 kernels/optimized/cpu/op_mm.cpp diff --git a/kernels/optimized/cpu/op_mm.cpp b/kernels/optimized/cpu/op_mm.cpp new file mode 100644 index 00000000000..aee11497d1a --- /dev/null +++ b/kernels/optimized/cpu/op_mm.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& opt_mm_out( + RuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + Tensor& out) { + ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out); + + size_t output_ndim = 0; + exec_aten::SizesType output_sizes[kTensorDimensionLimit]; + get_mm_out_target_size(in, mat2, output_sizes, &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok, + InvalidArgument, + out); + + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t n = in.size(0); + size_t k = in.size(1); + size_t m = mat2.size(1); + + // gemm expects column-major inputs and produces column-major + // output. So, we take advantage of the identity (A @ B).t() + // = B.t() @ A.t() here; row-major B is B.t() from gemm's + // column-major perspective, etc. + executorch::cpublas::gemm( + executorch::cpublas::TransposeType::NoTranspose, + executorch::cpublas::TransposeType::NoTranspose, + m, + n, + k, + static_cast(1), + mat2.const_data_ptr(), + m, + in.const_data_ptr(), + k, + static_cast(0), + out.mutable_data_ptr(), + m); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index e7bb2d36bf4..225498aa8d1 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -52,6 +52,13 @@ _OPTIMIZED_ATEN_OPS = ( ], }), ), + op_target( + name = "op_mm", + deps = [ + "//executorch/kernels/optimized:libblas", + "//executorch/kernels/portable/cpu/util:matmul_ops_util", + ], + ), op_target( name = "op_mul", deps = [ diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 0d445deb3e8..7c2c4d35fd7 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -52,6 +52,11 @@ - arg_meta: null kernel_name: torch::executor::opt_le_tensor_out +- op: mm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_mm_out + - op: mul.out kernels: - arg_meta: null diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 7ae17c5237a..cd3ca556fe6 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -244,7 +244,7 @@ def define_common_targets(): _common_op_test("op_mean_test", ["aten", "portable"]) _common_op_test("op_min_test", ["aten", "portable"]) _common_op_test("op_minimum_test", ["aten", "portable"]) - _common_op_test("op_mm_test", ["aten", "portable"]) + _common_op_test("op_mm_test", ["aten", "portable", "optimized"]) _common_op_test("op_mul_test", ["aten", "portable", "optimized"]) _common_op_test("op_narrow_copy_test", ["aten", "portable"]) _common_op_test("op_native_batch_norm_test", ["aten", "portable"]) From 237e6bd299ce09309a7ca1a366a41aad0e2481b9 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 9 Sep 2024 16:01:18 -0700 Subject: [PATCH 2/3] Update base for Update on "[ExecuTorch] Optimized op_mm using CPUBlas gemm" No immediate need for this, but it is extremely simple to implement so why not support it? Differential Revision: [D62151659](https://our.internmc.facebook.com/intern/diff/D62151659/) [ghstack-poisoned] From 83ccadf1c9c3c9d847205d8a5f1fcc322a83b79b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Sep 2024 10:22:04 -0700 Subject: [PATCH 3/3] Update base for Update on "[ExecuTorch] Optimized op_mm using CPUBlas gemm" No immediate need for this, but it is extremely simple to implement so why not support it? Differential Revision: [D62151659](https://our.internmc.facebook.com/intern/diff/D62151659/) [ghstack-poisoned]