Skip to content

Commit

Permalink
PR #12864: [ROCm] fix DeviceAllocate code
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#12864

Copybara import of the project:

--
2e6dfa841a495c1f0babf0ff6f877bc1866bf319 by Ruturaj4 <ruturaj.vaidya@amd.com>:

[ROCm] fix DeviceAllocate code

Merging this change closes #12864

PiperOrigin-RevId: 636848518
  • Loading branch information
Ruturaj4 authored and tensorflower-gardener committed May 24, 2024
1 parent e20343f commit 5a75dba
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1318,13 +1318,19 @@ struct BitPatternToValue {

/* static */ void* GpuDriver::DeviceAllocate(GpuContext* context,
uint64_t bytes) {
if (bytes == 0) {
return nullptr;
}

ScopedActivateContext activated{context};
hipDeviceptr_t result = 0;
hipError_t res = wrap::hipMalloc(&result, bytes);
if (res != hipSuccess) {
LOG(ERROR) << "failed to allocate "
<< tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes
<< " bytes) from device: " << ToString(res);
// LOG(INFO) because this isn't always important to users (e.g. BFCAllocator
// implements a retry if the first allocation fails).
LOG(INFO) << "failed to allocate "
<< tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes
<< " bytes) from device: " << ToString(res);
return nullptr;
}
void* ptr = reinterpret_cast<void*>(result);
Expand Down

0 comments on commit 5a75dba

Please sign in to comment.