diff --git a/kernels/portable/cpu/op_native_batch_norm.cpp b/kernels/portable/cpu/op_native_batch_norm.cpp index 546212a6b3f..100b1a7fb27 100644 --- a/kernels/portable/cpu/op_native_batch_norm.cpp +++ b/kernels/portable/cpu/op_native_batch_norm.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -18,6 +19,7 @@ namespace executor { namespace native { using Tensor = exec_aten::Tensor; +using SizesType = exec_aten::SizesType; std::tuple _native_batch_norm_legit_no_training_out( KernelRuntimeContext& ctx, @@ -184,27 +186,131 @@ std::tuple _native_batch_norm_legit_no_stats_out( Tensor& mean_out, Tensor& invstd_out) { (void)ctx; - (void)in; - (void)weight; - (void)bias; - (void)momentum; - (void)eps; + (void)training; std::tuple ret_val(out, mean_out, invstd_out); - ET_KERNEL_CHECK_MSG( + ET_KERNEL_CHECK( ctx, - training == false, + check_batch_norm_args( + in, + weight, + bias, + exec_aten::optional(), + exec_aten::optional(), + momentum, + eps, + out, + mean_out, + invstd_out), InvalidArgument, - ret_val, - "Portable kernels only support inference mode!"); + ret_val); - ET_KERNEL_CHECK_MSG( + ET_KERNEL_CHECK( ctx, - training == true, + is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size()), InvalidArgument, - ret_val, - "running_mean & running_var must be provided during inference!"); + ret_val); + + ET_KERNEL_CHECK( + ctx, + tensors_have_same_dim_order(in, out, mean_out, invstd_out), + InvalidArgument, + ret_val); + + if (weight.has_value()) { + ET_KERNEL_CHECK( + ctx, + tensors_have_same_dim_order(in, weight.value()), + InvalidArgument, + ret_val); + } + + if (bias.has_value()) { + ET_KERNEL_CHECK( + ctx, + tensors_have_same_dim_order(in, bias.value()), + InvalidArgument, + ret_val); + } + + ET_KERNEL_CHECK(ctx, in.dim() >= 2, InvalidArgument, ret_val); + + size_t N = in.size(0); + size_t C = in.size(1); + size_t inner = getTrailingDims(in, 1); + size_t elements_per_channel = N * inner; + + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, in.sizes()) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(mean_out, {static_cast(C)}) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(invstd_out, {static_cast(C)}) == Error::Ok, + InvalidArgument, + ret_val); + + constexpr auto name = "_native_batch_norm_legit.no_stats_out"; + + ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { + const CTYPE* in_data = in.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + CTYPE* mean_data = mean_out.mutable_data_ptr(); + CTYPE* invstd_data = invstd_out.mutable_data_ptr(); + + // Compute sum and sum of squares for each channel + for (size_t b = 0; b < N; ++b) { + const CTYPE* b_in_data = in_data + b * C * inner; + for (size_t c = 0; c < C; ++c) { + const CTYPE* x = b_in_data + c * inner; + + CTYPE sum = reduce_add(x, inner); + CTYPE sq_sum = vec_powerf(x, inner); + + mean_data[c] += sum; + invstd_data[c] += sq_sum; + } + } + + // Compute mean and invstd for each channel + for (size_t c = 0; c < C; ++c) { + CTYPE mean = mean_data[c] / elements_per_channel; + // Var[x] = E[x^2] - E[x]^2 + CTYPE var = invstd_data[c] / elements_per_channel - mean * mean; + CTYPE invstd = 1.0 / std::sqrt(var + eps); + mean_data[c] = mean; + invstd_data[c] = invstd; + } + + for (size_t i = 0; i < N; ++i) { + for (size_t c = 0; c < C; ++c) { + CTYPE mean = mean_data[c]; + CTYPE invstd = invstd_data[c]; + CTYPE weight_val = 1; + if (weight.has_value()) { + weight_val = weight.value().const_data_ptr()[c]; + } + CTYPE bias_val = 0; + if (bias.has_value()) { + bias_val = bias.value().const_data_ptr()[c]; + } + for (size_t j = 0; j < inner; ++j) { + *out_data = (*in_data - mean) * invstd * weight_val + bias_val; + out_data++; + in_data++; + } + } + } + }); return ret_val; } diff --git a/kernels/portable/cpu/util/normalization_ops_util.cpp b/kernels/portable/cpu/util/normalization_ops_util.cpp index 6b2b12bf14e..f16963aa5f8 100644 --- a/kernels/portable/cpu/util/normalization_ops_util.cpp +++ b/kernels/portable/cpu/util/normalization_ops_util.cpp @@ -19,35 +19,35 @@ bool check_batch_norm_args( const Tensor& in, const exec_aten::optional& weight, const exec_aten::optional& bias, - const Tensor& running_mean, - const Tensor& running_var, + const exec_aten::optional& running_mean, + const exec_aten::optional& running_var, double momentum, double eps, Tensor& out, Tensor& mean_out, Tensor& var_out) { // All tensors must be the same dtype - ET_LOG_AND_RETURN_IF_FALSE( - tensors_have_same_dtype(in, running_mean, running_var)); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out)); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out)); if (weight.has_value()) { ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value())); } if (bias.has_value()) { ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value())); } + if (running_mean.has_value()) { + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_dtype(in, running_mean.value())); + } + if (running_mean.has_value()) { + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_dtype(in, running_var.value())); + } + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out)); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out)); size_t C_dim = in.dim() >= 1 ? 1 : 0; // All parameter tensors must be of dim 1 and have length equal to the // channels dim of in - ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean, 1)); - ET_LOG_AND_RETURN_IF_FALSE( - tensors_have_same_size_at_dims(running_mean, 0, in, C_dim)); - ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var, 1)); - ET_LOG_AND_RETURN_IF_FALSE( - tensors_have_same_size_at_dims(running_var, 0, in, C_dim)); if (weight.has_value()) { ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1)); ET_LOG_AND_RETURN_IF_FALSE( @@ -58,6 +58,16 @@ bool check_batch_norm_args( ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_size_at_dims(bias.value(), 0, in, C_dim)); } + if (running_mean.has_value()) { + ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean.value(), 1)); + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_size_at_dims(running_mean.value(), 0, in, C_dim)); + } + if (running_var.has_value()) { + ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var.value(), 1)); + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_size_at_dims(running_var.value(), 0, in, C_dim)); + } return true; } diff --git a/kernels/portable/cpu/util/normalization_ops_util.h b/kernels/portable/cpu/util/normalization_ops_util.h index 59d43d700c7..fb4d889785f 100644 --- a/kernels/portable/cpu/util/normalization_ops_util.h +++ b/kernels/portable/cpu/util/normalization_ops_util.h @@ -17,8 +17,8 @@ bool check_batch_norm_args( const Tensor& in, const exec_aten::optional& weight, const exec_aten::optional& bias, - const Tensor& running_mean, - const Tensor& running_var, + const exec_aten::optional& running_mean, + const exec_aten::optional& running_var, double momentum, double eps, Tensor& out, diff --git a/kernels/test/op_native_batch_norm_test.cpp b/kernels/test/op_native_batch_norm_test.cpp index c6810f737fd..ba593d8dc4d 100644 --- a/kernels/test/op_native_batch_norm_test.cpp +++ b/kernels/test/op_native_batch_norm_test.cpp @@ -78,6 +78,33 @@ class OpNativeBatchNormLegitOutTest : public OperatorTest { } }; +class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest { + protected: + ::std::tuple + op_native_batch_norm_legit_no_stats_out( + const exec_aten::Tensor& input, + const exec_aten::optional& weight, + const exec_aten::optional& bias, + bool training, + double momentum, + double eps, + exec_aten::Tensor& out0, + exec_aten::Tensor& out1, + exec_aten::Tensor& out2) { + return torch::executor::aten::_native_batch_norm_legit_outf( + context_, + input, + weight, + bias, + training, + momentum, + eps, + out0, + out1, + out2); + } +}; + TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) { torch::executor::testing::TensorFactory tfFloat; @@ -949,3 +976,111 @@ TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) { EXPECT_TENSOR_CLOSE(out1, out1_expected); EXPECT_TENSOR_CLOSE(out2, out2_expected); } + +TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) { + torch::executor::testing::TensorFactory tfFloat; + + exec_aten::Tensor input = + tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); + exec_aten::optional weight = + exec_aten::optional(); + exec_aten::optional bias = + exec_aten::optional(); + bool training = true; + double momentum = 1e-3; + double eps = 1e-5; + exec_aten::Tensor out0 = tfFloat.zeros({3, 4}); + exec_aten::Tensor out1 = tfFloat.zeros({4}); + exec_aten::Tensor out2 = tfFloat.zeros({4}); + exec_aten::Tensor out0_expected = tfFloat.make( + {3, 4}, + {-0.98058063, + -1.03422451, + -1.06904495, + -1.09332705, + -0.39223224, + -0.31822300, + -0.26726127, + -0.23017406, + 1.37281299, + 1.35244739, + 1.33630610, + 1.32350123}); + exec_aten::Tensor out1_expected = + tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794}); + exec_aten::Tensor out2_expected = + tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882}); + op_native_batch_norm_legit_no_stats_out( + input, weight, bias, training, momentum, eps, out0, out1, out2); + EXPECT_TENSOR_CLOSE(out0, out0_expected); + EXPECT_TENSOR_CLOSE(out1, out1_expected); + EXPECT_TENSOR_CLOSE(out2, out2_expected); +} + +TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) { + torch::executor::testing::TensorFactory tfFloat; + + exec_aten::Tensor input = tfFloat.make( + {2, 3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, + 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529}); + exec_aten::optional weight = + exec_aten::optional(); + exec_aten::optional bias = + exec_aten::optional(); + bool training = true; + double momentum = 1e-3; + double eps = 1e-5; + exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 4}); + exec_aten::Tensor out1 = tfFloat.zeros({3}); + exec_aten::Tensor out2 = tfFloat.zeros({3}); + exec_aten::Tensor out0_expected = tfFloat.make( + {2, 3, 4}, + {-1.01045656, -0.99964952, -0.96722847, -0.91319335, -1.08850884, + -1.02468753, -0.94668359, -0.85449719, -1.12558389, -1.03595889, + -0.93578988, -0.82507670, 0.54575467, 0.81593025, 1.10771990, + 1.42112350, 0.61339414, 0.84740579, 1.09560001, 1.35797679, + 0.64582670, 0.86198103, 1.08867943, 1.32592189}); + exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5}); + exec_aten::Tensor out2_expected = + tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206}); + op_native_batch_norm_legit_no_stats_out( + input, weight, bias, training, momentum, eps, out0, out1, out2); + EXPECT_TENSOR_CLOSE(out0, out0_expected); + EXPECT_TENSOR_CLOSE(out1, out1_expected); + EXPECT_TENSOR_CLOSE(out2, out2_expected); +} + +TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest4D) { + torch::executor::testing::TensorFactory tfFloat; + + exec_aten::Tensor input = + tfFloat.make({2, 3, 2, 2}, {0, 1, 4, 9, 16, 25, 36, 49, + 64, 81, 100, 121, 144, 169, 196, 225, + 256, 289, 324, 361, 400, 441, 484, 529}); + exec_aten::optional weight = + exec_aten::optional( + tfFloat.make({3}, {1.1, 0.7, 0.3})); + exec_aten::optional bias = + exec_aten::optional( + tfFloat.make({3}, {1.7, 2.2, 3.3})); + bool training = true; + double momentum = 1e-3; + double eps = 1e-5; + exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 2, 2}); + exec_aten::Tensor out1 = tfFloat.zeros({3}); + exec_aten::Tensor out2 = tfFloat.zeros({3}); + exec_aten::Tensor out0_expected = tfFloat.make( + {2, 3, 2, 2}, + {0.58849782, 0.60038555, 0.63604873, 0.69548732, 1.43804383, 1.48271883, + 1.53732157, 1.60185206, 2.96232486, 2.98921227, 3.01926303, 3.05247688, + 2.30033016, 2.59752321, 2.91849184, 3.26323581, 2.62937593, 2.79318404, + 2.96691990, 3.15058374, 3.49374819, 3.55859423, 3.62660384, 3.69777656}); + exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5}); + exec_aten::Tensor out2_expected = + tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206}); + op_native_batch_norm_legit_no_stats_out( + input, weight, bias, training, momentum, eps, out0, out1, out2); + EXPECT_TENSOR_CLOSE(out0, out0_expected); + EXPECT_TENSOR_CLOSE(out1, out1_expected); + EXPECT_TENSOR_CLOSE(out2, out2_expected); +} diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 140a71cbfe8..53698e7f216 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -867,6 +867,7 @@ ATEN_OPS = ( op_target( name = "op_native_batch_norm", deps = [ + ":vec_ops", "//executorch/kernels/portable/cpu/util:normalization_ops_util", ], ),