From 75d76e9a986554faf59eef6fcbae991779b7f774 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 21 Jan 2025 16:02:29 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_glu.cpp | 10 ++++------ kernels/test/op_glu_test.cpp | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index 38bf8fd8db4..9374d17e86e 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -155,12 +155,10 @@ Tensor& glu_out( const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim; const auto in_dtype = self.scalar_type(); - ET_SWITCH_FLOAT_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { - if (out.scalar_type() == ScalarType::Float) { - glu_out_tensor(self, non_negative_dim, out); - } else { - glu_out_tensor(self, non_negative_dim, out); - } + ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { + ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() { + glu_out_tensor(self, non_negative_dim, out); + }); }); return out; diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index ca5fb5f6f67..0501b3c7d10 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -128,6 +128,20 @@ TEST_F(OpGluOutTest, AllInputDoubleOutputSupport) { #undef TEST_ENTRY } +TEST_F(OpGluOutTest, AllInputHalfOutputSupport) { +#define TEST_ENTRY(ctype, dtype) \ + test_glu_out(); + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpGluOutTest, AllInputBFloat16OutputSupport) { +#define TEST_ENTRY(ctype, dtype) \ + test_glu_out(); + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + TEST_F(OpGluOutTest, InfinityAndNANTest) { TensorFactory tf; const std::vector sizes = {4, 2}; From 447b0ae82497f25a03bd8e85213dfeb26a59ce4c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 09:55:04 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/test/op_glu_test.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index 0501b3c7d10..f03ab40f115 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -117,28 +117,28 @@ class OpGluOutTest : public OperatorTest { TEST_F(OpGluOutTest, AllInputFloatOutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_glu_out(); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpGluOutTest, AllInputDoubleOutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_glu_out(); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpGluOutTest, AllInputHalfOutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_glu_out(); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpGluOutTest, AllInputBFloat16OutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_glu_out(); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }