Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move external allocators into rmm.allocators module to defer imports #1221

Merged
merged 14 commits into from
Feb 27, 2023
Merged
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,18 @@ resources
MemoryResources are highly configurable and can be composed together in different ways.
See `help(rmm.mr)` for more information.

### Using RMM with CuPy
## Using RMM with third-party libraries

#### Using RMM with CuPy
wence- marked this conversation as resolved.
Show resolved Hide resolved

You can configure [CuPy](https://cupy.dev/) to use RMM for memory
allocations by setting the CuPy CUDA allocator to
`rmm_cupy_allocator`:

```python
>>> import rmm
>>> from rmm.allocators.cupy import rmm_cupy_allocator
>>> import cupy
>>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
>>> cupy.cuda.set_allocator(rmm_cupy_allocator)
```


Expand All @@ -718,15 +720,15 @@ This can be done in two ways:
1. Setting the environment variable `NUMBA_CUDA_MEMORY_MANAGER`:

```python
$ NUMBA_CUDA_MEMORY_MANAGER=rmm python (args)
$ NUMBA_CUDA_MEMORY_MANAGER=rmm.allocators.numba python (args)
```

2. Using the `set_memory_manager()` function provided by Numba:

```python
>>> from numba import cuda
>>> import rmm
>>> cuda.set_memory_manager(rmm.RMMNumbaManager)
>>> from rmm.allocators.numba import RMMNumbaManager
>>> cuda.set_memory_manager(RMMNumbaManager)
```

**Note:** This only configures Numba to use the current RMM resource for allocations.
Expand All @@ -741,10 +743,11 @@ RMM-managed pool:

```python
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

rmm.reinitialize(pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
```

PyTorch and RMM will now share the same memory pool.
Expand All @@ -753,13 +756,14 @@ You can, of course, use a custom memory resource with PyTorch as well:

```python
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

# note that you can configure PyTorch to use RMM either before or
# after changing RMM's memory resource. PyTorch will use whatever
# memory resource is configured to be the "current" memory resource at
# the time of allocation.
torch.cuda.change_current_allocator(rmm.rmm_torch_allocator)
torch.cuda.change_current_allocator(rmm_torch_allocator)

# configure RMM to use a managed memory resource, wrapped with a
# statistics resource adaptor that can report information about the
Expand Down
18 changes: 18 additions & 0 deletions python/docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,21 @@ Memory Resources
:members:
:undoc-members:
:show-inheritance:

Memory Allocators
-----------------

.. automodule:: rmm.allocators.cupy
:members:
:undoc-members:
:show-inheritance:

.. automodule:: rmm.allocators.numba
:members:
:undoc-members:
:show-inheritance:

.. automodule:: rmm.allocators.torch
:members:
:undoc-members:
:show-inheritance:
40 changes: 32 additions & 8 deletions python/docs/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,35 +131,59 @@ resources
MemoryResources are highly configurable and can be composed together in different ways.
See `help(rmm.mr)` for more information.

### Using RMM with CuPy
## Using RMM with third-party libraries

A number of libraries provide hooks to control their device
allocations. RMM provides implementations of these for
[CuPy](https://cupy.dev),
[numba](https://numba.readthedocs.io/en/stable/), and [PyTorch](https://pytorch.org) in the
`rmm.allocators` submodule. All these approaches configure the library
to use whichever the _current_ RMM memory resource is for device
bdice marked this conversation as resolved.
Show resolved Hide resolved
allocations.

#### Using RMM with CuPy
wence- marked this conversation as resolved.
Show resolved Hide resolved

You can configure [CuPy](https://cupy.dev/) to use RMM for memory
allocations by setting the CuPy CUDA allocator to
`rmm_cupy_allocator`:
`rmm.allocators.cupy.rmm_cupy_allocator`:

```python
>>> import rmm
>>> from rmm.allocators.cupy import rmm_cupy_allocator
>>> import cupy
>>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
>>> cupy.cuda.set_allocator(rmm_cupy_allocator)
```

### Using RMM with Numba

You can configure Numba to use RMM for memory allocations using the
You can configure [Numba](https://numba.readthedocs.io/en/stable/) to use RMM for memory allocations using the
Numba [EMM Plugin](https://numba.readthedocs.io/en/stable/cuda/external-memory.html#setting-emm-plugin).

This can be done in two ways:

1. Setting the environment variable `NUMBA_CUDA_MEMORY_MANAGER`:

```bash
$ NUMBA_CUDA_MEMORY_MANAGER=rmm python (args)
$ NUMBA_CUDA_MEMORY_MANAGER=rmm.allocators.numba python (args)
```

2. Using the `set_memory_manager()` function provided by Numba:

```python
>>> from numba import cuda
>>> import rmm
>>> cuda.set_memory_manager(rmm.RMMNumbaManager)
>>> from rmm.allocators.numba import RMMNumbaManager
>>> cuda.set_memory_manager(RMMNumbaManager)
```

### Using RMM with PyTorch

You can configure
[PyTorch](https://pytorch.org/docs/stable/notes/cuda.html) to use RMM
for memory allocations using their by configuring the current
allocator.

```python
from rmm.allocators.torch import rmm_torch_allocator
import torch

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
```
6 changes: 0 additions & 6 deletions python/rmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,22 @@
from rmm.mr import disable_logging, enable_logging, get_log_filenames
from rmm.rmm import (
RMMError,
RMMNumbaManager,
_numba_memory_manager,
is_initialized,
register_reinitialize_hook,
reinitialize,
rmm_cupy_allocator,
rmm_torch_allocator,
unregister_reinitialize_hook,
)

__all__ = [
"DeviceBuffer",
"RMMError",
"RMMNumbaManager",
"disable_logging",
"enable_logging",
"get_log_filenames",
"is_initialized",
"mr",
"register_reinitialize_hook",
"reinitialize",
"rmm_cupy_allocator",
"unregister_reinitialize_hook",
]

wence- marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
3 changes: 2 additions & 1 deletion python/rmm/_cuda/gpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2020, NVIDIA CORPORATION.

import numba.cuda
from cuda import cuda, cudart


Expand Down Expand Up @@ -84,6 +83,8 @@ def runtimeGetVersion():
"""
# TODO: Replace this with `cuda.cudart.cudaRuntimeGetVersion()` when the
# limitation is fixed.
import numba.cuda

major, minor = numba.cuda.runtime.get_version()
return major * 1000 + minor * 10

Expand Down
22 changes: 11 additions & 11 deletions python/rmm/_cuda/stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,14 @@ from cuda.ccudart cimport cudaStream_t
from libc.stdint cimport uintptr_t
from libcpp cimport bool

from rmm._lib.cuda_stream cimport CudaStream
from rmm._lib.cuda_stream_view cimport (
cuda_stream_default,
cuda_stream_legacy,
cuda_stream_per_thread,
cuda_stream_view,
)

from numba import cuda

from rmm._lib.cuda_stream cimport CudaStream

from rmm._lib.cuda_stream import CudaStream


cdef class Stream:
def __init__(self, obj=None):
Expand All @@ -46,10 +41,11 @@ cdef class Stream:
self._init_with_new_cuda_stream()
elif isinstance(obj, Stream):
self._init_from_stream(obj)
elif isinstance(obj, cuda.cudadrv.driver.Stream):
self._init_from_numba_stream(obj)
else:
self._init_from_cupy_stream(obj)
try:
self._init_from_numba_stream(obj)
except TypeError:
self._init_from_cupy_stream(obj)

@staticmethod
cdef Stream _from_cudaStream_t(cudaStream_t s, object owner=None):
Expand Down Expand Up @@ -94,8 +90,12 @@ cdef class Stream:
return self.c_is_default()

def _init_from_numba_stream(self, obj):
self._cuda_stream = <cudaStream_t><uintptr_t>(int(obj))
self._owner = obj
from numba import cuda
if isinstance(obj, cuda.cudadrv.driver.Stream):
self._cuda_stream = <cudaStream_t><uintptr_t>(int(obj))
self._owner = obj
else:
raise TypeError(f"Cannot create stream from {type(obj)}")

def _init_from_cupy_stream(self, obj):
try:
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions python/rmm/allocators/cupy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from rmm import _lib as librmm
from rmm._cuda.stream import Stream

try:
import cupy
except ImportError:
cupy = None


def rmm_cupy_allocator(nbytes):
"""
A CuPy allocator that makes use of RMM.

Examples
--------
>>> from rmm.allocators.cupy import rmm_cupy_allocator
>>> import cupy
>>> cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
"""
if cupy is None:
raise ModuleNotFoundError("No module named 'cupy'")

stream = Stream(obj=cupy.cuda.get_current_stream())
buf = librmm.device_buffer.DeviceBuffer(size=nbytes, stream=stream)
dev_id = -1 if buf.ptr else cupy.cuda.device.get_device_id()
mem = cupy.cuda.UnownedMemory(
ptr=buf.ptr, size=buf.size, owner=buf, device_id=dev_id
)
ptr = cupy.cuda.memory.MemoryPointer(mem, 0)

return ptr
Loading