diff --git a/gloo/allgather.cc b/gloo/allgather.cc index 70600f04b..8ab4e6f4f 100644 --- a/gloo/allgather.cc +++ b/gloo/allgather.cc @@ -53,8 +53,8 @@ void allgather(AllgatherOptions& opts) { in->size); } - // Short circuit if there is only a single process. - if (context->size == 1) { + // Short circuit if there is only a single process or the output is empty. + if (context->size == 1 || outBytes == 0) { return; } diff --git a/gloo/allgather_ring.h b/gloo/allgather_ring.h index 03b79e2ca..aa438c2a9 100644 --- a/gloo/allgather_ring.h +++ b/gloo/allgather_ring.h @@ -55,6 +55,10 @@ class AllgatherRing : public Algorithm { virtual ~AllgatherRing() {} void run() { + // Short circuit if there is only a single process or the output is empty. + if (this->contextSize_ == 1 || count_ == 0) { + return; + } const int rank = this->contextRank_; const int numRounds = this->contextSize_ - 1; diff --git a/gloo/allgatherv.cc b/gloo/allgatherv.cc index 770da3314..961f4254e 100644 --- a/gloo/allgatherv.cc +++ b/gloo/allgatherv.cc @@ -111,8 +111,8 @@ void allgatherv(AllgathervOptions& opts) { } } - // Short circuit if there is only a single process. - if (context->size == 1) { + // Short circuit if there is only a single process or the output is empty. + if (context->size == 1 || offset == 0) { return; } diff --git a/gloo/test/allgather_test.cc b/gloo/test/allgather_test.cc index aad9f6604..56b1a19a7 100644 --- a/gloo/test/allgather_test.cc +++ b/gloo/test/allgather_test.cc @@ -100,7 +100,7 @@ INSTANTIATE_TEST_CASE_P( ::testing::Combine( ::testing::ValuesIn(kTransportsForClassAlgorithms), ::testing::Range(2, 10), - ::testing::Values(4, 100, 1000, 10000), + ::testing::Values(0, 4, 100, 1000, 10000), ::testing::Range(1, 4))); using NewParam = std::tuple; @@ -157,7 +157,7 @@ INSTANTIATE_TEST_CASE_P( ::testing::Combine( ::testing::ValuesIn(kTransportsForFunctionAlgorithms), ::testing::Values(1, 2, 4, 7), - ::testing::Values(4, 100, 1000, 10000), + ::testing::Values(0, 4, 100, 1000, 10000), ::testing::Values(false, true))); TEST_F(AllgatherNewTest, TestTimeout) { diff --git a/gloo/test/allgatherv_test.cc b/gloo/test/allgatherv_test.cc index 2450c4e98..d9ec40fe2 100644 --- a/gloo/test/allgatherv_test.cc +++ b/gloo/test/allgatherv_test.cc @@ -86,7 +86,7 @@ INSTANTIATE_TEST_CASE_P( ::testing::Combine( ::testing::ValuesIn(kTransportsForFunctionAlgorithms), ::testing::Values(1, 2, 4, 7), - ::testing::Values(1, 10, 100, 1000), + ::testing::Values(0, 1, 10, 100, 1000), ::testing::Values(false, true))); TEST_F(AllgathervTest, TestTimeout) {