Skip to content

Commit

Permalink
Address reviewer' comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rombur committed Sep 8, 2023
1 parent 9fde326 commit 035d284
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 11 deletions.
1 change: 1 addition & 0 deletions core/src/HIP/Kokkos_HIP_GraphNodeKernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <Kokkos_PointerOwnership.hpp>

#include <HIP/Kokkos_HIP_SharedAllocationRecord.hpp>
#include <HIP/Kokkos_HIP_GraphNode_Impl.hpp>

namespace Kokkos {
namespace Impl {
Expand Down
2 changes: 0 additions & 2 deletions core/src/HIP/Kokkos_HIP_GraphNode_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,4 @@ struct GraphNodeBackendDetailsBeforeTypeErasure<Kokkos::HIP, Kernel,
} // namespace Impl
} // namespace Kokkos

#include <HIP/Kokkos_HIP_GraphNodeKernel.hpp>

#endif
4 changes: 2 additions & 2 deletions core/src/HIP/Kokkos_HIP_Graph_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
#include <Kokkos_Graph_fwd.hpp>

#include <impl/Kokkos_GraphImpl.hpp>

#include <impl/Kokkos_GraphNodeImpl.hpp>
#include <HIP/Kokkos_HIP_GraphNode_Impl.hpp>

#include <HIP/Kokkos_HIP_GraphNodeKernel.hpp>

namespace Kokkos {
namespace Impl {
Expand Down
20 changes: 13 additions & 7 deletions core/src/HIP/Kokkos_HIP_KernelLaunch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include <HIP/Kokkos_HIP_Space.hpp>

#if !((HIP_VERSION_MAJOR == 5) && (HIP_VERSION_MINOR == 2))
#define KOKKOS_IMPL_HIP_GRAPH_ENABLED
#endif

#ifdef KOKKOS_IMPL_HIP_GRAPH_ENABLED
#include <HIP/Kokkos_HIP_GraphNodeKernel.hpp>
#include <impl/Kokkos_GraphImpl_fwd.hpp>
#endif
Expand Down Expand Up @@ -380,13 +384,13 @@ struct HIPParallelLaunchKernelInvoker<DriverType, LaunchBounds,
driver);
}

#if !((HIP_VERSION_MAJOR == 5) && (HIP_VERSION_MINOR == 2))
#ifdef KOKKOS_IMPL_HIP_GRAPH_ENABLED
static void create_parallel_launch_graph_node(
DriverType const &driver, dim3 const &grid, dim3 const &block, int shmem,
HIPInternal const * /*hip_instance*/) {
auto const &graph = Impl::get_hip_graph_from_kernel(driver);
auto const &graph = get_hip_graph_from_kernel(driver);
KOKKOS_EXPECTS(graph);
auto &graph_node = Impl::get_hip_graph_node_from_kernel(driver);
auto &graph_node = get_hip_graph_node_from_kernel(driver);
// Expect node not yet initialized
KOKKOS_EXPECTS(!graph_node);

Expand Down Expand Up @@ -438,13 +442,13 @@ struct HIPParallelLaunchKernelInvoker<DriverType, LaunchBounds,
driver_ptr);
}

#if !((HIP_VERSION_MAJOR == 5) && (HIP_VERSION_MINOR == 2))
#ifdef KOKKOS_IMPL_HIP_GRAPH_ENABLED
static void create_parallel_launch_graph_node(
DriverType const &driver, dim3 const &grid, dim3 const &block, int shmem,
HIPInternal const *hip_instance) {
auto const &graph = Impl::get_hip_graph_from_kernel(driver);
auto const &graph = get_hip_graph_from_kernel(driver);
KOKKOS_EXPECTS(graph);
auto &graph_node = Impl::get_hip_graph_node_from_kernel(driver);
auto &graph_node = get_hip_graph_node_from_kernel(driver);
// Expect node not yet initialized
KOKKOS_EXPECTS(!graph_node);

Expand Down Expand Up @@ -581,7 +585,7 @@ void hip_parallel_launch(const DriverType &driver, const dim3 &grid,
const dim3 &block, const int shmem,
const HIPInternal *hip_instance,
const bool prefer_shmem) {
#if !((HIP_VERSION_MAJOR == 5) && (HIP_VERSION_MINOR == 2))
#ifdef KOKKOS_IMPL_HIP_GRAPH_ENABLED
if constexpr (DoGraph) {
// Graph launch
using base_t = HIPParallelLaunchKernelInvoker<DriverType, LaunchBounds,
Expand Down Expand Up @@ -624,6 +628,8 @@ void hip_parallel_launch(const DriverType &driver, const dim3 &grid,
} // namespace Impl
} // namespace Kokkos

#undef KOKKOS_IMPL_HIP_GRAPH_ENABLED

#endif

#endif
1 change: 1 addition & 0 deletions core/src/Kokkos_Graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ Graph<ExecutionSpace> create_graph(Closure&& arg_closure) {
#include <impl/Kokkos_Default_Graph_Impl.hpp>
#include <Cuda/Kokkos_Cuda_Graph_Impl.hpp>
#if defined(KOKKOS_ENABLE_HIP)
// The implementation of hipGraph in ROCm 5.2 is bugged, so we cannot use it.
#if !((HIP_VERSION_MAJOR == 5) && (HIP_VERSION_MINOR == 2))
#include <HIP/Kokkos_HIP_Graph_Impl.hpp>
#endif
Expand Down

0 comments on commit 035d284

Please sign in to comment.