Skip to content

Commit

Permalink
Replace RMM CUDA Python bindings with those provided by CUDA-Python (#…
Browse files Browse the repository at this point in the history
…451)

As a follow up to rapidsai/rmm#930, fix RAFT to rely on CUDA Python directly rather than custom  CUDA bindings that RMM provided.

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Jordan Jacobelli (https://github.com/Ethyling)

URL: #451
  • Loading branch information
shwina committed Jan 20, 2022
1 parent 73585f4 commit c52420d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 49 deletions.
1 change: 1 addition & 0 deletions conda/environments/raft_dev_cuda11.5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
dependencies:
- cudatoolkit=11.5
- cuda-python >=11.5,<12.0
- clang=11.1.0
- clang-tools=11.1.0
- rapids-build-env=22.02.*
Expand Down
24 changes: 5 additions & 19 deletions python/raft/common/cuda.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,23 +14,9 @@
# limitations under the License.
#

# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3
from cuda.ccudart cimport cudaStream_t

cdef class Stream:
cdef cudaStream_t s

# Populate this with more typedef's (eg: events) as and when needed
cdef extern from * nogil:
ctypedef void* _Stream "cudaStream_t"
ctypedef int _Error "cudaError_t"


# Populate this with more runtime api method declarations as and when needed
cdef extern from "cuda_runtime_api.h" nogil:
_Error cudaStreamCreate(_Stream* s)
_Error cudaStreamDestroy(_Stream s)
_Error cudaStreamSynchronize(_Stream s)
_Error cudaGetLastError()
const char* cudaGetErrorString(_Error e)
const char* cudaGetErrorName(_Error e)
cdef cudaStream_t getStream(self)
47 changes: 23 additions & 24 deletions python/raft/common/cuda.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,22 @@
# cython: embedsignature = True
# cython: language_level = 3

from cuda.ccudart cimport(
cudaStream_t,
cudaError_t,
cudaSuccess,
cudaStreamCreate,
cudaStreamDestroy,
cudaStreamSynchronize,
cudaGetLastError,
cudaGetErrorString,
cudaGetErrorName
)


class CudaRuntimeError(RuntimeError):
def __init__(self, extraMsg=None):
cdef _Error e = cudaGetLastError()
cdef cudaError_t e = cudaGetLastError()
cdef bytes errMsg = cudaGetErrorString(e)
cdef bytes errName = cudaGetErrorName(e)
msg = "Error! %s reason='%s'" % (errName.decode(), errMsg.decode())
Expand All @@ -45,29 +57,17 @@ cdef class Stream:
stream.sync()
del stream # optional!
"""

# NOTE:
# If we store _Stream directly, this always leads to the following error:
# "Cannot convert Python object to '_Stream'"
# I was unable to find a good solution to this in reasonable time. Also,
# since cudaStream_t is a pointer anyways, storing it as an integer should
# be just fine (although, that certainly is ugly and hacky!).
cdef size_t s

def __cinit__(self):
if self.s != 0:
return
cdef _Stream stream
cdef _Error e = cudaStreamCreate(&stream)
if e != 0:
cdef cudaStream_t stream
cdef cudaError_t e = cudaStreamCreate(&stream)
if e != cudaSuccess:
raise CudaRuntimeError("Stream create")
self.s = <size_t>stream
self.s = stream

def __dealloc__(self):
self.sync()
cdef _Stream stream = <_Stream>self.s
cdef _Error e = cudaStreamDestroy(stream)
if e != 0:
cdef cudaError_t e = cudaStreamDestroy(self.s)
if e != cudaSuccess:
raise CudaRuntimeError("Stream destroy")

def sync(self):
Expand All @@ -76,10 +76,9 @@ cdef class Stream:
could raise exception due to issues with previous asynchronous
launches
"""
cdef _Stream stream = <_Stream>self.s
cdef _Error e = cudaStreamSynchronize(stream)
if e != 0:
cdef cudaError_t e = cudaStreamSynchronize(self.s)
if e != cudaSuccess:
raise CudaRuntimeError("Stream sync")

def getStream(self):
cdef cudaStream_t getStream(self):
return self.s
3 changes: 1 addition & 2 deletions python/raft/common/handle.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,6 @@


from libcpp.memory cimport shared_ptr
from .cuda cimport _Stream
from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.cuda_stream_pool cimport cuda_stream_pool
from libcpp.memory cimport shared_ptr
Expand Down
9 changes: 5 additions & 4 deletions python/raft/common/handle.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,10 @@ from libcpp.memory cimport shared_ptr
from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread
from rmm._lib.cuda_stream_view cimport cuda_stream_view

from .cuda cimport _Stream, _Error, cudaStreamSynchronize
from .cuda cimport Stream
from .cuda import CudaRuntimeError


cdef class Handle:
"""
Handle is a lightweight python wrapper around the corresponding C++ class
Expand All @@ -51,7 +52,7 @@ cdef class Handle:
del handle # optional!
"""

def __cinit__(self, stream=None, n_streams=0):
def __cinit__(self, stream: Stream = None, n_streams=0):
self.n_streams = n_streams
if n_streams > 0:
self.stream_pool.reset(new cuda_stream_pool(n_streams))
Expand All @@ -64,7 +65,7 @@ cdef class Handle:
self.stream_pool))
else:
# this constructor constructs a handle on user stream
c_stream = cuda_stream_view(<_Stream><size_t> stream.getStream())
c_stream = cuda_stream_view(stream.getStream())
self.c_obj.reset(new handle_t(c_stream,
self.stream_pool))

Expand Down

0 comments on commit c52420d

Please sign in to comment.