Skip to content

Commit

Permalink
simd: add floor, ceil, round, trunc operations (kokkos#6393)
Browse files Browse the repository at this point in the history
Added simd floor, ceil, round, trunc for all types
  • Loading branch information
ldh4 committed Oct 12, 2023
1 parent 1095b64 commit c586fa1
Show file tree
Hide file tree
Showing 7 changed files with 1,495 additions and 768 deletions.
373 changes: 271 additions & 102 deletions simd/src/Kokkos_SIMD_AVX2.hpp

Large diffs are not rendered by default.

1,348 changes: 786 additions & 562 deletions simd/src/Kokkos_SIMD_AVX512.hpp

Large diffs are not rendered by default.

15 changes: 14 additions & 1 deletion simd/src/Kokkos_SIMD_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,21 @@ template <class T, class Abi>
return a == simd_mask<T, Abi>(false);
}

} // namespace Experimental
// A temporary device-callable implemenation of round half to nearest even
template <typename T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto round_half_to_nearest_even(
T const& x) {
auto ceil = Kokkos::ceil(x);
auto floor = Kokkos::floor(x);

if (Kokkos::abs(ceil - x) == Kokkos::abs(floor - x)) {
auto rem = Kokkos::remainder(ceil, 2.0);
return (rem == 0) ? ceil : floor;
}
return Kokkos::round(x);
}

} // namespace Experimental
} // namespace Kokkos

#endif
339 changes: 260 additions & 79 deletions simd/src/Kokkos_SIMD_NEON.hpp

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions simd/src/Kokkos_SIMD_Scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,38 @@ template <class T>
return a;
}

template <typename T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto floor(
Experimental::simd<T, Experimental::simd_abi::scalar> const& a) {
using data_type = std::conditional_t<std::is_floating_point_v<T>, T, double>;
return Experimental::simd<data_type, Experimental::simd_abi::scalar>(
Kokkos::floor(static_cast<data_type>(a[0])));
};

template <typename T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto ceil(
Experimental::simd<T, Experimental::simd_abi::scalar> const& a) {
using data_type = std::conditional_t<std::is_floating_point_v<T>, T, double>;
return Experimental::simd<data_type, Experimental::simd_abi::scalar>(
Kokkos::ceil(static_cast<data_type>(a[0])));
};

template <typename T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto round(
Experimental::simd<T, Experimental::simd_abi::scalar> const& a) {
using data_type = std::conditional_t<std::is_floating_point_v<T>, T, double>;
return Experimental::simd<data_type, Experimental::simd_abi::scalar>(
Experimental::round_half_to_nearest_even(static_cast<data_type>(a[0])));
};

template <typename T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto trunc(
Experimental::simd<T, Experimental::simd_abi::scalar> const& a) {
using data_type = std::conditional_t<std::is_floating_point_v<T>, T, double>;
return Experimental::simd<data_type, Experimental::simd_abi::scalar>(
Kokkos::trunc(static_cast<data_type>(a[0])));
};

template <class T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION
Experimental::simd<T, Experimental::simd_abi::scalar>
Expand Down
80 changes: 80 additions & 0 deletions simd/unit_tests/include/SIMDTesting_Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,86 @@ class absolutes {
}
};

class floors {
public:
template <typename T>
auto on_host(T const& a) const {
return Kokkos::floor(a);
}
template <typename T>
auto on_host_serial(T const& a) const {
return Kokkos::floor(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const {
return Kokkos::floor(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const {
return Kokkos::floor(a);
}
};

class ceils {
public:
template <typename T>
auto on_host(T const& a) const {
return Kokkos::ceil(a);
}
template <typename T>
auto on_host_serial(T const& a) const {
return Kokkos::ceil(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const {
return Kokkos::ceil(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const {
return Kokkos::ceil(a);
}
};

class rounds {
public:
template <typename T>
auto on_host(T const& a) const {
return Kokkos::round(a);
}
template <typename T>
auto on_host_serial(T const& a) const {
return std::rint(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const {
return Kokkos::round(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const {
return Kokkos::Experimental::round_half_to_nearest_even(a);
}
};

class truncates {
public:
template <typename T>
auto on_host(T const& a) const {
return Kokkos::trunc(a);
}
template <typename T>
auto on_host_serial(T const& a) const {
return Kokkos::trunc(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const {
return Kokkos::trunc(a);
}
template <typename T>
KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const {
return Kokkos::trunc(a);
}
};

class shift_right {
public:
template <typename T, typename U>
Expand Down
76 changes: 52 additions & 24 deletions simd/unit_tests/include/TestSIMD_MathOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ void host_check_math_op_one_loader(UnaryOp unary_op, std::size_t n,
simd_type arg;
bool const loaded_arg = loader.host_load(args + i, nlanes, arg);
if (!loaded_arg) continue;
simd_type expected_result;
auto computed_result = unary_op.on_host(arg);

decltype(computed_result) expected_result;
for (std::size_t lane = 0; lane < simd_type::size(); ++lane) {
if (lane < nlanes)
expected_result[lane] = unary_op.on_host_serial(T(arg[lane]));
}
simd_type const computed_result = unary_op.on_host(arg);
host_check_equality(expected_result, computed_result, nlanes);
}
}
Expand All @@ -85,12 +86,17 @@ inline void host_check_all_math_ops(const DataType (&first_args)[n],
host_check_math_op_all_loaders<Abi>(plus(), n, first_args, second_args);
host_check_math_op_all_loaders<Abi>(minus(), n, first_args, second_args);
host_check_math_op_all_loaders<Abi>(multiplies(), n, first_args, second_args);
host_check_math_op_all_loaders<Abi>(absolutes(), n, first_args);

// TODO: Place fallback division implementations for all simd integer types
if constexpr (std::is_same_v<DataType, double>)
host_check_math_op_all_loaders<Abi>(divides(), n, first_args, second_args);
host_check_math_op_all_loaders<Abi>(floors(), n, first_args);
host_check_math_op_all_loaders<Abi>(ceils(), n, first_args);
host_check_math_op_all_loaders<Abi>(rounds(), n, first_args);
host_check_math_op_all_loaders<Abi>(truncates(), n, first_args);

host_check_math_op_all_loaders<Abi>(absolutes(), n, first_args);
// TODO: Place fallback implementations for all simd integer types
if constexpr (std::is_floating_point_v<DataType>) {
host_check_math_op_all_loaders<Abi>(divides(), n, first_args, second_args);
}
}

template <typename Abi, typename DataType>
Expand All @@ -100,20 +106,28 @@ inline void host_check_abi_size() {
static_assert(simd_type::size() == mask_type::size());
}

template <class Abi, typename DataType>
template <typename Abi, typename DataType>
inline void host_check_math_ops() {
constexpr size_t n = 11;

host_check_abi_size<Abi, DataType>();

if constexpr (std::is_signed_v<DataType>) {
DataType const first_args[n] = {1, 2, -1, 10, 0, 1, -2, 10, 0, 1, -2};
DataType const second_args[n] = {1, 2, 1, 1, 1, -3, -2, 1, 13, -3, -2};
if constexpr (!std::is_integral_v<DataType>) {
DataType const first_args[n] = {0.1, 0.4, 0.5, 0.7, 1.0, 1.5,
-2.0, 10.0, 0.0, 1.2, -2.8};
DataType const second_args[n] = {1.0, 0.2, 1.1, 1.8, -0.1, -3.0,
-2.4, 1.0, 13.0, -3.2, -2.1};
host_check_all_math_ops<Abi>(first_args, second_args);
} else {
DataType const first_args[n] = {1, 2, 1, 10, 0, 1, 2, 10, 0, 1, 2};
DataType const second_args[n] = {1, 2, 1, 1, 1, 3, 2, 1, 13, 3, 2};
host_check_all_math_ops<Abi>(first_args, second_args);
if constexpr (std::is_signed_v<DataType>) {
DataType const first_args[n] = {1, 2, -1, 10, 0, 1, -2, 10, 0, 1, -2};
DataType const second_args[n] = {1, 2, 1, 1, 1, -3, -2, 1, 13, -3, -2};
host_check_all_math_ops<Abi>(first_args, second_args);
} else {
DataType const first_args[n] = {1, 2, 1, 10, 0, 1, 2, 10, 0, 1, 2};
DataType const second_args[n] = {1, 2, 1, 1, 1, 3, 2, 1, 13, 3, 2};
host_check_all_math_ops<Abi>(first_args, second_args);
}
}
}

Expand Down Expand Up @@ -171,11 +185,12 @@ KOKKOS_INLINE_FUNCTION void device_check_math_op_one_loader(UnaryOp unary_op,
simd_type arg;
bool const loaded_arg = loader.device_load(args + i, nlanes, arg);
if (!loaded_arg) continue;
simd_type expected_result;
auto computed_result = unary_op.on_device(arg);

decltype(computed_result) expected_result;
for (std::size_t lane = 0; lane < nlanes; ++lane) {
expected_result[lane] = unary_op.on_device_serial(arg[lane]);
}
simd_type const computed_result = unary_op.on_device(arg);
device_check_equality(expected_result, computed_result, nlanes);
}
}
Expand All @@ -196,12 +211,17 @@ KOKKOS_INLINE_FUNCTION void device_check_all_math_ops(
device_check_math_op_all_loaders<Abi>(minus(), n, first_args, second_args);
device_check_math_op_all_loaders<Abi>(multiplies(), n, first_args,
second_args);
device_check_math_op_all_loaders<Abi>(absolutes(), n, first_args);

if constexpr (std::is_same_v<DataType, double>)
device_check_math_op_all_loaders<Abi>(floors(), n, first_args);
device_check_math_op_all_loaders<Abi>(ceils(), n, first_args);
device_check_math_op_all_loaders<Abi>(rounds(), n, first_args);
device_check_math_op_all_loaders<Abi>(truncates(), n, first_args);

if constexpr (std::is_floating_point_v<DataType>) {
device_check_math_op_all_loaders<Abi>(divides(), n, first_args,
second_args);

device_check_math_op_all_loaders<Abi>(absolutes(), n, first_args);
}
}

template <typename Abi, typename DataType>
Expand All @@ -217,14 +237,22 @@ KOKKOS_INLINE_FUNCTION void device_check_math_ops() {

device_check_abi_size<Abi, DataType>();

if constexpr (std::is_signed_v<DataType>) {
DataType const first_args[n] = {1, 2, -1, 10, 0, 1, -2, 10, 0, 1, -2};
DataType const second_args[n] = {1, 2, 1, 1, 1, -3, -2, 1, 13, -3, -2};
if constexpr (!std::is_integral_v<DataType>) {
DataType const first_args[n] = {0.1, 0.4, 0.5, 0.7, 1.0, 1.5,
-2.0, 10.0, 0.0, 1.2, -2.8};
DataType const second_args[n] = {1.0, 0.2, 1.1, 1.8, -0.1, -3.0,
-2.4, 1.0, 13.0, -3.2, -2.1};
device_check_all_math_ops<Abi>(first_args, second_args);
} else {
DataType const first_args[n] = {1, 2, 1, 10, 0, 1, 2, 10, 0, 1, 2};
DataType const second_args[n] = {1, 2, 1, 1, 1, 3, 2, 1, 13, 3, 2};
device_check_all_math_ops<Abi>(first_args, second_args);
if constexpr (std::is_signed_v<DataType>) {
DataType const first_args[n] = {1, 2, -1, 10, 0, 1, -2, 10, 0, 1, -2};
DataType const second_args[n] = {1, 2, 1, 1, 1, -3, -2, 1, 13, -3, -2};
device_check_all_math_ops<Abi>(first_args, second_args);
} else {
DataType const first_args[n] = {1, 2, 1, 10, 0, 1, 2, 10, 0, 1, 2};
DataType const second_args[n] = {1, 2, 1, 1, 1, 3, 2, 1, 13, 3, 2};
device_check_all_math_ops<Abi>(first_args, second_args);
}
}
}

Expand Down

0 comments on commit c586fa1

Please sign in to comment.