Skip to content

Commit

Permalink
WIP: Correct signatures for torch allocator plug in
Browse files Browse the repository at this point in the history
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes #1405
  • Loading branch information
wence- committed Dec 12, 2023
1 parent 53c8043 commit fe84d63
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 31 deletions.
5 changes: 1 addition & 4 deletions python/rmm/_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# the License.
# =============================================================================

set(cython_sources device_buffer.pyx lib.pyx logger.pyx memory_resource.pyx cuda_stream.pyx
torch_allocator.pyx)
set(cython_sources device_buffer.pyx lib.pyx logger.pyx memory_resource.pyx cuda_stream.pyx)
set(linked_libraries rmm::rmm)

# Build all of the Cython targets
rapids_cython_create_modules(SOURCE_FILES "${cython_sources}" LINKED_LIBRARIES "${linked_libraries}"
CXX)
# The cdef public functions in this file need to have a C ABI
target_compile_definitions(torch_allocator PRIVATE CYTHON_EXTERN_C=extern\ "C")
24 changes: 0 additions & 24 deletions python/rmm/_lib/torch_allocator.pyx

This file was deleted.

56 changes: 56 additions & 0 deletions python/rmm/allocators/_torch_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_runtime_api.h>

#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

// These signatures must match those required by CUDAPluggableAllocator in
// github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/CUDAPluggableAllocator.h
// Since the loading is done at runtime via dlopen, no error checking
// can before performed.

/**
* @brief Allocate memory of at least \p size bytes.
*
* @throws rmm::bad_alloc When the requested allocation cannot be satisfied.
*
* @param size The number of bytes to allocate
* @param device The device whose memory resource one should use
* @param stream CUDA stream to perform allocation on
* @return void* Pointer to the newly allocated memory
*/
extern "C" void* allocate(std::size_t size, int device, void* stream)
{
auto mr = rmm::mr::get_per_device_resource(rmm::cuda_device_id{device});
return mr->allocate(size, rmm::cuda_stream_view{static_cast<cudaStream_t>(stream)});
}

/**
* @brief Deallocate memory pointed to by \p ptr.
*
* @param ptr Pointer to be deallocated
* @param size The number of bytes in the allocation
* @param device The device whose memory resource one should use
* @param stream CUDA stream to perform deallocation on
*/
extern "C" void deallocate(void* ptr, std::size_t size, int device, void* stream)
{
auto mr = rmm::mr::get_per_device_resource(rmm::cuda_device_id{device});
mr->deallocate(ptr, size, rmm::cuda_stream_view{static_cast<cudaStream_t>(stream)});
}
6 changes: 3 additions & 3 deletions python/rmm/allocators/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
except ImportError:
rmm_torch_allocator = None
else:
import rmm._lib.torch_allocator
import pathlib

_alloc_free_lib_path = rmm._lib.torch_allocator.__file__
rmm_torch_allocator = CUDAPluggableAllocator(
_alloc_free_lib_path,
pathlib.Path(__file__).absolute().parent / "_torch_allocator.so",
alloc_fn_name="allocate",
free_fn_name="deallocate",
)
del pathlib

0 comments on commit fe84d63

Please sign in to comment.