Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/rmm/cuda_stream_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class cuda_stream_pool {
* @param stream_id Unique identifier for the desired stream
*
* @return rmm::cuda_stream_view
*
* @note @p stream_id is wrapped around the pool size, therefore any size_t value is
* allowed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A size_t cannot be negative so the non negative stipulation is unnecessary

*/
rmm::cuda_stream_view get_stream(std::size_t stream_id) const;

Expand Down
7 changes: 6 additions & 1 deletion python/rmm/rmm/librmm/cuda_stream.pxd
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2020-2024, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings.cyruntime cimport cudaStream_t
from libcpp cimport bool
from libc.stdint cimport uint32_t

from rmm.librmm.cuda_stream_view cimport cuda_stream_view


cdef extern from "rmm/cuda_stream.hpp" namespace "rmm" nogil:

cpdef enum class cuda_stream_flags "rmm::cuda_stream::flags" (uint32_t):
sync_default "rmm::cuda_stream::flags::sync_default"
non_blocking "rmm::cuda_stream::flags::non_blocking"
cdef cppclass cuda_stream:
cuda_stream() except +
bool is_valid() except +
Expand Down
5 changes: 3 additions & 2 deletions python/rmm/rmm/librmm/cuda_stream_pool.pxd
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

from rmm.librmm.cuda_stream_view cimport cuda_stream_view

from rmm.librmm.cuda_stream cimport cuda_stream_flags

cdef extern from "rmm/cuda_stream_pool.hpp" namespace "rmm" nogil:
cdef cppclass cuda_stream_pool:
cuda_stream_pool(size_t pool_size)
cuda_stream_pool(size_t pool_size, cuda_stream_flags flags)
cuda_stream_view get_stream()
cuda_stream_view get_stream(size_t stream_id) except +
size_t get_pool_size()
3 changes: 2 additions & 1 deletion python/rmm/rmm/pylibrmm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# cmake-format: on
# =============================================================================

set(cython_sources device_buffer.pyx logger.pyx cuda_stream.pyx helper.pyx stream.pyx utils.pyx)
set(cython_sources cuda_stream.pyx cuda_stream_pool.pyx device_buffer.pyx logger.pyx helper.pyx
stream.pyx utils.pyx)
set(linked_libraries rmm::rmm)

# Build all of the Cython targets
Expand Down
9 changes: 8 additions & 1 deletion python/rmm/rmm/pylibrmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

from rmm.pylibrmm import memory_resource

from .cuda_stream_pool import CudaStreamPool
from .cuda_stream import CudaStreamFlags
from .device_buffer import DeviceBuffer

__all__ = ["DeviceBuffer", "memory_resource"]
__all__ = [
"CudaStreamPool",
"CudaStreamFlags",
"DeviceBuffer",
"memory_resource",
]
18 changes: 17 additions & 1 deletion python/rmm/rmm/pylibrmm/cuda_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,26 @@
# SPDX-License-Identifier: Apache-2.0

cimport cython
from enum import IntEnum
from cuda.bindings.cyruntime cimport cudaStream_t
from libcpp cimport bool

from rmm.librmm.cuda_stream cimport cuda_stream
from rmm.librmm.cuda_stream cimport cuda_stream, cuda_stream_flags


class CudaStreamFlags(IntEnum):
"""
Enumeration of CUDA stream creation flags.

Attributes
----------
SYNC_DEFAULT : int
Created stream synchronizes with the default stream.
NON_BLOCKING : int
Created stream does not synchronize with the default stream.
"""
SYNC_DEFAULT = <int>cuda_stream_flags.sync_default
NON_BLOCKING = <int>cuda_stream_flags.non_blocking


@cython.final
Expand Down
13 changes: 13 additions & 0 deletions python/rmm/rmm/pylibrmm/cuda_stream_pool.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

cimport cython
from libc.stddef cimport size_t
from libcpp.memory cimport unique_ptr

from rmm.librmm.cuda_stream_pool cimport cuda_stream_pool


@cython.final
cdef class CudaStreamPool:
cdef unique_ptr[cuda_stream_pool] c_obj
67 changes: 67 additions & 0 deletions python/rmm/rmm/pylibrmm/cuda_stream_pool.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

cimport cython
from libc.stddef cimport size_t
from cython.operator cimport dereference as deref

from rmm.librmm.cuda_stream cimport cuda_stream_flags
from rmm.librmm.cuda_stream_pool cimport cuda_stream_pool

from rmm.pylibrmm.stream cimport Stream

from typing import Optional


@cython.final
cdef class CudaStreamPool:
"""
A pool of CUDA streams for efficient stream management.

Provides thread-safe access to a collection of CUDA stream objects.
Successive calls may return views of identical streams.
"""

def __cinit__(self, size_t pool_size = 16,
cuda_stream_flags flags = cuda_stream_flags.sync_default):
with nogil:
self.c_obj.reset(new cuda_stream_pool(pool_size, flags))

def __dealloc__(self):
with nogil:
self.c_obj.reset()

def get_stream(self, stream_id: Optional[int] = None) -> Stream:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method and get_pool_size need docstrings. You can probably copy from the C++ docstrings.

"""
Get a Stream from the pool (optionally by ID).

Parameters
----------
stream_id : Optional[int], optional
The ID of the stream to get. If None, the next stream from the pool is
returned.

Returns
-------
Stream
A non-owning Stream object from the pool
"""
cdef size_t c_stream_id
if stream_id is None:
return Stream._from_cudaStream_t(
deref(self.c_obj).get_stream().value(), owner=self)
else:
c_stream_id = <size_t>stream_id
return Stream._from_cudaStream_t(
deref(self.c_obj).get_stream(c_stream_id).value(), owner=self)

def get_pool_size(self) -> int:
"""
Get the pool size.

Returns
-------
int
The number of streams in the pool
"""
return deref(self.c_obj).get_pool_size()
5 changes: 5 additions & 0 deletions python/rmm/rmm/pylibrmm/stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ cdef class Stream:
pass
raise TypeError(f"Cannot create stream from {type(obj)}")

def __eq__(self, other):
if isinstance(other, Stream):
return self.view() == (<Stream>other).view()
return False

cdef void _init_with_new_cuda_stream(self) except *:
cdef CudaStream stream = CudaStream()
self._cuda_stream = stream.value()
Expand Down
29 changes: 29 additions & 0 deletions python/rmm/rmm/tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pytest
from cuda.core.experimental import Device

import rmm.pylibrmm.cuda_stream
import rmm.pylibrmm.cuda_stream_pool
import rmm.pylibrmm.stream

CUDA_CORE_VERSION = importlib.metadata.version("cuda-core")
Expand Down Expand Up @@ -93,3 +95,30 @@ def test_cuda_core_buffer(current_device):
buf = cuda_core_mr.allocate(1024, stream=rmm_stream)
buf.close(stream=rmm_stream)
rmm_stream.synchronize()


@pytest.mark.parametrize(
"flags",
[
rmm.pylibrmm.cuda_stream.CudaStreamFlags.SYNC_DEFAULT,
rmm.pylibrmm.cuda_stream.CudaStreamFlags.NON_BLOCKING,
],
)
def test_cuda_stream_pool(current_device, flags):
default_rmm_stream = rmm.pylibrmm.stream.Stream(
current_device.default_stream
)

stream_pool = rmm.pylibrmm.cuda_stream_pool.CudaStreamPool(
pool_size=10, flags=flags
)
assert stream_pool.get_pool_size() == 10

streams = [stream_pool.get_stream() for _ in range(10)]

for i in range(10):
for j in range(i + 1, 10):
assert streams[i] != streams[j]
# should not be the default stream
assert streams[i] != default_rmm_stream
assert streams[i] == stream_pool.get_stream(i)