From 53393a03398a05aa457dece2e1b345d720d5b4db Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 15:36:37 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_relu.cpp | 8 ++++++-- kernels/test/op_relu_test.cpp | 8 ++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_relu.cpp b/kernels/portable/cpu/op_relu.cpp index 2ec258e2c47..e8c265fba4d 100644 --- a/kernels/portable/cpu/op_relu.cpp +++ b/kernels/portable/cpu/op_relu.cpp @@ -33,12 +33,16 @@ Tensor& relu_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK( ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbf16_type(out), + InvalidArgument, + out); ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "relu.out", CTYPE, [&]() { + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "relu.out", CTYPE, [&]() { apply_unary_map_fn( [](const CTYPE val_in) { return (std::isnan(val_in) || val_in >= CTYPE(0)) ? val_in : CTYPE(0); diff --git a/kernels/test/op_relu_test.cpp b/kernels/test/op_relu_test.cpp index 042cf13ad58..23709046897 100644 --- a/kernels/test/op_relu_test.cpp +++ b/kernels/test/op_relu_test.cpp @@ -82,6 +82,14 @@ TEST_F(OpReluTest, DoubleTensors) { test_relu_execution_floats(); } +TEST_F(OpReluTest, HalfTensors) { + test_relu_execution_floats(); +} + +TEST_F(OpReluTest, BFloat16Tensors) { + test_relu_execution_floats(); +} + TEST_F(OpReluTest, ByteTensors) { TensorFactory tf;