diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index 38bf8fd8db4..9374d17e86e 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -155,12 +155,10 @@ Tensor& glu_out( const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim; const auto in_dtype = self.scalar_type(); - ET_SWITCH_FLOAT_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { - if (out.scalar_type() == ScalarType::Float) { - glu_out_tensor(self, non_negative_dim, out); - } else { - glu_out_tensor(self, non_negative_dim, out); - } + ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { + ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() { + glu_out_tensor(self, non_negative_dim, out); + }); }); return out; diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index ca5fb5f6f67..f03ab40f115 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -117,14 +117,28 @@ class OpGluOutTest : public OperatorTest { TEST_F(OpGluOutTest, AllInputFloatOutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_glu_out(); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpGluOutTest, AllInputDoubleOutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_glu_out(); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpGluOutTest, AllInputHalfOutputSupport) { +#define TEST_ENTRY(ctype, dtype) \ + test_glu_out(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpGluOutTest, AllInputBFloat16OutputSupport) { +#define TEST_ENTRY(ctype, dtype) \ + test_glu_out(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }