diff --git a/python/rmm/__init__.py b/python/rmm/__init__.py index 8dddeae3a..00d9270be 100644 --- a/python/rmm/__init__.py +++ b/python/rmm/__init__.py @@ -14,9 +14,6 @@ from rmm import mr from rmm._lib.device_buffer import DeviceBuffer -from rmm.allocators.cupy import rmm_cupy_allocator -from rmm.allocators.numba import RMMNumbaManager, _numba_memory_manager -from rmm.allocators.torch import rmm_torch_allocator from rmm.mr import disable_logging, enable_logging, get_log_filenames from rmm.rmm import ( RMMError, @@ -29,7 +26,6 @@ __all__ = [ "DeviceBuffer", "RMMError", - "RMMNumbaManager", "disable_logging", "enable_logging", "get_log_filenames", @@ -37,8 +33,6 @@ "mr", "register_reinitialize_hook", "reinitialize", - "rmm_cupy_allocator", - "rmm_torch_allocator", "unregister_reinitialize_hook", ] diff --git a/python/rmm/allocators/cupy.py b/python/rmm/allocators/cupy.py index baf57beaa..8b55ead5a 100644 --- a/python/rmm/allocators/cupy.py +++ b/python/rmm/allocators/cupy.py @@ -26,7 +26,7 @@ def rmm_cupy_allocator(nbytes): Examples -------- - >>> import rmm + >>> from rmm.allocators.cupy import rmm_cupy_allocator >>> import cupy >>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) """ diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index f79c60b43..95afc8db3 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -24,6 +24,8 @@ import rmm import rmm._cuda.stream +from rmm.allocators.cupy import rmm_cupy_allocator +from rmm.allocators.numba import RMMNumbaManager if sys.version_info < (3, 8): try: @@ -33,7 +35,7 @@ else: import pickle -cuda.set_memory_manager(rmm.RMMNumbaManager) +cuda.set_memory_manager(RMMNumbaManager) _driver_version = rmm._cuda.gpu.driverGetVersion() _runtime_version = rmm._cuda.gpu.runtimeGetVersion() @@ -303,17 +305,17 @@ def test_rmm_pool_numba_stream(stream): def test_rmm_cupy_allocator(): cupy = pytest.importorskip("cupy") - m = rmm.rmm_cupy_allocator(42) + m = rmm_cupy_allocator(42) assert m.mem.size == 42 assert m.mem.ptr != 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) - m = rmm.rmm_cupy_allocator(0) + m = rmm_cupy_allocator(0) assert m.mem.size == 0 assert m.mem.ptr == 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) a = cupy.arange(10) assert isinstance(a.data.mem._owner, rmm.DeviceBuffer) @@ -323,7 +325,7 @@ def test_rmm_pool_cupy_allocator_with_stream(stream): cupy = pytest.importorskip("cupy") rmm.reinitialize(pool_allocator=True) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) if stream == "null": stream = cupy.cuda.stream.Stream.null @@ -331,12 +333,12 @@ def test_rmm_pool_cupy_allocator_with_stream(stream): stream = cupy.cuda.stream.Stream() with stream: - m = rmm.rmm_cupy_allocator(42) + m = rmm_cupy_allocator(42) assert m.mem.size == 42 assert m.mem.ptr != 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) - m = rmm.rmm_cupy_allocator(0) + m = rmm_cupy_allocator(0) assert m.mem.size == 0 assert m.mem.ptr == 0 assert isinstance(m.mem._owner, rmm.DeviceBuffer) @@ -355,7 +357,7 @@ def test_rmm_pool_cupy_allocator_stream_lifetime(): cupy = pytest.importorskip("cupy") rmm.reinitialize(pool_allocator=True) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) stream = cupy.cuda.stream.Stream() diff --git a/python/rmm/tests/test_rmm_pytorch.py b/python/rmm/tests/test_rmm_pytorch.py index eaa40c0ed..065507b61 100644 --- a/python/rmm/tests/test_rmm_pytorch.py +++ b/python/rmm/tests/test_rmm_pytorch.py @@ -2,7 +2,7 @@ import pytest -import rmm +from rmm.allocators.torch import rmm_torch_allocator torch = pytest.importorskip("torch") @@ -13,7 +13,7 @@ def torch_allocator(): from torch.cuda.memory import change_current_allocator except ImportError: pytest.skip("pytorch pluggable allocator not available") - change_current_allocator(rmm.rmm_torch_allocator) + change_current_allocator(rmm_torch_allocator) def test_rmm_torch_allocator(torch_allocator, stats_mr):