diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index f03ab40f115..63e06da4c16 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -28,6 +28,19 @@ class OpGluOutTest : public OperatorTest { return torch::executor::aten::glu_outf(context_, self, dim, out); } + template + void expect_tensor_close(Tensor actual, Tensor expected) { + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + actual, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(actual, expected); + } + } + // Common testing for glu operator template void test_glu_out() { @@ -41,14 +54,14 @@ class OpGluOutTest : public OperatorTest { Tensor in = tf.ones(sizes); Tensor out = tf_out.zeros(out_sizes_1); op_glu_out(in, 0, out); - EXPECT_TENSOR_CLOSE( + expect_tensor_close( out, tf_out.make( out_sizes_1, /*data=*/{0.731059, 0.731059, 0.731059, 0.731059})); const std::vector out_sizes_2 = {4, 1}; out = tf_out.zeros(out_sizes_2); op_glu_out(in, 1, out); - EXPECT_TENSOR_CLOSE( + expect_tensor_close( out, tf_out.make( out_sizes_2, /*data=*/{0.731059, 0.731059, 0.731059, 0.731059}));