Skip to content

Commit

Permalink
PR #9640: [ROCm] Replace usages of hipGraphAddMemcpyNode with hipGrap…
Browse files Browse the repository at this point in the history
…hAddMemcpy…

Imported from GitHub PR openxla/xla#9640

…Node1D

It seems that setting up the params structure is inconsistent across rocm versions so fall back to more stable 1D variant. Minor cleanup of "unwrapped" hip runtime calls. Should fix openxla/xla#8692
Copybara import of the project:

--
7de30f1c2fa05941dd2fb59d70e43466145a1ad6 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Replace usages of hipGraphAddMemcpyNode with hipGraphAddMemcpyNode1D

It seems that setting up the params structure is inconsistent across rocm
versions so fall back to more stable 1D variant. Minor cleanup of "unwrapped"
hip runtime calls.

Merging this change closes #9640

PiperOrigin-RevId: 609124252
  • Loading branch information
draganmladjenovic authored and tensorflower-gardener committed Feb 21, 2024
1 parent f691003 commit 0904e4e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 29 deletions.
34 changes: 8 additions & 26 deletions third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) {
if (tls->depth == 0) {
VLOG(3) << "ScopedActivateContext switching to "
<< hip_context->device_ordinal();
FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(hip_context->context()),
FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(hip_context->context()),
"Failed setting context");
tls->depth = 1;
tls->current_device_ordinal = hip_context->device_ordinal();
Expand All @@ -205,7 +205,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) {

to_restore_ = tls->context;
// Set the device and update thread local.
FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(hip_context->context()),
FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(hip_context->context()),
"Failed setting context");
tls->current_device_ordinal = hip_context->device_ordinal();
tls->context = hip_context;
Expand All @@ -229,7 +229,7 @@ ScopedActivateContext::~ScopedActivateContext() {
}

// Set context and update thread local.
FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(to_restore_->context()),
FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(to_restore_->context()),
"Failed setting context");
tls->current_device_ordinal = to_restore_->device_ordinal();
tls->context = to_restore_;
Expand Down Expand Up @@ -959,18 +959,9 @@ GpuDriver::GraphGetMemAllocNodeParams(GpuGraphNodeHandle node) {
<< "; src: " << reinterpret_cast<void*>(gpu_src) << "; size: " << size
<< "; context: " << context->context() << "; deps: " << deps.size();

hipMemcpy3DParms params{
.srcArray = {},
.srcPos = {},
.srcPtr = {.ptr = gpu_src, .pitch = size, .xsize = size, .ysize = 1},
.dstArray = {},
.dstPos = {},
.dstPtr = {.ptr = gpu_dst, .pitch = size, .xsize = size, .ysize = 1},
.extent = hipExtent{.width = size, .height = 1, .depth = 1},
.kind = hipMemcpyDeviceToDevice};

RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode(node, graph, deps.data(),
deps.size(), &params),
RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode1D(
node, graph, deps.data(), deps.size(), gpu_dst,
gpu_src, size, hipMemcpyDeviceToDevice),
"Failed to add memcpy d2d node to a HIP graph");

return absl::OkStatus();
Expand All @@ -984,18 +975,9 @@ GpuDriver::GraphGetMemAllocNodeParams(GpuGraphNodeHandle node) {
<< "; src: " << reinterpret_cast<void*>(gpu_src) << "; size: " << size
<< "; context: " << context->context();

hipMemcpy3DParms params{
.srcArray = {},
.srcPos = {},
.srcPtr = {.ptr = gpu_src, .pitch = size, .xsize = size, .ysize = 1},
.dstArray = {},
.dstPos = {},
.dstPtr = {.ptr = gpu_dst, .pitch = size, .xsize = size, .ysize = 1},
.extent = hipExtent{.width = size, .height = 1, .depth = 1},
.kind = hipMemcpyDeviceToDevice};

RETURN_IF_ROCM_ERROR(
wrap::hipGraphExecMemcpyNodeSetParams(exec, node, &params),
wrap::hipGraphExecMemcpyNodeSetParams1D(exec, node, gpu_dst, gpu_src,
size, hipMemcpyDeviceToDevice),
"Failed to set memcpy d2d node params");

return absl::OkStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace wrap {
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char *kName = TO_STR(hipSymbolName); \
void *f; \
auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \
.value(), \
kName, &f); \
Expand Down Expand Up @@ -106,7 +106,6 @@ namespace wrap {
__macro(hipGraphAddKernelNode) \
__macro(hipGraphAddChildGraphNode) \
__macro(hipGraphAddMemAllocNode) \
__macro(hipGraphAddMemcpyNode) \
__macro(hipGraphAddMemcpyNode1D) \
__macro(hipGraphAddMemsetNode) \
__macro(hipGraphAddMemFreeNode) \
Expand All @@ -116,7 +115,7 @@ namespace wrap {
__macro(hipGraphExecChildGraphNodeSetParams) \
__macro(hipGraphExecDestroy) \
__macro(hipGraphExecKernelNodeSetParams) \
__macro(hipGraphExecMemcpyNodeSetParams) \
__macro(hipGraphExecMemcpyNodeSetParams1D) \
__macro(hipGraphExecMemsetNodeSetParams) \
__macro(hipGraphExecUpdate) \
__macro(hipGraphInstantiate) \
Expand Down

0 comments on commit 0904e4e

Please sign in to comment.