From 01b07b7876bd6253fc4d745d0844bb8f2be801fc Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 5 Sep 2024 17:08:54 -0700 Subject: [PATCH] [ExecuTorch] Add optimized op_linear If we happen to be running without a delegate, directly implementing linear is much more efficient than permute_copy_out (materialize a transpose) followed by matmul. Differential Revision: [D62154007](https://our.internmc.facebook.com/intern/diff/D62154007/) [ghstack-poisoned] --- kernels/optimized/cpu/op_linear.cpp | 73 +++++ kernels/optimized/cpu/targets.bzl | 7 + kernels/optimized/optimized.yaml | 5 + kernels/portable/cpu/util/matmul_ops_util.cpp | 25 ++ kernels/portable/cpu/util/matmul_ops_util.h | 8 + kernels/test/op_linear_test.cpp | 301 ++++++++++++++++++ kernels/test/targets.bzl | 1 + 7 files changed, 420 insertions(+) create mode 100644 kernels/optimized/cpu/op_linear.cpp create mode 100644 kernels/test/op_linear_test.cpp diff --git a/kernels/optimized/cpu/op_linear.cpp b/kernels/optimized/cpu/op_linear.cpp new file mode 100644 index 00000000000..316e4c70c43 --- /dev/null +++ b/kernels/optimized/cpu/op_linear.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& opt_linear_out( + RuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + const optional& bias, + Tensor& out) { + ET_KERNEL_CHECK_MSG( + ctx, + !bias.has_value(), + InvalidArgument, + out, + "bias not supported yet in linear"); + ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out); + + size_t output_ndim = 0; + exec_aten::SizesType output_sizes[kTensorDimensionLimit]; + get_linear_out_target_size(in, mat2, output_sizes, &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok, + InvalidArgument, + out); + + int flattened_input_dim = 1; + for (int ii = 0; ii < in.dim() - 1; ++ii) { + flattened_input_dim *= in.sizes()[ii]; + } + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t n = flattened_input_dim; + size_t k = in.sizes()[in.dim() - 1]; + size_t m = mat2.size(0); + + executorch::cpublas::gemm( + executorch::cpublas::TransposeType::Transpose, + executorch::cpublas::TransposeType::NoTranspose, + m, + n, + k, + static_cast(1), + mat2.const_data_ptr(), + k, + in.const_data_ptr(), + k, + static_cast(0), + out.mutable_data_ptr(), + m); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 225498aa8d1..488d2af7fa1 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -40,6 +40,13 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu:scalar_utils", ], ), + op_target( + name = "op_linear", + deps = [ + "//executorch/kernels/optimized:libblas", + "//executorch/kernels/portable/cpu/util:matmul_ops_util", + ], + ), op_target( name = "op_log_softmax", deps = select({ diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 7c2c4d35fd7..2421673f8a7 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -52,6 +52,11 @@ - arg_meta: null kernel_name: torch::executor::opt_le_tensor_out +- op: linear.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_linear_out + - op: mm.out kernels: - arg_meta: null diff --git a/kernels/portable/cpu/util/matmul_ops_util.cpp b/kernels/portable/cpu/util/matmul_ops_util.cpp index d7e49d64958..3d4f2e5e9ba 100644 --- a/kernels/portable/cpu/util/matmul_ops_util.cpp +++ b/kernels/portable/cpu/util/matmul_ops_util.cpp @@ -71,6 +71,19 @@ bool check_mm_args(const Tensor& in, const Tensor& mat2, Tensor& out) { return true; } +bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(in.dim() == out.dim()); + ET_LOG_AND_RETURN_IF_FALSE(in.dim() >= 2); + ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2)); + + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out)); + + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_size_at_dims(in, in.dim() - 1, mat2, 1)); + + return true; +} + void get_mm_out_target_size( const Tensor& mat1, const Tensor& mat2, @@ -81,5 +94,17 @@ void get_mm_out_target_size( out_sizes[1] = mat2.size(1); } +void get_linear_out_target_size( + const Tensor& mat1, + const Tensor& mat2, + Tensor::SizesType* out_sizes, + size_t* out_ndim) { + *out_ndim = mat1.dim(); + for (int ii = 0; ii < mat1.dim() - 1; ++ii) { + out_sizes[ii] = mat1.sizes()[ii]; + } + out_sizes[mat1.dim() - 1] = mat2.size(0); +} + } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/matmul_ops_util.h b/kernels/portable/cpu/util/matmul_ops_util.h index 91e27ff2cc9..d2991868e95 100644 --- a/kernels/portable/cpu/util/matmul_ops_util.h +++ b/kernels/portable/cpu/util/matmul_ops_util.h @@ -37,5 +37,13 @@ void get_mm_out_target_size( Tensor::SizesType* out_sizes, size_t* out_ndim); +bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out); + +void get_linear_out_target_size( + const Tensor& mat1, + const Tensor& mat2, + Tensor::SizesType* out_sizes, + size_t* out_ndim); + } // namespace executor } // namespace torch diff --git a/kernels/test/op_linear_test.cpp b/kernels/test/op_linear_test.cpp new file mode 100644 index 00000000000..1d65c9a733e --- /dev/null +++ b/kernels/test/op_linear_test.cpp @@ -0,0 +1,301 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace ::testing; +using exec_aten::ArrayRef; +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpLinearOutTest : public OperatorTest { + protected: + Tensor& op_linear_out(const Tensor& self, const Tensor& mat2, Tensor& out) { + return torch::executor::aten::linear_outf(context_, self, mat2, torch::executor::nullopt, out); + } + + template + void test_dtype() { + TensorFactory tf; + + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + if (DTYPE == ScalarType::Half) { + GTEST_SKIP() + << "skip Half because torch::executor::aten::mm_out does not support Half"; + return; + } + } + + // matmul gives 4 * 2 * 3 = 24 + Tensor x = tf.full({3, 4}, 2); + Tensor y = tf.full({5, 4}, 3); + + // Output shape should be (3, 5) + Tensor out = tf.zeros({3, 5}); + + op_linear_out(x, y, out); + + Tensor expected = tf.full({3, 5}, 24); + + EXPECT_TENSOR_EQ(out, expected); + } +}; + +TEST_F(OpLinearOutTest, OutputDim) { + TensorFactory tf; + + // 3 tensors with compatible dimensions: (3, 5), (3, 4) and (4, 5). + Tensor x = tf.ones({3, 4}); + Tensor y = tf.ones({5, 4}); + Tensor out = tf.zeros({3, 5}); + + Tensor ret = op_linear_out(x, y, out); + + // Should always return the provided out Tensor. + EXPECT_TENSOR_EQ(ret, out); + + // Expected tensor, filled with 4. + Tensor expected = tf.full({3, 5}, 4); + + EXPECT_TENSOR_EQ(out, expected); +} + +/// A generic smoke test that works for any dtype that supports ones() and +/// zeros(). +TEST_F(OpLinearOutTest, AllDtypesSupported) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY + // TODO: Also add tests for half, complex, quantized, and other types. Easiest + // way to do that would be to make TensorFactory support zeros() and ones() + // for those types. +} + +TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) { + TensorFactory tf; + + // Empty input matrices + Tensor x = tf.make({0, 3}, {}); + Tensor y = tf.make({0, 3}, {}); + + // Output matrix is also empty + Tensor out = tf.make({0, 0}, {}); + + Tensor expected = tf.make({0, 0}, {}); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected); +} + +TEST_F(OpLinearOutTest, InfinityTensorPasses) { + TensorFactory tff; + + Tensor x = tff.full({3, 4}, std::numeric_limits::infinity()); + Tensor y = tff.full({5, 4}, 3); + + // Output shape should be (3, 5) + Tensor out = tff.zeros({3, 5}); + + Tensor expected = tff.full({3, 5}, std::numeric_limits::infinity()); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected); +} + +TEST_F(OpLinearOutTest, MismatchedDimensionsDies) { + TensorFactory tf; + + Tensor x = tf.full({2, 2}, 3); + + Tensor wrong_y = tf.full({1, 3}, 1); + Tensor right_y = tf.full({2, 2}, 1); + + // Make an empty out tensor and demonstrate that it's empty. + Tensor out = tf.full({2, 2}, 0); + + Tensor expected = tf.full({2, 2}, 6); + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, out)); + + EXPECT_TENSOR_EQ(op_linear_out(x, right_y, out), expected); +} + +TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle mismatched dimension size"; + } + TensorFactory tf; + Tensor x = tf.full({2, 2}, 3); + + // wrong_y has incompatible dim + Tensor wrong_y = tf.full({2, 2, 2}, 1); + Tensor right_y = tf.full({2, 2}, 1); + + // wrong_out has incompatible dim + Tensor right_out = tf.ones({2, 2}); + Tensor wrong_out = tf.ones({2, 2, 3}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, right_y, wrong_out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, right_out)); +} + +TEST_F(OpLinearOutTest, WrongOutShapeDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle wrong out shape"; + } + TensorFactory tf; + Tensor x = tf.ones({10, 3}); + + Tensor y = tf.ones({4, 3}); + + // wrong_out has incompatible shape + Tensor right_out = tf.ones({10, 4}); + Tensor wrong_out = tf.ones({7, 5}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, wrong_out)); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, right_out), tf.full({10, 4}, 3)); +} + +TEST_F(OpLinearOutTest, DynamicShapeUpperBoundSameAsExpected) { + TensorFactory tf; + + Tensor x = tf.make( + {3, 2}, + {0.17412060499191284, + 0.34793388843536377, + 0.8187907934188843, + 0.9979893565177917, + 0.7049332857131958, + 0.4255824089050293}); + Tensor y = tf.make( + {4, 2}, + {0.8071839213371277, + 0.31638312339782715, + 0.13667285442352295, + 0.3691965937614441, + 0.9002121090888977, + 0.09420186281204224, + 0.9070476293563843, + 0.9310881495475769}); + Tensor expected_result = tf.make( + {3, 4}, + {0.2506277561187744, + 0.15225356817245483, + 0.18952149152755737, + 0.48189279437065125, + 0.976661741733551, + 0.480360746383667, + 0.8310978412628174, + 1.6718982458114624, + 0.703657865524292, + 0.2534688115119934, + 0.6746801733970642, + 1.0356627702713013}); + + Tensor out = + tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor ret = op_linear_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +TEST_F(OpLinearOutTest, DynamicShapeUpperBoundLargerThanExpected) { + TensorFactory tf; + + Tensor x = tf.make( + {3, 2}, + {0.17412060499191284, + 0.34793388843536377, + 0.8187907934188843, + 0.9979893565177917, + 0.7049332857131958, + 0.4255824089050293}); + Tensor y = tf.make( + {4, 2}, + {0.8071839213371277, + 0.31638312339782715, + 0.13667285442352295, + 0.3691965937614441, + 0.9002121090888977, + 0.09420186281204224, + 0.9070476293563843, + 0.9310881495475769}); + Tensor expected_result = tf.make( + {3, 4}, + {0.2506277561187744, + 0.15225356817245483, + 0.18952149152755737, + 0.48189279437065125, + 0.976661741733551, + 0.480360746383667, + 0.8310978412628174, + 1.6718982458114624, + 0.703657865524292, + 0.2534688115119934, + 0.6746801733970642, + 1.0356627702713013}); + + Tensor out = + tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor ret = op_linear_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +TEST_F(OpLinearOutTest, DynamicShapeUnbound) { + GTEST_SKIP() << "Dynamic shape not supported"; + TensorFactory tf; + + Tensor x = tf.make( + {3, 2}, + {0.17412060499191284, + 0.34793388843536377, + 0.8187907934188843, + 0.9979893565177917, + 0.7049332857131958, + 0.4255824089050293}); + Tensor y = tf.make( + {4, 2}, + {0.8071839213371277, + 0.31638312339782715, + 0.13667285442352295, + 0.3691965937614441, + 0.9002121090888977, + 0.09420186281204224, + 0.9070476293563843, + 0.9310881495475769}); + Tensor expected_result = tf.make( + {3, 4}, + {0.2506277561187744, + 0.15225356817245483, + 0.18952149152755737, + 0.48189279437065125, + 0.976661741733551, + 0.480360746383667, + 0.8310978412628174, + 1.6718982458114624, + 0.703657865524292, + 0.2534688115119934, + 0.6746801733970642, + 1.0356627702713013}); + + Tensor out = + tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); + Tensor ret = op_linear_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +// TODO: support and test bias diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index cd3ca556fe6..f8ea484435a 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -226,6 +226,7 @@ def define_common_targets(): _common_op_test("op_le_test", ["aten", "portable", "optimized"]) _common_op_test("op_leaky_relu_test", ["aten", "portable"]) _common_op_test("op_lift_fresh_copy_test", ["aten", "portable"]) + _common_op_test("op_linear_test", ["aten", "optimized"]) _common_op_test("op_log_softmax_test", ["aten", "portable", "optimized"]) _common_op_test("op_log_test", ["aten", "portable"]) _common_op_test("op_log10_test", ["aten", "portable"])