Skip to content

Commit

Permalink
Explicitly use the current device resource in DeviceBuffer
Browse files Browse the repository at this point in the history
Previously we were relying on the C++ and Python-level device
resources to agree. But this need not be the case.

To avoid this, first get the current deivce resource and then use it
when allocating the wrapped C++ device_buffer when creating
DeviceBuffers.

- Closes #1506
  • Loading branch information
wence- committed Apr 3, 2024
1 parent bd3f0d8 commit f0486cd
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
26 changes: 20 additions & 6 deletions python/rmm/_lib/device_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,31 @@ from libcpp.memory cimport unique_ptr

from rmm._cuda.stream cimport Stream
from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.memory_resource cimport DeviceMemoryResource
from rmm._lib.memory_resource cimport (
DeviceMemoryResource,
device_memory_resource,
)


cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
cdef cppclass device_buffer:
device_buffer()
device_buffer(size_t size, cuda_stream_view stream) except +
device_buffer(const void* source_data,
size_t size, cuda_stream_view stream) except +
device_buffer(const device_buffer buf,
cuda_stream_view stream) except +
device_buffer(
size_t size,
cuda_stream_view stream,
device_memory_resource *
) except +
device_buffer(
const void* source_data,
size_t size,
cuda_stream_view stream,
device_memory_resource *
) except +
device_buffer(
const device_buffer buf,
cuda_stream_view stream,
device_memory_resource *
) except +
void reserve(size_t new_capacity, cuda_stream_view stream) except +
void resize(size_t new_size, cuda_stream_view stream) except +
void shrink_to_fit(cuda_stream_view stream) except +
Expand Down
22 changes: 12 additions & 10 deletions python/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ from cuda.ccudart cimport (
cudaStream_t,
)

from rmm._lib.memory_resource cimport get_current_device_resource
from rmm._lib.memory_resource cimport (
device_memory_resource,
get_current_device_resource,
)


# The DeviceMemoryResource attribute could be released prematurely
Expand Down Expand Up @@ -75,24 +78,23 @@ cdef class DeviceBuffer:
>>> db = rmm.DeviceBuffer(size=5)
"""
cdef const void* c_ptr
cdef device_memory_resource * mr_ptr
# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.stream = stream

mr_ptr = self.mr.get_mr()
with nogil:
c_ptr = <const void*>ptr

if size == 0:
self.c_obj.reset(new device_buffer())
elif c_ptr == NULL:
self.c_obj.reset(new device_buffer(size, stream.view()))
if c_ptr == NULL or size == 0:
self.c_obj.reset(new device_buffer(size, stream.view(), mr_ptr))
else:
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view()))
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view(), mr_ptr))

if stream.c_is_default():
stream.c_synchronize()

# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.stream = stream

def __len__(self):
return self.size

Expand Down
2 changes: 1 addition & 1 deletion python/rmm/_lib/memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ cdef extern from "rmm/mr/device/device_memory_resource.hpp" \

cdef class DeviceMemoryResource:
cdef shared_ptr[device_memory_resource] c_obj
cdef device_memory_resource* get_mr(self)
cdef device_memory_resource* get_mr(self) noexcept nogil

cdef class UpstreamResourceAdaptor(DeviceMemoryResource):
cdef readonly DeviceMemoryResource upstream_mr
Expand Down
2 changes: 1 addition & 1 deletion python/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \

cdef class DeviceMemoryResource:

cdef device_memory_resource* get_mr(self):
cdef device_memory_resource* get_mr(self) noexcept nogil:
"""Get the underlying C++ memory resource object."""
return self.c_obj.get()

Expand Down

0 comments on commit f0486cd

Please sign in to comment.