Skip to content

Commit

Permalink
kokkos#5635: Add parallel_scan overloads with value for Threads
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable authored and cz4rs committed Sep 27, 2023
1 parent 1675997 commit 190bfe4
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions core/src/Threads/Kokkos_ThreadsTeam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,16 +976,19 @@ parallel_reduce(const Impl::ThreadVectorRangeBoundariesStruct<
* lambda(iType i, ValueType & val, bool final) for each i=0..N-1.
*
*/
template <typename iType, class FunctorType>
template <typename iType, class FunctorType, typename ValueType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<
iType, Impl::ThreadsExecTeamMember>& loop_bounds,
const FunctorType& lambda) {
using value_type = typename Kokkos::Impl::FunctorAnalysis<
const FunctorType& lambda, ValueType& return_val) {
// Extract ValueType from the Closure
using closure_value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;
static_assert(std::is_same<closure_value_type, ValueType>::value,
"Non-matching value types of closure and return type");

auto scan_val = value_type{};
auto scan_val = ValueType{};

// Intra-member scan
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand All @@ -1006,6 +1009,21 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
i += loop_bounds.increment) {
lambda(i, scan_val, true);
}

return_val = scan_val;
}

template <typename iType, class FunctorType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<
iType, Impl::ThreadsExecTeamMember>& loop_bounds,
const FunctorType& lambda) {
using value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;

value_type scan_val;
parallel_scan(loop_bounds, lambda, scan_val);
}

/** \brief Intra-thread vector parallel exclusive prefix sum. Executes
Expand Down

0 comments on commit 190bfe4

Please sign in to comment.