From ff3d11a2a4543709f096c302149e28e74b48253b Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 12 May 2025 16:36:46 -0700 Subject: [PATCH] gloo/cuda: use torch dtype bf16 --- CMakeLists.txt | 1 + gloo/CMakeLists.txt | 10 ++++++++++ gloo/config.h.in | 2 ++ gloo/cuda.cu | 9 +++++++++ gloo/cuda_allreduce_bcube.cc | 4 ++++ gloo/cuda_allreduce_halving_doubling.cc | 4 ++++ gloo/cuda_allreduce_local.cc | 4 ++++ gloo/cuda_allreduce_ring.cc | 4 ++++ gloo/cuda_allreduce_ring_chunked.cc | 4 ++++ gloo/cuda_broadcast_one_to_all.cc | 4 ++++ gloo/cuda_private.h | 4 ++++ 11 files changed, 50 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 359a2764d..caf3bf388 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,7 @@ if(${USE_TCP_OPENSSL_LINK} AND ${USE_TCP_OPENSSL_LOAD}) endif() option(USE_CUDA "Build with CUDA support" OFF) option(GLOO_USE_CUDA_TOOLKIT "Build CUDA with FindCUDATookit.cmake and enable_language(CUDA)" OFF) +option(GLOO_USE_TORCH_DTYPES "Build CUDA kernels with pytorch dtypes" 0) if(MSVC) message(STATUS "MSVC detected") diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index d5e6a1a55..186fe1288 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -171,9 +171,19 @@ target_link_libraries(gloo PRIVATE ${gloo_DEPENDENCY_LIBS}) target_include_directories(gloo INTERFACE $) if(USE_CUDA) target_include_directories(gloo_cuda INTERFACE $) + + message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}") + if(GLOO_USE_TORCH_DTYPES) + target_include_directories(gloo_cuda PRIVATE ${GLOO_TORCH_DIR}) + endif() endif() if(USE_ROCM) target_include_directories(gloo_hip INTERFACE $) + + message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}") + if(GLOO_USE_TORCH_DTYPES) + target_include_directories(gloo_hip PRIVATE ${GLOO_TORCH_DIR}) + endif() endif() # Install if necessary. diff --git a/gloo/config.h.in b/gloo/config.h.in index 70462bd8e..a30be3ed9 100644 --- a/gloo/config.h.in +++ b/gloo/config.h.in @@ -37,3 +37,5 @@ static_assert( #cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS #cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS #cmakedefine01 GLOO_HAVE_TRANSPORT_UV + +#cmakedefine01 GLOO_USE_TORCH_DTYPES diff --git a/gloo/cuda.cu b/gloo/cuda.cu index 8ae2ccaea..0eb2e4dd2 100644 --- a/gloo/cuda.cu +++ b/gloo/cuda.cu @@ -391,4 +391,13 @@ DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(double, cudaMax, >); DELEGATE_HALF_PRECISION_CUDA_BINARY_COMPARE(cudaMin, <); DELEGATE_HALF_PRECISION_CUDA_BINARY_COMPARE(cudaMax, >); +#if GLOO_USE_TORCH_DTYPES +using BFloat16 = c10::BFloat16; +INSTANTIATE_COPY_ASYNC(BFloat16); +DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaSum, +); +DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaProduct, *); +DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMin, <); +DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMax, >); +#endif + } // namespace gloo diff --git a/gloo/cuda_allreduce_bcube.cc b/gloo/cuda_allreduce_bcube.cc index beb9c3e0c..c4859be81 100644 --- a/gloo/cuda_allreduce_bcube.cc +++ b/gloo/cuda_allreduce_bcube.cc @@ -514,4 +514,8 @@ INSTANTIATE_TEMPLATE(float); INSTANTIATE_TEMPLATE(double); INSTANTIATE_TEMPLATE(float16); +#if GLOO_USE_TORCH_DTYPES +INSTANTIATE_TEMPLATE(c10::BFloat16); +#endif + } // namespace gloo diff --git a/gloo/cuda_allreduce_halving_doubling.cc b/gloo/cuda_allreduce_halving_doubling.cc index f83092071..56ca136db 100644 --- a/gloo/cuda_allreduce_halving_doubling.cc +++ b/gloo/cuda_allreduce_halving_doubling.cc @@ -657,4 +657,8 @@ INSTANTIATE_TEMPLATE(float); INSTANTIATE_TEMPLATE(double); INSTANTIATE_TEMPLATE(float16); +#if GLOO_USE_TORCH_DTYPES +INSTANTIATE_TEMPLATE(c10::BFloat16); +#endif + } // namespace gloo diff --git a/gloo/cuda_allreduce_local.cc b/gloo/cuda_allreduce_local.cc index 4997e09aa..40d366ef9 100644 --- a/gloo/cuda_allreduce_local.cc +++ b/gloo/cuda_allreduce_local.cc @@ -76,4 +76,8 @@ INSTANTIATE_TEMPLATE(float); INSTANTIATE_TEMPLATE(double); INSTANTIATE_TEMPLATE(float16); +#if GLOO_USE_TORCH_DTYPES +INSTANTIATE_TEMPLATE(c10::BFloat16); +#endif + } // namespace gloo diff --git a/gloo/cuda_allreduce_ring.cc b/gloo/cuda_allreduce_ring.cc index 3bf2c1481..b23a7de6c 100644 --- a/gloo/cuda_allreduce_ring.cc +++ b/gloo/cuda_allreduce_ring.cc @@ -188,4 +188,8 @@ INSTANTIATE_TEMPLATE(float); INSTANTIATE_TEMPLATE(double); INSTANTIATE_TEMPLATE(float16); +#if GLOO_USE_TORCH_DTYPES +INSTANTIATE_TEMPLATE(c10::BFloat16); +#endif + } // namespace gloo diff --git a/gloo/cuda_allreduce_ring_chunked.cc b/gloo/cuda_allreduce_ring_chunked.cc index ed1dc94b3..2907bdb64 100644 --- a/gloo/cuda_allreduce_ring_chunked.cc +++ b/gloo/cuda_allreduce_ring_chunked.cc @@ -365,4 +365,8 @@ INSTANTIATE_TEMPLATE(float); INSTANTIATE_TEMPLATE(double); INSTANTIATE_TEMPLATE(float16); +#if GLOO_USE_TORCH_DTYPES +INSTANTIATE_TEMPLATE(c10::BFloat16); +#endif + } // namespace gloo diff --git a/gloo/cuda_broadcast_one_to_all.cc b/gloo/cuda_broadcast_one_to_all.cc index cc36757f4..79ccf03de 100644 --- a/gloo/cuda_broadcast_one_to_all.cc +++ b/gloo/cuda_broadcast_one_to_all.cc @@ -197,4 +197,8 @@ INSTANTIATE_TEMPLATE(float); INSTANTIATE_TEMPLATE(double); INSTANTIATE_TEMPLATE(float16); +#if GLOO_USE_TORCH_DTYPES +INSTANTIATE_TEMPLATE(c10::BFloat16); +#endif + } // namespace gloo diff --git a/gloo/cuda_private.h b/gloo/cuda_private.h index 02f1dfd16..2ac45aac0 100644 --- a/gloo/cuda_private.h +++ b/gloo/cuda_private.h @@ -20,6 +20,10 @@ #include "gloo/cuda.h" #include "gloo/transport/device.h" +#if GLOO_USE_TORCH_DTYPES +#include +#endif + namespace gloo { #define CUDA_CHECK(condition) \