diff --git a/kernels/portable/cpu/op_native_group_norm.cpp b/kernels/portable/cpu/op_native_group_norm.cpp index b13e6c2e5ed..7882204e57e 100644 --- a/kernels/portable/cpu/op_native_group_norm.cpp +++ b/kernels/portable/cpu/op_native_group_norm.cpp @@ -30,7 +30,7 @@ void group_norm( int64_t sC, int64_t sHxW, int64_t group, - CTYPE eps, + double eps, Tensor& out, Tensor& mean, Tensor& rstd) { @@ -77,37 +77,43 @@ void group_norm( const CTYPE* x = input_data + i * inner_size; // compute E[X] and Var[x] = E[x^2] - E[x]^2 - CTYPE sum = reduce_add(x, inner_size); - CTYPE sq_sum = vec_powerf(x, inner_size); - CTYPE mean_value = sum / inner_size; - CTYPE variance = sq_sum / inner_size - mean_value * mean_value; - CTYPE std = std::sqrt(variance + eps); - CTYPE rstd_value = 1.0 / std; + CTYPE sum = reduce_add(x, static_cast(inner_size)); + CTYPE sq_sum = vec_powerf(x, static_cast(inner_size)); + double mean_value = + static_cast(sum) / static_cast(inner_size); + double variance = + static_cast(sq_sum) / static_cast(inner_size) - + mean_value * mean_value; + double std = std::sqrt(variance + eps); + double rstd_value = 1.0 / std; // Calculate the elements of output if (weight_data == nullptr && bias_data == nullptr) { CTYPE* y = out_data + i * inner_size; for (const auto j : c10::irange(inner_size)) { - y[j] = (x[j] - mean_value) * rstd_value; + y[j] = static_cast( + (static_cast(x[j]) - mean_value) * rstd_value); } } else { const size_t g = i % G; for (const auto j : c10::irange(D)) { const size_t ch = g * D + j; - const CTYPE scale = - rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]); - const CTYPE beta = - -scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]); + const double scale = rstd_value * + (weight_data == nullptr ? double(1.0) + : static_cast(weight_data[ch])); + const double beta = -scale * mean_value + + (bias_data == nullptr ? double(0.0) + : static_cast(bias_data[ch])); x = input_data + (i * D + j) * HxW; CTYPE* y = out_data + (i * D + j) * HxW; for (const auto k : c10::irange(HxW)) { - y[k] = scale * x[k] + beta; + y[k] = static_cast(scale * static_cast(x[k]) + beta); } } } - mean_data[i] = mean_value; - rstd_data[i] = rstd_value; + mean_data[i] = static_cast(mean_value); + rstd_data[i] = static_cast(rstd_value); } } @@ -186,7 +192,7 @@ std::tuple native_group_norm_out( constexpr auto name = "native_group_norm.out"; - ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() { + ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() { group_norm( input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out); }); diff --git a/kernels/test/op_native_group_norm_test.cpp b/kernels/test/op_native_group_norm_test.cpp index e196899fbca..591df6e186b 100644 --- a/kernels/test/op_native_group_norm_test.cpp +++ b/kernels/test/op_native_group_norm_test.cpp @@ -20,110 +20,319 @@ using executorch::aten::Tensor; using std::optional; using torch::executor::testing::TensorFactory; -::std::tuple op_native_group_norm_out( - const Tensor& input, - const optional& weight, - const optional& bias, - int64_t N, - int64_t C, - int64_t HxW, - int64_t group, - double eps, - Tensor& out0, - Tensor& out1, - Tensor& out2) { - executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext context{}; - return torch::executor::aten::native_group_norm_outf( - context, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); +class OpNativeGroupNormTest : public OperatorTest { + protected: + ::std::tuple op_native_group_norm_out( + const Tensor& input, + const optional& weight, + const optional& bias, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + double eps, + Tensor& out0, + Tensor& out1, + Tensor& out2) { + return torch::executor::aten::native_group_norm_outf( + context_, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + template + struct NativeGroupNormTestCase { + using ctype = typename TensorFactory::ctype; + + // Size vector for the input/output + const std::vector sizes; + // Data for the input tensor; must agree with `sizes`. + const std::vector input_data; + // Affine transform weight. + const std::vector weight_data; + // Affine transform bias. + const std::vector bias_data; + // Batch size N + const int64_t N; + // Number of channels C + const int64_t C; + // Spatial size HxW + const int64_t HxW; + // Number of groups + const int64_t group; + // a value added to the denominator for numerical stability + const ctype eps; + // The expected output data. + const std::vector expected_data; + // The expected mean data. + const std::vector expected_mean_data; + // The expected rstd data. + const std::vector expected_rstd_data; + }; + + /// Runs the provided test cases. + template + void run_test_cases(std::vector> test_cases) { + TensorFactory tf; + for (const auto& test_case : test_cases) { + Tensor in = tf.make(test_case.sizes, test_case.input_data); + optional weight = + tf.make({static_cast(test_case.C)}, test_case.weight_data); + optional bias = + tf.make({static_cast(test_case.C)}, test_case.bias_data); + Tensor out0 = tf.zeros(test_case.sizes); + Tensor out1 = tf.zeros( + {static_cast(test_case.N), + static_cast(test_case.group)}, + torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor out2 = tf.zeros( + {static_cast(test_case.N), + static_cast(test_case.group)}, + torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + + auto result = op_native_group_norm_out( + in, + weight, + bias, + test_case.N, + test_case.C, + test_case.HxW, + test_case.group, + test_case.eps, + out0, + out1, + out2); + EXPECT_TENSOR_CLOSE(out0, std::get<0>(result)); + + Tensor expected = tf.make(test_case.sizes, test_case.expected_data); + Tensor expected_mean = tf.make( + {static_cast(test_case.N), + static_cast(test_case.group)}, + test_case.expected_mean_data); + Tensor expected_rstd = tf.make( + {static_cast(test_case.N), + static_cast(test_case.group)}, + test_case.expected_rstd_data); + + if constexpr (DTYPE == ScalarType::Half) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultHalfAtol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out1, + expected_mean, + 1e-2, + executorch::runtime::testing::internal::kDefaultHalfAtol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out2, + expected_rstd, + 1e-2, + executorch::runtime::testing::internal::kDefaultHalfAtol); + } else if constexpr (DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultBFloat16Atol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out1, + expected_mean, + 1e-2, + executorch::runtime::testing::internal::kDefaultBFloat16Atol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out2, + expected_rstd, + 1e-2, + executorch::runtime::testing::internal::kDefaultBFloat16Atol); + } else { + EXPECT_TENSOR_CLOSE(out0, expected); + EXPECT_TENSOR_CLOSE(out1, expected_mean); + EXPECT_TENSOR_CLOSE(out2, expected_rstd); + } + } + } + + template + void run_floating_point_test_cases() { + std::vector> test_cases = { + { + {5, 6, 2, 2}, // sizes + {-0.8125, 0.0625, -2.7500, -3.0625, -1.1250, -2.1250, -1.3125, + -4.0625, 2.8125, -2.0625, 4.2500, 3.5000, -0.3750, 1.6250, + 4.3125, -1.0625, -2.8750, 3.3750, 4.9375, 4.0625, -3.0625, + -1.8750, -2.7500, -2.5625, -0.1875, -3.0000, -2.7500, 0.6875, + -3.2500, -3.1875, 1.0000, -4.6250, -0.1875, -1.7500, 4.5000, + -1.8750, -2.6875, 4.8125, -3.8125, -2.9375, -1.1875, 2.8750, + 0.7500, 2.8750, 1.1250, -0.6250, -2.2500, -3.7500, 3.2500, + -0.3750, -2.0625, -4.7500, 2.0625, 3.0000, -3.1875, -4.1250, + -3.7500, 1.2500, -2.3125, 1.5625, 3.1250, 0.3125, 3.2500, + -2.7500, -3.8125, -4.2500, -4.3125, -0.5625, -0.4375, 2.9375, + -1.3750, -0.6250, -2.5625, -4.5625, 0.1250, -3.5000, -5.0000, + -1.0000, -4.6875, -0.6875, 1.1250, 1.8750, -4.5000, 4.3125, + 4.5625, 0.2500, -3.6250, 4.5625, -3.5000, -2.1250, -3.6250, + -2.9375, 3.6875, 3.9375, 4.3750, 3.0625, 2.4375, 2.0625, + -2.4375, -3.9375, 3.6875, 2.7500, -0.8750, -0.9375, 2.7500, + -2.4375, -2.3750, -0.9375, -4.8750, 0.1875, 3.5000, -2.0000, + -0.2500, -2.7500, 0.3125, 1.2500, -0.5625, 0.0000, 1.8125, + 1.0625}, // input_data + {4.5625, -2.8750, -0.6875, 0.5625, -2.0625, -2.7500}, // weight_data + {-0.5000, -2.7500, 1.1875, 3.6875, 3.8125, 4.6875}, // bias_data + 5, // N + 6, // C + 4, // HxW + 3, // group + 1e-5, // eps + {3.419882, 6.578348, -3.573864, -4.701888, + -4.509254, -2.234663, -4.082768, 2.172355, + 0.838826, 2.270225, 0.416747, 0.636962, + 3.207030, 3.687500, 4.333131, 3.041869, + 5.547079, 1.649148, 0.674665, 1.220376, + 7.156189, 6.168714, 6.896327, 6.740410, + 3.509863, -3.022041, -2.441427, 5.542011, + -0.794903, -0.886369, -7.014627, 1.217361, + 1.120617, 1.463606, 0.091652, 1.491045, + 3.293219, 4.640229, 3.091168, 3.248319, + 4.895990, 1.114683, 3.092597, 1.114683, + 3.262238, 5.434066, 7.450763, 9.312329, + 5.570122, 0.101119, -2.444796, -6.499403, + -5.446074, -6.337338, -0.454995, 0.436269, + 2.228491, 0.871598, 1.838385, 0.786793, + 4.362284, 3.737805, 4.390039, 3.057817, + 5.814659, 6.202621, 6.258044, 2.932658, + 3.366583, -0.623879, 4.475045, 3.588276, + -0.082914, -4.936279, 6.438795, -2.357929, + 0.714463, -5.402106, 0.236606, -5.879963, + 1.176247, 1.021916, 2.333727, 0.520341, + 4.275447, 3.549392, 2.896994, 4.275447, + 6.120910, 5.298480, 6.195676, 5.784461, + 2.033296, 1.833920, 1.485010, 2.531738, + 3.193988, 2.532378, -5.406940, -8.053379, + -6.467402, -5.425139, -1.395059, -1.325575, + 0.266062, 1.622680, 1.606336, 1.230405, + 2.809896, 3.893110, 4.601880, 3.425055, + 4.374411, 8.283354, 3.494898, 2.029045, + 6.088204, 4.915522, 1.136877, 2.700454}, // expected_data + {-1.89843750, + 1.62500000, + -0.09375000, + -1.91406250, + -0.49218744, + -0.02343750, + -0.77343756, + 0.08593753, + -1.55468738, + -2.73437500, + 1.07031238, + 0.35937503, + 0.34374997, + -0.77343750, + 0.10937499}, // expected_mean_data + {0.79116172, + 0.42708409, + 0.30238494, + 0.50903118, + 0.31929117, + 0.45128885, + 0.33067191, + 0.39473253, + 0.42994878, + 0.53187561, + 0.29930803, + 0.29000264, + 0.38669431, + 0.38038814, + 0.75809801}, // expected_rstd_data + }, + }; + + run_test_cases(test_cases); + } + + // Runs death test cases. + template + void run_death_test_cases( + std::vector> test_cases) { + TensorFactory tf; + for (const auto& test_case : test_cases) { + Tensor in = tf.make(test_case.sizes, test_case.input_data); + std::optional weight, bias; + if (!test_case.weight_data.empty()) { + weight = + tf.make({static_cast(test_case.C)}, test_case.weight_data); + } + if (!test_case.bias_data.empty()) { + bias = + tf.make({static_cast(test_case.C)}, test_case.bias_data); + } + Tensor out0 = tf.zeros(test_case.sizes); + Tensor out1 = tf.zeros( + {static_cast(test_case.N), + static_cast(test_case.group)}); + Tensor out2 = tf.zeros( + {static_cast(test_case.N), + static_cast(test_case.group)}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_native_group_norm_out( + in, + weight, + bias, + test_case.N, + test_case.C, + test_case.HxW, + test_case.group, + test_case.eps, + out0, + out1, + out2)); + } + } + + // Test cases with imcompatible types. + template + void run_int_test_cases() { + std::vector> test_cases = { + { + // Cannot be represented by a type other than float. + {2, 4, 2, 2}, // sizes + {1, 0, -1, -1, 4, 0, 2, -2, 1, 0, -1, -1, 4, 0, 2, -2, 1, + 0, -1, -1, 4, 0, 2, -2, 1, 0, -1, -1, 4, 0, 2, -2}, // input_data + {1, 1, 1, 1}, // weights + {0, 0, 0, 0}, // bias + 2, // N + 4, // C + 4, // HxW + 2, // group + 1, // eps + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // expected_data + {0, 0, 0, 0}, // expected_mean_data + {1, 1, 1, 1}, // expected_rstd_data + }, + }; + run_death_test_cases(test_cases); + } +}; + +/// Describes a test case, using tensors of the specified DTYPE. +TEST_F(OpNativeGroupNormTest, DoubleTensors) { + run_floating_point_test_cases(); +} + +TEST_F(OpNativeGroupNormTest, FloatTensors) { + run_floating_point_test_cases(); +} + +TEST_F(OpNativeGroupNormTest, HalfTensors) { + run_floating_point_test_cases(); +} + +TEST_F(OpNativeGroupNormTest, BFloat16Tensors) { + run_floating_point_test_cases(); } -TEST(OpNativeGroupNormOutTest, SmokeTest) { - TensorFactory tfFloat; - - Tensor input = tfFloat.make( - {5, 6, 2, 2}, - {-0.8125, 0.0625, -2.7500, -3.0625, -1.1250, -2.1250, -1.3125, -4.0625, - 2.8125, -2.0625, 4.2500, 3.5000, -0.3750, 1.6250, 4.3125, -1.0625, - -2.8750, 3.3750, 4.9375, 4.0625, -3.0625, -1.8750, -2.7500, -2.5625, - -0.1875, -3.0000, -2.7500, 0.6875, -3.2500, -3.1875, 1.0000, -4.6250, - -0.1875, -1.7500, 4.5000, -1.8750, -2.6875, 4.8125, -3.8125, -2.9375, - -1.1875, 2.8750, 0.7500, 2.8750, 1.1250, -0.6250, -2.2500, -3.7500, - 3.2500, -0.3750, -2.0625, -4.7500, 2.0625, 3.0000, -3.1875, -4.1250, - -3.7500, 1.2500, -2.3125, 1.5625, 3.1250, 0.3125, 3.2500, -2.7500, - -3.8125, -4.2500, -4.3125, -0.5625, -0.4375, 2.9375, -1.3750, -0.6250, - -2.5625, -4.5625, 0.1250, -3.5000, -5.0000, -1.0000, -4.6875, -0.6875, - 1.1250, 1.8750, -4.5000, 4.3125, 4.5625, 0.2500, -3.6250, 4.5625, - -3.5000, -2.1250, -3.6250, -2.9375, 3.6875, 3.9375, 4.3750, 3.0625, - 2.4375, 2.0625, -2.4375, -3.9375, 3.6875, 2.7500, -0.8750, -0.9375, - 2.7500, -2.4375, -2.3750, -0.9375, -4.8750, 0.1875, 3.5000, -2.0000, - -0.2500, -2.7500, 0.3125, 1.2500, -0.5625, 0.0000, 1.8125, 1.0625}); - optional weight = - tfFloat.make({6}, {4.5625, -2.8750, -0.6875, 0.5625, -2.0625, -2.7500}); - optional bias = - tfFloat.make({6}, {-0.5000, -2.7500, 1.1875, 3.6875, 3.8125, 4.6875}); - double eps = 1e-5; - Tensor out0 = tfFloat.zeros({5, 6, 2, 2}); - Tensor out1 = tfFloat.zeros({5, 3}); - Tensor out2 = tfFloat.zeros({5, 3}); - Tensor out0_expected = tfFloat.make( - {5, 6, 2, 2}, - {3.419882, 6.578348, -3.573864, -4.701888, -4.509254, -2.234663, - -4.082768, 2.172355, 0.838826, 2.270225, 0.416747, 0.636962, - 3.207030, 3.687500, 4.333131, 3.041869, 5.547079, 1.649148, - 0.674665, 1.220376, 7.156189, 6.168714, 6.896327, 6.740410, - 3.509863, -3.022041, -2.441427, 5.542011, -0.794903, -0.886369, - -7.014627, 1.217361, 1.120617, 1.463606, 0.091652, 1.491045, - 3.293219, 4.640229, 3.091168, 3.248319, 4.895990, 1.114683, - 3.092597, 1.114683, 3.262238, 5.434066, 7.450763, 9.312329, - 5.570122, 0.101119, -2.444796, -6.499403, -5.446074, -6.337338, - -0.454995, 0.436269, 2.228491, 0.871598, 1.838385, 0.786793, - 4.362284, 3.737805, 4.390039, 3.057817, 5.814659, 6.202621, - 6.258044, 2.932658, 3.366583, -0.623879, 4.475045, 3.588276, - -0.082914, -4.936279, 6.438795, -2.357929, 0.714463, -5.402106, - 0.236606, -5.879963, 1.176247, 1.021916, 2.333727, 0.520341, - 4.275447, 3.549392, 2.896994, 4.275447, 6.120910, 5.298480, - 6.195676, 5.784461, 2.033296, 1.833920, 1.485010, 2.531738, - 3.193988, 2.532378, -5.406940, -8.053379, -6.467402, -5.425139, - -1.395059, -1.325575, 0.266062, 1.622680, 1.606336, 1.230405, - 2.809896, 3.893110, 4.601880, 3.425055, 4.374411, 8.283354, - 3.494898, 2.029045, 6.088204, 4.915522, 1.136877, 2.700454}); - Tensor out1_expected = tfFloat.make( - {5, 3}, - {-1.89843750, - 1.62500000, - -0.09375000, - -1.91406250, - -0.49218744, - -0.02343750, - -0.77343756, - 0.08593753, - -1.55468738, - -2.73437500, - 1.07031238, - 0.35937503, - 0.34374997, - -0.77343750, - 0.10937499}); - Tensor out2_expected = tfFloat.make( - {5, 3}, - {0.79116172, - 0.42708409, - 0.30238494, - 0.50903118, - 0.31929117, - 0.45128885, - 0.33067191, - 0.39473253, - 0.42994878, - 0.53187561, - 0.29930803, - 0.29000264, - 0.38669431, - 0.38038814, - 0.75809801}); - op_native_group_norm_out( - input, weight, bias, 5, 6, 4, 3, eps, out0, out1, out2); - EXPECT_TENSOR_CLOSE(out0, out0_expected); - EXPECT_TENSOR_CLOSE(out1, out1_expected); - EXPECT_TENSOR_CLOSE(out2, out2_expected); +TEST_F(OpNativeGroupNormTest, IntTensorsDies) { + // Cannot be represented by a type other than float. + run_int_test_cases(); }