Skip to content

Use a stream pool for gpuCalloc*() #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
@@ -71,6 +71,53 @@ struct CudaStreamWithFlags {
cudaStream_t stream_;
};

class GpuStreamPool;

/// A managed non-blocking GPU stream object created by GpuStreamPool.
/// This object does not own the stream.
/// When this object is destroyed, it will return the underlying CudaStreamWithFlags to the pool.
class GpuStream {
protected:
/// Constructor. Only called by a GpuStreamPool.
/// @param pool A shared pointer to the GpuStreamPool that manages this stream.
/// @param stream A shared pointer to the CudaStreamWithFlags that represents the underlying stream.
GpuStream(std::shared_ptr<GpuStreamPool> pool, std::shared_ptr<CudaStreamWithFlags> stream);

public:
/// Destructor. This will return the underlying CudaStreamWithFlags to the pool, not destroy it.
~GpuStream();

operator cudaStream_t() const { return stream_->stream_; }

private:
friend class GpuStreamPool;

std::shared_ptr<GpuStreamPool> pool_;
std::shared_ptr<CudaStreamWithFlags> stream_;
};

/// A pool of managed GPU streams. Only provides non-blocking streams.
/// This is intended to be used for reusing temporal streams.
class GpuStreamPool : public std::enable_shared_from_this<GpuStreamPool> {
public:
GpuStreamPool();

/// Get a non-blocking GPU stream from the pool. If no streams are available, a new one will be created.
/// @return A GpuStream object.
GpuStream getStream();

/// Clear the pool, which will remove all streams from the pool.
void clear();

protected:
friend class GpuStream;
std::vector<std::shared_ptr<CudaStreamWithFlags>> streams_;
};

/// Get the singleton instance of GpuStreamPool.
/// @return A shared pointer to the GpuStreamPool instance.
std::shared_ptr<GpuStreamPool> gpuStreamPool();

namespace detail {

void setReadWriteMemoryAccess(void* base, size_t size);
29 changes: 26 additions & 3 deletions src/gpu_utils.cc
Original file line number Diff line number Diff line change
@@ -28,6 +28,29 @@ void CudaStreamWithFlags::set(unsigned int flags) {

bool CudaStreamWithFlags::empty() const { return stream_ == nullptr; }

GpuStream::GpuStream(std::shared_ptr<GpuStreamPool> pool, std::shared_ptr<CudaStreamWithFlags> stream)
: pool_(pool), stream_(stream) {}

GpuStream::~GpuStream() { pool_->streams_.push_back(stream_); }

GpuStreamPool::GpuStreamPool() {}

GpuStream GpuStreamPool::getStream() {
if (!streams_.empty()) {
auto stream = streams_.back();
streams_.pop_back();
return GpuStream(shared_from_this(), stream);
}
return GpuStream(shared_from_this(), std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking));
}

void GpuStreamPool::clear() { streams_.clear(); }

std::shared_ptr<GpuStreamPool> gpuStreamPool() {
static std::shared_ptr<GpuStreamPool> pool = std::make_shared<GpuStreamPool>();
return pool;
}

namespace detail {

CUmemAllocationHandleType nvlsCompatibleMemHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
@@ -48,7 +71,7 @@ void setReadWriteMemoryAccess(void* base, size_t size) {
void* gpuCalloc(size_t bytes) {
AvoidCudaGraphCaptureGuard cgcGuard;
void* ptr;
CudaStreamWithFlags stream(cudaStreamNonBlocking);
auto stream = gpuStreamPool()->getStream();
MSCCLPP_CUDATHROW(cudaMalloc(&ptr, bytes));
MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, 0, bytes, stream));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
@@ -67,7 +90,7 @@ void* gpuCallocHost(size_t bytes) {
void* gpuCallocUncached(size_t bytes) {
AvoidCudaGraphCaptureGuard cgcGuard;
void* ptr;
CudaStreamWithFlags stream(cudaStreamNonBlocking);
auto stream = gpuStreamPool()->getStream();
MSCCLPP_CUDATHROW(hipExtMallocWithFlags((void**)&ptr, bytes, hipDeviceMallocUncached));
MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, 0, bytes, stream));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
@@ -137,7 +160,7 @@ void* gpuCallocPhysical(size_t bytes, size_t gran, size_t align) {
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, nbytes, align, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, nbytes, 0, memHandle, 0));
setReadWriteMemoryAccess(devicePtr, nbytes);
CudaStreamWithFlags stream(cudaStreamNonBlocking);
auto stream = gpuStreamPool()->getStream();
MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, nbytes, stream));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));

2 changes: 1 addition & 1 deletion test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

target_sources(unit_tests PRIVATE
core_tests.cc
cuda_utils_tests.cc
gpu_utils_tests.cc
errors_tests.cc
fifo_tests.cu
numa_tests.cc
30 changes: 25 additions & 5 deletions test/unit/cuda_utils_tests.cc → test/unit/gpu_utils_tests.cc
Original file line number Diff line number Diff line change
@@ -5,27 +5,47 @@

#include <mscclpp/gpu_utils.hpp>

TEST(CudaUtilsTest, AllocShared) {
TEST(GpuUtilsTest, StreamPool) {
auto streamPool = mscclpp::gpuStreamPool();
cudaStream_t s;
{
auto stream1 = streamPool->getStream();
s = stream1;
EXPECT_NE(s, nullptr);
}
{
auto stream2 = streamPool->getStream();
EXPECT_EQ(cudaStream_t(stream2), s);
}
{
auto stream3 = streamPool->getStream();
auto stream4 = streamPool->getStream();
EXPECT_NE(cudaStream_t(stream3), cudaStream_t(stream4));
}
streamPool->clear();
}

TEST(GpuUtilsTest, AllocShared) {
auto p1 = mscclpp::detail::gpuCallocShared<uint32_t>();
auto p2 = mscclpp::detail::gpuCallocShared<int64_t>(5);
}

TEST(CudaUtilsTest, AllocUnique) {
TEST(GpuUtilsTest, AllocUnique) {
auto p1 = mscclpp::detail::gpuCallocUnique<uint32_t>();
auto p2 = mscclpp::detail::gpuCallocUnique<int64_t>(5);
}

TEST(CudaUtilsTest, MakeSharedHost) {
TEST(GpuUtilsTest, MakeSharedHost) {
auto p1 = mscclpp::detail::gpuCallocHostShared<uint32_t>();
auto p2 = mscclpp::detail::gpuCallocHostShared<int64_t>(5);
}

TEST(CudaUtilsTest, MakeUniqueHost) {
TEST(GpuUtilsTest, MakeUniqueHost) {
auto p1 = mscclpp::detail::gpuCallocHostUnique<uint32_t>();
auto p2 = mscclpp::detail::gpuCallocHostUnique<int64_t>(5);
}

TEST(CudaUtilsTest, Memcpy) {
TEST(GpuUtilsTest, Memcpy) {
const int nElem = 1024;
std::vector<int> hostBuff(nElem);
for (int i = 0; i < nElem; ++i) {
Loading
Oops, something went wrong.