Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add a public copy API to DeviceBuffer #1128

Merged
merged 10 commits into from
Oct 18, 2022
2 changes: 2 additions & 0 deletions python/rmm/_lib/device_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
device_buffer(size_t size, cuda_stream_view stream) except +
device_buffer(const void* source_data,
size_t size, cuda_stream_view stream) except +
device_buffer(const device_buffer buf,
cuda_stream_view stream) except +
void reserve(size_t new_capacity, cuda_stream_view stream) except +
void resize(size_t new_size, cuda_stream_view stream) except +
void shrink_to_fit(cuda_stream_view stream) except +
Expand Down
33 changes: 29 additions & 4 deletions python/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ cimport cython
from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_FromStringAndSize
from cython.operator cimport dereference
from libc.stdint cimport uintptr_t
from libcpp.memory cimport unique_ptr
from libcpp.memory cimport make_unique, unique_ptr
from libcpp.utility cimport move

from rmm._cuda.stream cimport Stream
Expand Down Expand Up @@ -133,6 +133,32 @@ cdef class DeviceBuffer:
}
return intf

def copy(self):
"""Returns a copy of DeviceBuffer.

Returns
-------
A deep copy of existing ``DeviceBuffer``

Examples
--------
>>> import rmm
>>> db = rmm.DeviceBuffer.to_device(b"abc")
>>> db_copy = db.copy()
>>> db.copy_to_host()
array([97, 98, 99], dtype=uint8)
>>> db_copy.copy_to_host()
array([97, 98, 99], dtype=uint8)
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
>>> assert db is not db_copy
>>> assert db.ptr != db_copy.ptr
"""
ret = DeviceBuffer(ptr=self.ptr, size=self.size, stream=self.stream)
ret.mr = self.mr
return ret

def __copy__(self):
return self.copy()

@staticmethod
cdef DeviceBuffer c_from_unique_ptr(unique_ptr[device_buffer] ptr):
cdef DeviceBuffer buf = DeviceBuffer.__new__(DeviceBuffer)
Expand Down Expand Up @@ -475,13 +501,12 @@ cpdef void copy_device_to_ptr(uintptr_t d_src,
Examples
--------
>>> import rmm
>>> import numpy as np
>>> db = rmm.DeviceBuffer(size=5)
>>> db2 = rmm.DeviceBuffer.to_device(b"abc")
>>> rmm._lib.device_buffer.copy_device_to_ptr(db2.ptr, db.ptr, db2.size)
>>> hb = db.copy_to_host()
>>> print(hb)
array([10, 11, 12, 0, 0], dtype=uint8)
>>> hb
array([97, 98, 99, 0, 0], dtype=uint8)
"""

with nogil:
Expand Down
27 changes: 27 additions & 0 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import gc
import os
import sys
Expand Down Expand Up @@ -887,3 +888,29 @@ def func_with_arg(x):
rmm.unregister_reinitialize_hook(func_with_arg)
rmm.reinitialize()
assert L == [2]


@pytest.mark.parametrize(
"cuda_ary",
[
lambda: rmm.DeviceBuffer.to_device(b"abc"),
lambda: cuda.to_device(np.array([97, 98, 99, 0, 0], dtype="u1")),
],
)
@pytest.mark.parametrize(
"make_copy", [lambda db: db.copy(), lambda db: copy.copy(db)]
)
def test_rmm_device_buffer_copy(cuda_ary, make_copy):
cuda_ary = cuda_ary()
db = rmm.DeviceBuffer.to_device(np.zeros(5, dtype="u1"))
db.copy_from_device(cuda_ary)
db_copy = make_copy(db)

assert db is not db_copy
assert db.ptr != db_copy.ptr
assert len(db) == len(db_copy)

expected = np.array([97, 98, 99, 0, 0], dtype="u1")
result = db_copy.copy_to_host()

np.testing.assert_equal(expected, result)