diff --git a/README.md b/README.md index 2f6d54a64..a05ffa5e5 100644 --- a/README.md +++ b/README.md @@ -691,6 +691,8 @@ resources MemoryResources are highly configurable and can be composed together in different ways. See `help(rmm.mr)` for more information. +## Using RMM with third-party libraries + ### Using RMM with CuPy You can configure [CuPy](https://cupy.dev/) to use RMM for memory @@ -698,9 +700,9 @@ allocations by setting the CuPy CUDA allocator to `rmm_cupy_allocator`: ```python ->>> import rmm +>>> from rmm.allocators.cupy import rmm_cupy_allocator >>> import cupy ->>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) +>>> cupy.cuda.set_allocator(rmm_cupy_allocator) ``` @@ -718,15 +720,15 @@ This can be done in two ways: 1. Setting the environment variable `NUMBA_CUDA_MEMORY_MANAGER`: ```python - $ NUMBA_CUDA_MEMORY_MANAGER=rmm python (args) + $ NUMBA_CUDA_MEMORY_MANAGER=rmm.allocators.numba python (args) ``` 2. Using the `set_memory_manager()` function provided by Numba: ```python >>> from numba import cuda - >>> import rmm - >>> cuda.set_memory_manager(rmm.RMMNumbaManager) + >>> from rmm.allocators.numba import RMMNumbaManager + >>> cuda.set_memory_manager(RMMNumbaManager) ``` **Note:** This only configures Numba to use the current RMM resource for allocations. @@ -741,10 +743,11 @@ RMM-managed pool: ```python import rmm +from rmm.allocators.torch import rmm_torch_allocator import torch rmm.reinitialize(pool_allocator=True) -torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator) +torch.cuda.memory.change_current_allocator(rmm_torch_allocator) ``` PyTorch and RMM will now share the same memory pool. @@ -753,13 +756,14 @@ You can, of course, use a custom memory resource with PyTorch as well: ```python import rmm +from rmm.allocators.torch import rmm_torch_allocator import torch # note that you can configure PyTorch to use RMM either before or # after changing RMM's memory resource. PyTorch will use whatever # memory resource is configured to be the "current" memory resource at # the time of allocation. -torch.cuda.change_current_allocator(rmm.rmm_torch_allocator) +torch.cuda.change_current_allocator(rmm_torch_allocator) # configure RMM to use a managed memory resource, wrapped with a # statistics resource adaptor that can report information about the diff --git a/python/docs/api.rst b/python/docs/api.rst index ebc68c354..73cd5dd81 100644 --- a/python/docs/api.rst +++ b/python/docs/api.rst @@ -17,3 +17,21 @@ Memory Resources :members: :undoc-members: :show-inheritance: + +Memory Allocators +----------------- + +.. automodule:: rmm.allocators.cupy + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: rmm.allocators.numba + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: rmm.allocators.torch + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs/basics.md b/python/docs/basics.md index 368145b3d..0c47073c1 100644 --- a/python/docs/basics.md +++ b/python/docs/basics.md @@ -131,21 +131,31 @@ resources MemoryResources are highly configurable and can be composed together in different ways. See `help(rmm.mr)` for more information. +## Using RMM with third-party libraries + +A number of libraries provide hooks to control their device +allocations. RMM provides implementations of these for +[CuPy](https://cupy.dev), +[numba](https://numba.readthedocs.io/en/stable/), and [PyTorch](https://pytorch.org) in the +`rmm.allocators` submodule. All these approaches configure the library +to use the _current_ RMM memory resource for device +allocations. + ### Using RMM with CuPy You can configure [CuPy](https://cupy.dev/) to use RMM for memory allocations by setting the CuPy CUDA allocator to -`rmm_cupy_allocator`: +`rmm.allocators.cupy.rmm_cupy_allocator`: ```python ->>> import rmm +>>> from rmm.allocators.cupy import rmm_cupy_allocator >>> import cupy ->>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) +>>> cupy.cuda.set_allocator(rmm_cupy_allocator) ``` ### Using RMM with Numba -You can configure Numba to use RMM for memory allocations using the +You can configure [Numba](https://numba.readthedocs.io/en/stable/) to use RMM for memory allocations using the Numba [EMM Plugin](https://numba.readthedocs.io/en/stable/cuda/external-memory.html#setting-emm-plugin). This can be done in two ways: @@ -153,13 +163,27 @@ This can be done in two ways: 1. Setting the environment variable `NUMBA_CUDA_MEMORY_MANAGER`: ```bash - $ NUMBA_CUDA_MEMORY_MANAGER=rmm python (args) + $ NUMBA_CUDA_MEMORY_MANAGER=rmm.allocators.numba python (args) ``` 2. Using the `set_memory_manager()` function provided by Numba: ```python >>> from numba import cuda - >>> import rmm - >>> cuda.set_memory_manager(rmm.RMMNumbaManager) + >>> from rmm.allocators.numba import RMMNumbaManager + >>> cuda.set_memory_manager(RMMNumbaManager) ``` + +### Using RMM with PyTorch + +You can configure +[PyTorch](https://pytorch.org/docs/stable/notes/cuda.html) to use RMM +for memory allocations using their by configuring the current +allocator. + +```python +from rmm.allocators.torch import rmm_torch_allocator +import torch + +torch.cuda.memory.change_current_allocator(rmm_torch_allocator) +``` diff --git a/python/rmm/__init__.py b/python/rmm/__init__.py index e2d8f57d7..d9e86c13e 100644 --- a/python/rmm/__init__.py +++ b/python/rmm/__init__.py @@ -17,20 +17,15 @@ from rmm.mr import disable_logging, enable_logging, get_log_filenames from rmm.rmm import ( RMMError, - RMMNumbaManager, - _numba_memory_manager, is_initialized, register_reinitialize_hook, reinitialize, - rmm_cupy_allocator, - rmm_torch_allocator, unregister_reinitialize_hook, ) __all__ = [ "DeviceBuffer", "RMMError", - "RMMNumbaManager", "disable_logging", "enable_logging", "get_log_filenames", @@ -38,8 +33,35 @@ "mr", "register_reinitialize_hook", "reinitialize", - "rmm_cupy_allocator", "unregister_reinitialize_hook", ] __version__ = "23.04.00" + + +_deprecated_names = { + "rmm_cupy_allocator": "cupy", + "rmm_torch_allocator": "torch", + "RMMNumbaManager": "numba", + "_numba_memory_manager": "numba", +} + + +def __getattr__(name): + if name in _deprecated_names: + import importlib + import warnings + + package = _deprecated_names[name] + warnings.warn( + f"Use of 'rmm.{name}' is deprecated and will be removed. " + f"'{name}' now lives in the 'rmm.allocators.{package}' sub-module, " + "please update your imports.", + FutureWarning, + ) + module = importlib.import_module( + f".allocators.{package}", package=__name__ + ) + return getattr(module, name) + else: + raise AttributeError(f"Module '{__name__}' has no attribute '{name}'") diff --git a/python/rmm/_cuda/gpu.py b/python/rmm/_cuda/gpu.py index e7f768349..2a23b41e6 100644 --- a/python/rmm/_cuda/gpu.py +++ b/python/rmm/_cuda/gpu.py @@ -1,6 +1,5 @@ # Copyright (c) 2020, NVIDIA CORPORATION. -import numba.cuda from cuda import cuda, cudart @@ -84,6 +83,8 @@ def runtimeGetVersion(): """ # TODO: Replace this with `cuda.cudart.cudaRuntimeGetVersion()` when the # limitation is fixed. + import numba.cuda + major, minor = numba.cuda.runtime.get_version() return major * 1000 + minor * 10 diff --git a/python/rmm/_cuda/stream.pyx b/python/rmm/_cuda/stream.pyx index 4f2ce26d0..d60dde4e1 100644 --- a/python/rmm/_cuda/stream.pyx +++ b/python/rmm/_cuda/stream.pyx @@ -16,6 +16,7 @@ from cuda.ccudart cimport cudaStream_t from libc.stdint cimport uintptr_t from libcpp cimport bool +from rmm._lib.cuda_stream cimport CudaStream from rmm._lib.cuda_stream_view cimport ( cuda_stream_default, cuda_stream_legacy, @@ -23,12 +24,6 @@ from rmm._lib.cuda_stream_view cimport ( cuda_stream_view, ) -from numba import cuda - -from rmm._lib.cuda_stream cimport CudaStream - -from rmm._lib.cuda_stream import CudaStream - cdef class Stream: def __init__(self, obj=None): @@ -46,10 +41,11 @@ cdef class Stream: self._init_with_new_cuda_stream() elif isinstance(obj, Stream): self._init_from_stream(obj) - elif isinstance(obj, cuda.cudadrv.driver.Stream): - self._init_from_numba_stream(obj) else: - self._init_from_cupy_stream(obj) + try: + self._init_from_numba_stream(obj) + except TypeError: + self._init_from_cupy_stream(obj) @staticmethod cdef Stream _from_cudaStream_t(cudaStream_t s, object owner=None): @@ -94,8 +90,12 @@ cdef class Stream: return self.c_is_default() def _init_from_numba_stream(self, obj): - self._cuda_stream = (int(obj)) - self._owner = obj + from numba import cuda + if isinstance(obj, cuda.cudadrv.driver.Stream): + self._cuda_stream = (int(obj)) + self._owner = obj + else: + raise TypeError(f"Cannot create stream from {type(obj)}") def _init_from_cupy_stream(self, obj): try: diff --git a/python/rmm/allocators/__init__.py b/python/rmm/allocators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/rmm/allocators/cupy.py b/python/rmm/allocators/cupy.py new file mode 100644 index 000000000..89947c46b --- /dev/null +++ b/python/rmm/allocators/cupy.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from rmm import _lib as librmm +from rmm._cuda.stream import Stream + +try: + import cupy +except ImportError: + cupy = None + + +def rmm_cupy_allocator(nbytes): + """ + A CuPy allocator that makes use of RMM. + + Examples + -------- + >>> from rmm.allocators.cupy import rmm_cupy_allocator + >>> import cupy + >>> cupy.cuda.set_allocator(rmm_cupy_allocator) + """ + if cupy is None: + raise ModuleNotFoundError("No module named 'cupy'") + + stream = Stream(obj=cupy.cuda.get_current_stream()) + buf = librmm.device_buffer.DeviceBuffer(size=nbytes, stream=stream) + dev_id = -1 if buf.ptr else cupy.cuda.device.get_device_id() + mem = cupy.cuda.UnownedMemory( + ptr=buf.ptr, size=buf.size, owner=buf, device_id=dev_id + ) + ptr = cupy.cuda.memory.MemoryPointer(mem, 0) + + return ptr diff --git a/python/rmm/allocators/numba.py b/python/rmm/allocators/numba.py new file mode 100644 index 000000000..18a010e1c --- /dev/null +++ b/python/rmm/allocators/numba.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes + +from cuda.cuda import CUdeviceptr, cuIpcGetMemHandle +from numba import config, cuda +from numba.cuda import HostOnlyCUDAMemoryManager, IpcHandle, MemoryPointer + +from rmm import _lib as librmm + + +def _make_emm_plugin_finalizer(handle, allocations): + """ + Factory to make the finalizer function. + We need to bind *handle* and *allocations* into the actual finalizer, which + takes no args. + """ + + def finalizer(): + """ + Invoked when the MemoryPointer is freed + """ + # At exit time (particularly in the Numba test suite) allocations may + # have already been cleaned up by a call to Context.reset() for the + # context, even if there are some DeviceNDArrays and their underlying + # allocations lying around. Finalizers then get called by weakref's + # atexit finalizer, at which point allocations[handle] no longer + # exists. This is harmless, except that a traceback is printed just + # prior to exit (without abnormally terminating the program), but is + # worrying for the user. To avoid the traceback, we check if + # allocations is already empty. + # + # In the case where allocations is not empty, but handle is not in + # allocations, then something has gone wrong - so we only guard against + # allocations being completely empty, rather than handle not being in + # allocations. + if allocations: + del allocations[handle] + + return finalizer + + +class RMMNumbaManager(HostOnlyCUDAMemoryManager): + """ + External Memory Management Plugin implementation for Numba. Provides + on-device allocation only. + + See https://numba.readthedocs.io/en/stable/cuda/external-memory.html for + details of the interface being implemented here. + """ + + def initialize(self): + # No special initialization needed to use RMM within a given context. + pass + + def memalloc(self, size): + """ + Allocate an on-device array from the RMM pool. + """ + buf = librmm.DeviceBuffer(size=size) + ctx = self.context + + if config.CUDA_USE_NVIDIA_BINDING: + ptr = CUdeviceptr(int(buf.ptr)) + else: + # expect ctypes bindings in numba + ptr = ctypes.c_uint64(int(buf.ptr)) + + finalizer = _make_emm_plugin_finalizer(int(buf.ptr), self.allocations) + + # self.allocations is initialized by the parent, HostOnlyCUDAManager, + # and cleared upon context reset, so although we insert into it here + # and delete from it in the finalizer, we need not do any other + # housekeeping elsewhere. + self.allocations[int(buf.ptr)] = buf + + return MemoryPointer(ctx, ptr, size, finalizer=finalizer) + + def get_ipc_handle(self, memory): + """ + Get an IPC handle for the MemoryPointer memory with offset modified by + the RMM memory pool. + """ + start, end = cuda.cudadrv.driver.device_extents(memory) + + if config.CUDA_USE_NVIDIA_BINDING: + _, ipc_handle = cuIpcGetMemHandle(start) + offset = int(memory.handle) - int(start) + else: + ipc_handle = (ctypes.c_byte * 64)() # IPC handle is 64 bytes + cuda.cudadrv.driver.driver.cuIpcGetMemHandle( + ctypes.byref(ipc_handle), + start, + ) + offset = memory.handle.value - start + source_info = cuda.current_context().device.get_device_identity() + + return IpcHandle( + memory, ipc_handle, memory.size, source_info, offset=offset + ) + + def get_memory_info(self): + raise NotImplementedError() + + @property + def interface_version(self): + return 1 + + +# Enables the use of RMM for Numba via an environment variable setting, +# NUMBA_CUDA_MEMORY_MANAGER=rmm. See: +# https://numba.readthedocs.io/en/stable/cuda/external-memory.html#environment-variable +_numba_memory_manager = RMMNumbaManager diff --git a/python/rmm/allocators/torch.py b/python/rmm/allocators/torch.py new file mode 100644 index 000000000..65b310a89 --- /dev/null +++ b/python/rmm/allocators/torch.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from torch.cuda.memory import CUDAPluggableAllocator +except ImportError: + rmm_torch_allocator = None +else: + import rmm._lib.torch_allocator + + _alloc_free_lib_path = rmm._lib.torch_allocator.__file__ + rmm_torch_allocator = CUDAPluggableAllocator( + _alloc_free_lib_path, + alloc_fn_name="allocate", + free_fn_name="deallocate", + ) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index cae9971dc..e5290905c 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -11,15 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import ctypes - -from cuda.cuda import CUdeviceptr, cuIpcGetMemHandle -from numba import config, cuda -from numba.cuda import HostOnlyCUDAMemoryManager, IpcHandle, MemoryPointer - -import rmm -from rmm import _lib as librmm -from rmm._cuda.stream import Stream +from rmm import mr # Utility Functions @@ -86,7 +78,7 @@ def reinitialize( for func, args, kwargs in reversed(_reinitialize_hooks): func(*args, **kwargs) - rmm.mr._initialize( + mr._initialize( pool_allocator=pool_allocator, managed_memory=managed_memory, initial_pool_size=initial_pool_size, @@ -101,155 +93,7 @@ def is_initialized(): """ Returns True if RMM has been initialized, False otherwise. """ - return rmm.mr.is_initialized() - - -class RMMNumbaManager(HostOnlyCUDAMemoryManager): - """ - External Memory Management Plugin implementation for Numba. Provides - on-device allocation only. - - See https://numba.readthedocs.io/en/stable/cuda/external-memory.html for - details of the interface being implemented here. - """ - - def initialize(self): - # No special initialization needed to use RMM within a given context. - pass - - def memalloc(self, size): - """ - Allocate an on-device array from the RMM pool. - """ - buf = librmm.DeviceBuffer(size=size) - ctx = self.context - - if config.CUDA_USE_NVIDIA_BINDING: - ptr = CUdeviceptr(int(buf.ptr)) - else: - # expect ctypes bindings in numba - ptr = ctypes.c_uint64(int(buf.ptr)) - - finalizer = _make_emm_plugin_finalizer(int(buf.ptr), self.allocations) - - # self.allocations is initialized by the parent, HostOnlyCUDAManager, - # and cleared upon context reset, so although we insert into it here - # and delete from it in the finalizer, we need not do any other - # housekeeping elsewhere. - self.allocations[int(buf.ptr)] = buf - - return MemoryPointer(ctx, ptr, size, finalizer=finalizer) - - def get_ipc_handle(self, memory): - """ - Get an IPC handle for the MemoryPointer memory with offset modified by - the RMM memory pool. - """ - start, end = cuda.cudadrv.driver.device_extents(memory) - - if config.CUDA_USE_NVIDIA_BINDING: - _, ipchandle = cuIpcGetMemHandle(start) - offset = int(memory.handle) - int(start) - else: - ipchandle = (ctypes.c_byte * 64)() # IPC handle is 64 bytes - cuda.cudadrv.driver.driver.cuIpcGetMemHandle( - ctypes.byref(ipchandle), - start, - ) - offset = memory.handle.value - start - source_info = cuda.current_context().device.get_device_identity() - - return IpcHandle( - memory, ipchandle, memory.size, source_info, offset=offset - ) - - def get_memory_info(self): - raise NotImplementedError() - - @property - def interface_version(self): - return 1 - - -def _make_emm_plugin_finalizer(handle, allocations): - """ - Factory to make the finalizer function. - We need to bind *handle* and *allocations* into the actual finalizer, which - takes no args. - """ - - def finalizer(): - """ - Invoked when the MemoryPointer is freed - """ - # At exit time (particularly in the Numba test suite) allocations may - # have already been cleaned up by a call to Context.reset() for the - # context, even if there are some DeviceNDArrays and their underlying - # allocations lying around. Finalizers then get called by weakref's - # atexit finalizer, at which point allocations[handle] no longer - # exists. This is harmless, except that a traceback is printed just - # prior to exit (without abnormally terminating the program), but is - # worrying for the user. To avoid the traceback, we check if - # allocations is already empty. - # - # In the case where allocations is not empty, but handle is not in - # allocations, then something has gone wrong - so we only guard against - # allocations being completely empty, rather than handle not being in - # allocations. - if allocations: - del allocations[handle] - - return finalizer - - -# Enables the use of RMM for Numba via an environment variable setting, -# NUMBA_CUDA_MEMORY_MANAGER=rmm. See: -# https://numba.readthedocs.io/en/stable/cuda/external-memory.html#environment-variable -_numba_memory_manager = RMMNumbaManager - -try: - import cupy -except Exception: - cupy = None - - -def rmm_cupy_allocator(nbytes): - """ - A CuPy allocator that makes use of RMM. - - Examples - -------- - >>> import rmm - >>> import cupy - >>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) - """ - if cupy is None: - raise ModuleNotFoundError("No module named 'cupy'") - - stream = Stream(obj=cupy.cuda.get_current_stream()) - buf = librmm.device_buffer.DeviceBuffer(size=nbytes, stream=stream) - dev_id = -1 if buf.ptr else cupy.cuda.device.get_device_id() - mem = cupy.cuda.UnownedMemory( - ptr=buf.ptr, size=buf.size, owner=buf, device_id=dev_id - ) - ptr = cupy.cuda.memory.MemoryPointer(mem, 0) - - return ptr - - -try: - from torch.cuda.memory import CUDAPluggableAllocator -except ImportError: - rmm_torch_allocator = None -else: - import rmm._lib.torch_allocator - - _alloc_free_lib_path = rmm._lib.torch_allocator.__file__ - rmm_torch_allocator = CUDAPluggableAllocator( - _alloc_free_lib_path, - alloc_fn_name="allocate", - free_fn_name="deallocate", - ) + return mr.is_initialized() def register_reinitialize_hook(func, *args, **kwargs): diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index f79c60b43..95afc8db3 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -24,6 +24,8 @@ import rmm import rmm._cuda.stream +from rmm.allocators.cupy import rmm_cupy_allocator +from rmm.allocators.numba import RMMNumbaManager if sys.version_info < (3, 8): try: @@ -33,7 +35,7 @@ else: import pickle -cuda.set_memory_manager(rmm.RMMNumbaManager) +cuda.set_memory_manager(RMMNumbaManager) _driver_version = rmm._cuda.gpu.driverGetVersion() _runtime_version = rmm._cuda.gpu.runtimeGetVersion() @@ -303,17 +305,17 @@ def test_rmm_pool_numba_stream(stream): def test_rmm_cupy_allocator(): cupy = pytest.importorskip("cupy") - m = rmm.rmm_cupy_allocator(42) + m = rmm_cupy_allocator(42) assert m.mem.size == 42 assert m.mem.ptr != 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) - m = rmm.rmm_cupy_allocator(0) + m = rmm_cupy_allocator(0) assert m.mem.size == 0 assert m.mem.ptr == 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) a = cupy.arange(10) assert isinstance(a.data.mem._owner, rmm.DeviceBuffer) @@ -323,7 +325,7 @@ def test_rmm_pool_cupy_allocator_with_stream(stream): cupy = pytest.importorskip("cupy") rmm.reinitialize(pool_allocator=True) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) if stream == "null": stream = cupy.cuda.stream.Stream.null @@ -331,12 +333,12 @@ def test_rmm_pool_cupy_allocator_with_stream(stream): stream = cupy.cuda.stream.Stream() with stream: - m = rmm.rmm_cupy_allocator(42) + m = rmm_cupy_allocator(42) assert m.mem.size == 42 assert m.mem.ptr != 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) - m = rmm.rmm_cupy_allocator(0) + m = rmm_cupy_allocator(0) assert m.mem.size == 0 assert m.mem.ptr == 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) @@ -355,7 +357,7 @@ def test_rmm_pool_cupy_allocator_stream_lifetime(): cupy = pytest.importorskip("cupy") rmm.reinitialize(pool_allocator=True) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) stream = cupy.cuda.stream.Stream() diff --git a/python/rmm/tests/test_rmm_pytorch.py b/python/rmm/tests/test_rmm_pytorch.py index eaa40c0ed..065507b61 100644 --- a/python/rmm/tests/test_rmm_pytorch.py +++ b/python/rmm/tests/test_rmm_pytorch.py @@ -2,7 +2,7 @@ import pytest -import rmm +from rmm.allocators.torch import rmm_torch_allocator torch = pytest.importorskip("torch") @@ -13,7 +13,7 @@ def torch_allocator(): from torch.cuda.memory import change_current_allocator except ImportError: pytest.skip("pytorch pluggable allocator not available") - change_current_allocator(rmm.rmm_torch_allocator) + change_current_allocator(rmm_torch_allocator) def test_rmm_torch_allocator(torch_allocator, stats_mr):