diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index 4631d2162..f468eebfc 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -7,12 +7,16 @@ set(GLOO_HDRS) # Compiled sources in root directory list(APPEND GLOO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/algorithm.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/allgather.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc" "${CMAKE_CURRENT_SOURCE_DIR}/context.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/gather.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/types.cc" ) list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/algorithm.h" + "${CMAKE_CURRENT_SOURCE_DIR}/allgather.h" "${CMAKE_CURRENT_SOURCE_DIR}/allgather_ring.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_halving_doubling.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_bcube.h" @@ -23,6 +27,7 @@ list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_all.h" "${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_one.h" "${CMAKE_CURRENT_SOURCE_DIR}/broadcast_one_to_all.h" + "${CMAKE_CURRENT_SOURCE_DIR}/gather.h" "${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter.h" "${CMAKE_CURRENT_SOURCE_DIR}/context.h" "${CMAKE_CURRENT_SOURCE_DIR}/math.h" diff --git a/gloo/allgather.cc b/gloo/allgather.cc new file mode 100644 index 000000000..2000d29f2 --- /dev/null +++ b/gloo/allgather.cc @@ -0,0 +1,109 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/allgather.h" + +#include + +#include "gloo/common/logging.h" +#include "gloo/types.h" + +namespace gloo { + +void allgather(const std::shared_ptr& context, AllgatherOptions& opts) { + std::unique_ptr tmpInBuffer; + std::unique_ptr tmpOutBuffer; + transport::UnboundBuffer* in = nullptr; + transport::UnboundBuffer* out = nullptr; + const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag); + + // Sanity checks + GLOO_ENFORCE(opts.elementSize > 0); + const auto recvRank = (context->size + context->rank - 1) % context->size; + GLOO_ENFORCE(context->getPair(recvRank), "pair missing (rank ", recvRank, ")"); + const auto sendRank = (context->size + context->rank + 1) % context->size; + GLOO_ENFORCE(context->getPair(sendRank), "pair missing (rank ", sendRank, ")"); + + // Figure out pointer to input buffer + if (opts.inBuffer) { + in = opts.inBuffer.get(); + } else if (opts.inPtr != nullptr) { + GLOO_ENFORCE(opts.inElements > 0); + tmpInBuffer = + context->createUnboundBuffer(opts.inPtr, opts.inElements * opts.elementSize); + in = tmpInBuffer.get(); + } + + // Figure out pointer to output buffer + if (opts.outBuffer) { + out = opts.outBuffer.get(); + } else { + GLOO_ENFORCE(opts.outPtr != nullptr); + GLOO_ENFORCE(opts.outElements > 0); + tmpOutBuffer = + context->createUnboundBuffer(opts.outPtr, opts.outElements * opts.elementSize); + out = tmpOutBuffer.get(); + } + + GLOO_ENFORCE_EQ(out->size, in->size * context->size); + + // If the input buffer is specified, this is NOT an in place operation, + // and the output buffer needs to be primed with the input. + if (in != nullptr) { + memcpy( + (uint8_t*) out->ptr + context->rank * opts.inElements * opts.elementSize, + (uint8_t*) in->ptr, + opts.inElements * opts.elementSize); + } + + // The chunk size may not be divisible by 2; use dynamic lookup. + std::array chunkSize; + chunkSize[0] = (opts.inElements * opts.elementSize) / 2; + chunkSize[1] = (opts.inElements * opts.elementSize) - chunkSize[0]; + std::array chunkOffset; + chunkOffset[0] = 0; + chunkOffset[1] = chunkSize[0]; + + for (auto i = 0; i < (context->size - 1) * 2; i++) { + size_t sendOffset = + (((context->size + context->rank - (i / 2)) + * opts.inElements + * opts.elementSize) + + chunkOffset[i & 0x1]) + % (opts.outElements * opts.elementSize); + size_t recvOffset = + (((context->size + context->rank - 1 - (i / 2)) + * opts.inElements + * opts.elementSize) + + chunkOffset[i & 0x1]) + % (opts.outElements * opts.elementSize); + size_t size = chunkSize[i & 0x1]; + if (i < 2) { + out->send(sendRank, slot, sendOffset, size); + out->recv(recvRank, slot, recvOffset, size); + continue; + } + + // Wait for pending operations to complete to synchronize with the + // previous iteration. Because we kick off two operations before + // getting here we always wait for the next-to-last operation. + out->waitSend(); + out->waitRecv(); + out->send(sendRank, slot, sendOffset, size); + out->recv(recvRank, slot, recvOffset, size); + } + + // Wait for completes + for (auto i = 0; i < 2; i++) { + out->waitSend(); + out->waitRecv(); + } +} + +} // namespace gloo diff --git a/gloo/allgather.h b/gloo/allgather.h new file mode 100644 index 000000000..8c166b800 --- /dev/null +++ b/gloo/allgather.h @@ -0,0 +1,42 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#pragma once + +#include "gloo/context.h" +#include "gloo/transport/unbound_buffer.h" + +namespace gloo { + +struct AllgatherOptions { + // The input and output can either be specified as a unbound buffer + // (that can be cached and reused by the caller), or a literal + // pointer and number of elements stored at that pointer. + // + // The operation is executed in place on the output if the input is + // set to null. The input for this process is assumed to be at the + // location in the output buffer where it would otherwise be. + std::unique_ptr inBuffer; + void* inPtr; + size_t inElements; + std::unique_ptr outBuffer; + void* outPtr; + size_t outElements; + + // Number of bytes per element. + size_t elementSize; + + // Tag for this gather operation. + // Must be unique across operations executing in parallel. + uint32_t tag; +}; + +void allgather(const std::shared_ptr& context, AllgatherOptions& opts); + +} // namespace gloo diff --git a/gloo/gather.cc b/gloo/gather.cc new file mode 100644 index 000000000..5d1bdf5d4 --- /dev/null +++ b/gloo/gather.cc @@ -0,0 +1,81 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/gather.h" + +#include + +#include "gloo/common/logging.h" +#include "gloo/types.h" + +namespace gloo { + +void gather(const std::shared_ptr& context, GatherOptions& opts) { + std::unique_ptr tmpInBuffer; + std::unique_ptr tmpOutBuffer; + transport::UnboundBuffer* in = nullptr; + transport::UnboundBuffer* out = nullptr; + const auto slot = Slot::build(kGatherSlotPrefix, opts.tag); + + // Sanity checks + GLOO_ENFORCE(opts.elementSize > 0); + + // Figure out pointer to input buffer + if (opts.inBuffer) { + in = opts.inBuffer.get(); + } else { + GLOO_ENFORCE(opts.inPtr != nullptr); + GLOO_ENFORCE(opts.inElements > 0); + tmpInBuffer = + context->createUnboundBuffer(opts.inPtr, opts.inElements * opts.elementSize); + in = tmpInBuffer.get(); + } + + if (context->rank == opts.root) { + const size_t chunkSize = in->size; + + // Figure out pointer to output buffer (only for root rank) + if (opts.outBuffer) { + out = opts.outBuffer.get(); + } else { + GLOO_ENFORCE(opts.outPtr != nullptr); + GLOO_ENFORCE(opts.outElements > 0); + tmpOutBuffer = + context->createUnboundBuffer(opts.outPtr, opts.outElements * opts.elementSize); + out = tmpOutBuffer.get(); + } + + // Ensure the output buffer has the right size. + GLOO_ENFORCE(in->size * context->size == out->size); + + // Post receive operations from peers into out buffer + for (size_t i = 0; i < context->size; i++) { + if (i == context->rank) { + continue; + } + out->recv(i, slot, i * chunkSize, chunkSize); + } + + // Copy local input to output + memcpy((char*) out->ptr + (context->rank * chunkSize), in->ptr, chunkSize); + + // Wait for receive operations to complete + for (size_t i = 0; i < context->size; i++) { + if (i == context->rank) { + continue; + } + out->waitRecv(); + } + } else { + in->send(opts.root, slot); + in->waitSend(); + } +} + +} // namespace gloo diff --git a/gloo/gather.h b/gloo/gather.h new file mode 100644 index 000000000..f1bf9ae91 --- /dev/null +++ b/gloo/gather.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#pragma once + +#include "gloo/context.h" +#include "gloo/transport/unbound_buffer.h" + +namespace gloo { + +struct GatherOptions { + // The input and output buffers can either be specified as a unbound + // buffer (that can be cached and reused by the caller), or a + // literal pointer and number of elements stored at that pointer. + std::unique_ptr inBuffer; + void* inPtr; + size_t inElements; + std::unique_ptr outBuffer; + void* outPtr; + size_t outElements; + + // Number of bytes per element. + size_t elementSize; + + // Rank of receiving process. + int root; + + // Tag for this gather operation. + // Must be unique across operations executing in parallel. + uint32_t tag; +}; + +void gather(const std::shared_ptr& context, GatherOptions& opts); + +} // namespace gloo diff --git a/gloo/test/CMakeLists.txt b/gloo/test/CMakeLists.txt index ec7ec15e2..b84c3b261 100644 --- a/gloo/test/CMakeLists.txt +++ b/gloo/test/CMakeLists.txt @@ -1,9 +1,11 @@ set(GLOO_TEST_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_builder_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/barrier_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/broadcast_builder_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/gather_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/main.cc" "${CMAKE_CURRENT_SOURCE_DIR}/send_recv_test.cc" diff --git a/gloo/test/allgather_test.cc b/gloo/test/allgather_test.cc index eddcdebc7..87b82c324 100644 --- a/gloo/test/allgather_test.cc +++ b/gloo/test/allgather_test.cc @@ -11,6 +11,7 @@ #include #include +#include "gloo/allgather.h" #include "gloo/allgather_ring.h" #include "gloo/common/common.h" #include "gloo/test/base_test.h" @@ -109,6 +110,76 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(genMemorySizes()), ::testing::Range(1, 4))); +using NewParam = std::tuple; + +class AllgatherNewTest : public BaseTest, + public ::testing::WithParamInterface {}; + +TEST_P(AllgatherNewTest, Default) { + auto contextSize = std::get<0>(GetParam()); + auto dataSize = std::get<1>(GetParam()); + + auto validate = [dataSize]( + const std::shared_ptr& context, + Fixture& output) { + const auto ptr = output.getPointer(); + const auto stride = context->size; + for (auto j = 0; j < context->size; j++) { + for (auto k = 0; k < dataSize; k++) { + ASSERT_EQ(j + k * stride, ptr[k + j * dataSize]) + << "Mismatch at index " << (k + j * dataSize); + } + } + }; + + spawn(contextSize, [&](std::shared_ptr context) { + auto input = Fixture(context, 1, dataSize); + auto output = Fixture(context, 1, contextSize * dataSize); + + // Run with raw pointers and sizes in options + { + input.assignValues(); + output.clear(); + + AllgatherOptions opts; + opts.inPtr = input.getPointer(); + opts.inElements = dataSize; + opts.outPtr = output.getPointer(); + opts.outElements = contextSize * dataSize; + opts.elementSize = sizeof(uint64_t); + input.assignValues(); + output.clear(); + allgather(context, opts); + validate(context, output); + } + + // Run with (optionally cached) unbound buffers in options + { + input.assignValues(); + output.clear(); + + AllgatherOptions opts; + opts.inBuffer = context->createUnboundBuffer( + input.getPointer(), + dataSize * sizeof(uint64_t)); + opts.outBuffer = context->createUnboundBuffer( + output.getPointer(), + contextSize * dataSize * sizeof(uint64_t)); + opts.elementSize = sizeof(uint64_t); + allgather(context, opts); + validate(context, output); + } + }); +} + +INSTANTIATE_TEST_CASE_P( + AllgatherNewDefault, + AllgatherNewTest, + ::testing::Combine( + ::testing::Values(2, 4, 7), + ::testing::ValuesIn(genMemorySizes()))); + + } // namespace } // namespace test } // namespace gloo diff --git a/gloo/test/base_test.h b/gloo/test/base_test.h index d74a48983..9cea0c117 100644 --- a/gloo/test/base_test.h +++ b/gloo/test/base_test.h @@ -136,6 +136,14 @@ class Fixture { } } + void clear() { + for (auto i = 0; i < srcs.size(); i++) { + for (auto j = 0; j < count; j++) { + srcs[i][j] = 0; + } + } + } + void checkBroadcastResult(Fixture& fixture, int root, int rootPointer) { // Expected is set to the expected value at ptr[0] const auto expected = root * fixture.srcs.size() + rootPointer; @@ -172,6 +180,10 @@ class Fixture { } } + T* getPointer() const { + return srcs.front().get(); + } + std::vector getPointers() const { std::vector out; for (const auto& src : srcs) { diff --git a/gloo/test/gather_test.cc b/gloo/test/gather_test.cc new file mode 100644 index 000000000..13acf74af --- /dev/null +++ b/gloo/test/gather_test.cc @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/gather.h" +#include "gloo/test/base_test.h" + +namespace gloo { +namespace test { +namespace { + +// Test parameterization. +using Param = std::tuple; + +// Test fixture. +class GatherTest : public BaseTest, + public ::testing::WithParamInterface { +}; + +TEST_P(GatherTest, Default) { + auto contextSize = std::get<0>(GetParam()); + auto dataSize = std::get<1>(GetParam()); + + spawn(contextSize, [&](std::shared_ptr context) { + auto input = Fixture(context, 1, dataSize); + auto output = Fixture(context, 1, contextSize * dataSize); + + // Initialize fixture with globally unique values + input.assignValues(); + + GatherOptions opts; + opts.inPtr = input.getPointer(); + opts.inElements = dataSize; + opts.elementSize = sizeof(uint64_t); + + // Take turns being root + for (auto i = 0; i < context->size; i++) { + // Set output pointer only when root + if (i == context->rank) { + opts.outPtr = output.getPointer(); + opts.outElements = dataSize * contextSize; + } else { + opts.outPtr = nullptr; + opts.outElements = 0; + } + + opts.root = i; + gather(context, opts); + + // Validate result if root + if (i == context->rank) { + const auto ptr = output.getPointer(); + const auto stride = context->size; + for (auto j = 0; j < context->size; j++) { + for (auto k = 0; k < dataSize; k++) { + ASSERT_EQ(j + k * stride, ptr[k + j * dataSize]) + << "Mismatch at index " << (k + j * dataSize); + } + } + } + } + }); +} + +std::vector genMemorySizes() { + std::vector v; + v.push_back(1); + v.push_back(10); + v.push_back(100); + v.push_back(1000); + return v; +} + +INSTANTIATE_TEST_CASE_P( + GatherDefault, + GatherTest, + ::testing::Combine( + ::testing::Values(2, 4, 7), + ::testing::ValuesIn(genMemorySizes()))); + +} // namespace +} // namespace test +} // namespace gloo diff --git a/gloo/types.cc b/gloo/types.cc new file mode 100644 index 000000000..72c50b2ca --- /dev/null +++ b/gloo/types.cc @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/types.h" + +#include +#include + +namespace gloo { + +Slot Slot::build(uint8_t prefix, uint32_t tag) { + uint64_t u64prefix = ((uint64_t) prefix) << 56; + uint64_t u64tag = (((uint64_t) tag) & 0xffffffff) << 24; + return Slot(u64prefix || u64tag, 0); +} + +const Slot Slot::operator+(uint8_t i) const { + // Maximum of 8 bits for use in a single collective operation. + // To avoid conflicts between them, raise if it overflows. + auto delta = delta_ + i; + if (delta > 0xff) { + std::stringstream ss; + ss << "Slot overflow: delta " << delta << " > 0xff"; + throw std::runtime_error(ss.str()); + } + + return Slot(base_, delta); +} + +} // namespace gloo diff --git a/gloo/types.h b/gloo/types.h index bd678e669..33af1c2d5 100644 --- a/gloo/types.h +++ b/gloo/types.h @@ -27,6 +27,55 @@ namespace gloo { +// Unlike old style collectives that are class instances that hold +// some state, the new style collectives do not need initialization +// before they can run. Instead of asking the context for a series of +// slots and storing them for later use and reuse, the new style +// collectives take a slot (or tag) argument that allows for +// concurrent execution of multiple collectives on the same context. +// +// This tag is what determines the slot numbers for the send and recv +// operations that the collectives end up executing. A single +// collective may have many send and recv operations running in +// parallel, so instead of using the specified tag verbatim, we use it +// as a prefix. Also, to avoid conflicts between collectives with the +// same tag, we have another tag prefix per collective type. Out of +// the 64 bits we can use for a slot, we use 8 of them to identify a +// collective, 32 to identify the collective tag, another 8 for use by +// the collective operation itself (allowing for 256 independent send +// and recv operations against the same point to point pair), and +// leave 16 bits unused. +// +// Below, you find constexprs for the prefix per collective type, as +// well as a way to compute slots when executing a collective. The +// slot class below captures both a prefix and a delta on that prefix +// to support addition with bounds checking. It is usable as an +// uint64_t, but one that cannot overflow beyond the bits allocated +// for use within a collective. +// + +constexpr uint8_t kGatherSlotPrefix = 0x01; +constexpr uint8_t kAllgatherSlotPrefix = 0x02; + +class Slot { + public: + static Slot build(uint8_t prefix, uint32_t tag); + + operator uint64_t() const { + return base_ + delta_; + } + + const Slot operator+(uint8_t i) const; + + protected: + explicit Slot(uint64_t base, uint64_t delta) + : base_(base), delta_(delta) { + } + + const uint64_t base_; + const uint64_t delta_; +}; + struct float16; float16 cpu_float2half_rn(float f); float cpu_half2float(float16 h);