Skip to content

Commit

Permalink
kokkos#5635: HIP: Add Overloads for parallel_scan with return value f…
Browse files Browse the repository at this point in the history
…or TeamThreadRange (kokkos#6302)

* kokkos#5635: Move some tests for parallel_scan to TestTeamScan

* kokkos#5635: HIP: Add parallel_scan with return value for TeamThreadRange

* use shortcut

* Use {} initialization

---------

Co-authored-by: Francesco Rizzi <fnrizzi@sandia.gov>
Co-authored-by: Cezary Skrzyński <cezary.skrzynski@ng-analytics.com>
  • Loading branch information
3 people committed Sep 26, 2023
1 parent 9db1ea4 commit 1675997
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
39 changes: 32 additions & 7 deletions core/src/HIP/Kokkos_HIP_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,28 +537,30 @@ parallel_reduce(const Impl::TeamThreadRangeBoundariesStruct<
* final == true.
*/
// This is the same code as in CUDA and largely the same as in OpenMPTarget
template <typename iType, typename FunctorType>
template <typename iType, typename FunctorType, typename ValueType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<iType, Impl::HIPTeamMember>&
loop_bounds,
const FunctorType& lambda) {
// Extract value_type from lambda
using value_type = typename Kokkos::Impl::FunctorAnalysis<
const FunctorType& lambda, ValueType& return_val) {
// Extract ValueType from the Functor
using functor_value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;
ValueType>::value_type;
static_assert(std::is_same_v<functor_value_type, ValueType>,
"Non-matching value types of functor and return type");

const auto start = loop_bounds.start;
const auto end = loop_bounds.end;
auto& member = loop_bounds.member;
const auto team_size = member.team_size();
const auto team_rank = member.team_rank();
const auto nchunk = (end - start + team_size - 1) / team_size;
value_type accum = 0;
ValueType accum = {};
// each team has to process one or more chunks of the prefix scan
for (iType i = 0; i < nchunk; ++i) {
auto ii = start + i * team_size + team_rank;
// local accumulation for this chunk
value_type local_accum = 0;
ValueType local_accum = 0;
// user updates value with prefix value
if (ii < loop_bounds.end) lambda(ii, local_accum, false);
// perform team scan
Expand All @@ -572,6 +574,29 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
// broadcast last value to rest of the team
member.team_broadcast(accum, team_size - 1);
}
return_val = accum;
}

/** \brief Inter-thread parallel exclusive prefix sum.
*
* Executes closure(iType i, ValueType & val, bool final) for each i=[0..N)
*
* The range [0..N) is mapped to each rank in the team (whose global rank is
* less than N) and a scan operation is performed. The last call to closure has
* final == true.
*/
template <typename iType, typename FunctorType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<iType, Impl::HIPTeamMember>&
loop_bounds,
const FunctorType& lambda) {
// Extract value_type from lambda
using value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;

value_type scan_val;
parallel_scan(loop_bounds, lambda, scan_val);
}

template <typename iType, class Closure>
Expand Down
6 changes: 3 additions & 3 deletions core/unit_test/TestTeamScan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ TEST(TEST_CATEGORY, team_scan) {

// Temporary: This condition will progressively be reduced when parallel_scan
// with return value will be implemented for more backends.
#if !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_OPENACC) && \
!defined(KOKKOS_ENABLE_SYCL) && !defined(KOKKOS_ENABLE_THREADS) && \
!defined(KOKKOS_ENABLE_OPENMPTARGET) && !defined(KOKKOS_ENABLE_HPX)
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_SYCL) && \
!defined(KOKKOS_ENABLE_THREADS) && !defined(KOKKOS_ENABLE_OPENMPTARGET) && \
!defined(KOKKOS_ENABLE_HPX)
template <class ExecutionSpace, class DataType>
struct TestTeamScanRetVal {
using execution_space = ExecutionSpace;
Expand Down

0 comments on commit 1675997

Please sign in to comment.