In [1]:
from mpi_helpers import (
    init_mpi,
    mpi_allgather,
    world_barrier,
)
from hip import (
    set_device,
    get_device,
    count_devices,
    malloc_fine_grained,
    get_ipc_handle,
    open_ipc_handle,
    hip_try,
    to_hip_c_byte_array_64
)
import numpy as np
from mpi4py import MPI
import ctypes

In [2]:
comm, rank, world_size = init_mpi()
num_gpus = count_devices()
gpu_id = rank % num_gpus
set_device(gpu_id)
print(f"Rank {rank} using device {gpu_id}")
world_barrier()

Rank 0 using device 0


In [3]:
bytes_count = 1 << 30  # 1GB
heap_base = malloc_fine_grained(bytes_count)
if heap_base.value is None:
    raise RuntimeError(f"Rank {rank}: Memory allocation failed.")

In [4]:
heap_bases = np.zeros(world_size, dtype=np.uint64)
heap_bases[rank] = heap_base.value

In [5]:
ipc_handles = np.zeros((world_size, 64), dtype=np.uint8)


In [6]:
def to_hip_c_byte_array_64(ipc_handle_np):
    if ipc_handle_np.dtype != np.uint8 or ipc_handle_np.size != 64:
        raise ValueError("Input must be a 64-element NumPy array of type np.uint8")
    hip_c_byte_array = (ctypes.c_byte * 64).from_buffer_copy(ipc_handle_np)
    return hip_c_byte_array

In [7]:
ipc_handle = get_ipc_handle(heap_base)
ipc_handle_np = np.frombuffer(ipc_handle, dtype=np.uint8)
ipc_handles[rank] = ipc_handle_np

def allgather(data):
    thread_comm = MPI.COMM_WORLD
    shmcomm = thread_comm.Split_type(MPI.COMM_TYPE_SHARED)
    shm_size = shmcomm.Get_size()
    rank = thread_comm.Get_rank()
    data = np.asarray(data, dtype=np.uint8)
    assert len(data.shape) == 1, "Input data must be a 1D array."
    recv_data = np.empty((shm_size, len(data)), dtype=data.dtype)
    shmcomm.Allgather(sendbuf=data, recvbuf=recv_data)
    shmcomm.Free()

    return recv_data

all_ipc_handles = allgather(ipc_handle_np)

In [8]:
handle_back = to_hip_c_byte_array_64(all_ipc_handles[0])

In [9]:
def verify_ipc_handle(ipc_handle_original, ipc_handle_reconstructed):
    # Convert original and reconstructed handles to bytes for comparison
    original_bytes = bytes(ipc_handle_original)
    reconstructed_bytes = bytes(ipc_handle_reconstructed)

    # Compare the two byte arrays
    if original_bytes == reconstructed_bytes:
        print("Handles match!")
        return True
    else:
        print("Handles do NOT match!")
        print(f"Original: {list(original_bytes)}")
        print(f"Reconstructed: {list(reconstructed_bytes)}")
        return False

# Example usage:
is_valid = verify_ipc_handle(ipc_handle, handle_back)

Handles match!


In [10]:
print(f"ipc_handle: {ipc_handle}")
print(f"handle_back: {handle_back}")
ipc_handle_orig = ipc_handle

ipc_handle: <hip.c_byte_Array_64 object at 0x7f2dab9371c0>
handle_back: <hip.c_byte_Array_64 object at 0x7f2dabde6240>


In [13]:
hip_runtime = ctypes.cdll.LoadLibrary("libamdhip64.so")

ptr = ctypes.c_void_p()

print(f"ipc_handle: {type(ipc_handle_orig)}")
# ipc_handle = to_hip_c_byte_array_64(ipc_handle_orig)
ipc_handle = ipc_handle_orig
hipIpcMemLazyEnablePeerAccess = 1
print(f"ipc_handle: {type(ipc_handle)}")
print(f" ctypes.byref(ptr): {type(ctypes.byref(ptr))}")
hip_try(
    hip_runtime.hipIpcOpenMemHandle(
        ctypes.byref(ptr),
        ipc_handle,
        ctypes.c_uint(hipIpcMemLazyEnablePeerAccess),
    )
)

ipc_handle: <class 'hip.c_byte_Array_64'>
ipc_handle: <class 'hip.c_byte_Array_64'>
 ctypes.byref(ptr): <class 'CArgObject'>


RuntimeError: HIP error code 1: invalid argument