From 11c8ca797d0fdcc536fd40e0c2ccef3c8f6d1194 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 16:05:36 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_roll.cpp | 2 +- kernels/test/op_roll_test.cpp | 30 +++++++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/kernels/portable/cpu/op_roll.cpp b/kernels/portable/cpu/op_roll.cpp index ade564c300b..ee735758c52 100644 --- a/kernels/portable/cpu/op_roll.cpp +++ b/kernels/portable/cpu/op_roll.cpp @@ -81,7 +81,7 @@ Tensor& roll_out( constexpr auto name = "roll.out"; - ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { const CTYPE* in_data = in.const_data_ptr(); CTYPE* out_data = out.mutable_data_ptr(); diff --git a/kernels/test/op_roll_test.cpp b/kernels/test/op_roll_test.cpp index 16e09ec83f5..e6cf2c43e11 100644 --- a/kernels/test/op_roll_test.cpp +++ b/kernels/test/op_roll_test.cpp @@ -37,18 +37,26 @@ class OpRollOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } + + template + void test_dtype() { + TensorFactory tf; + + Tensor input = tf.make({4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + int64_t shifts_data[2] = {2, 1}; + ArrayRef shifts = ArrayRef(shifts_data, 2); + int64_t dims_data[2] = {0, 1}; + ArrayRef dims = ArrayRef(dims_data, 2); + Tensor out = tf.zeros({4, 2}); + Tensor out_expected = tf.make({4, 2}, {6, 5, 8, 7, 2, 1, 4, 3}); + op_roll_out(input, shifts, dims, out); + EXPECT_TENSOR_CLOSE(out, out_expected); + } }; TEST_F(OpRollOutTest, SmokeTest) { - TensorFactory tfFloat; - - Tensor input = tfFloat.make({4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - int64_t shifts_data[2] = {2, 1}; - ArrayRef shifts = ArrayRef(shifts_data, 2); - int64_t dims_data[2] = {0, 1}; - ArrayRef dims = ArrayRef(dims_data, 2); - Tensor out = tfFloat.zeros({4, 2}); - Tensor out_expected = tfFloat.make({4, 2}, {6, 5, 8, 7, 2, 1, 4, 3}); - op_roll_out(input, shifts, dims, out); - EXPECT_TENSOR_CLOSE(out, out_expected); +#define TEST_ENTRY(ctype, dtype) test_dtype(); + // TODO: enable bool test after #7856 lands. + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY }