diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index 4631d2162..592db5bd5 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -9,6 +9,8 @@ list(APPEND GLOO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/algorithm.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc" "${CMAKE_CURRENT_SOURCE_DIR}/context.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/gather.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/types.cc" ) list(APPEND GLOO_HDRS @@ -23,6 +25,7 @@ list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_all.h" "${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_one.h" "${CMAKE_CURRENT_SOURCE_DIR}/broadcast_one_to_all.h" + "${CMAKE_CURRENT_SOURCE_DIR}/gather.h" "${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter.h" "${CMAKE_CURRENT_SOURCE_DIR}/context.h" "${CMAKE_CURRENT_SOURCE_DIR}/math.h" diff --git a/gloo/gather.cc b/gloo/gather.cc new file mode 100644 index 000000000..5d1bdf5d4 --- /dev/null +++ b/gloo/gather.cc @@ -0,0 +1,81 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/gather.h" + +#include + +#include "gloo/common/logging.h" +#include "gloo/types.h" + +namespace gloo { + +void gather(const std::shared_ptr& context, GatherOptions& opts) { + std::unique_ptr tmpInBuffer; + std::unique_ptr tmpOutBuffer; + transport::UnboundBuffer* in = nullptr; + transport::UnboundBuffer* out = nullptr; + const auto slot = Slot::build(kGatherSlotPrefix, opts.tag); + + // Sanity checks + GLOO_ENFORCE(opts.elementSize > 0); + + // Figure out pointer to input buffer + if (opts.inBuffer) { + in = opts.inBuffer.get(); + } else { + GLOO_ENFORCE(opts.inPtr != nullptr); + GLOO_ENFORCE(opts.inElements > 0); + tmpInBuffer = + context->createUnboundBuffer(opts.inPtr, opts.inElements * opts.elementSize); + in = tmpInBuffer.get(); + } + + if (context->rank == opts.root) { + const size_t chunkSize = in->size; + + // Figure out pointer to output buffer (only for root rank) + if (opts.outBuffer) { + out = opts.outBuffer.get(); + } else { + GLOO_ENFORCE(opts.outPtr != nullptr); + GLOO_ENFORCE(opts.outElements > 0); + tmpOutBuffer = + context->createUnboundBuffer(opts.outPtr, opts.outElements * opts.elementSize); + out = tmpOutBuffer.get(); + } + + // Ensure the output buffer has the right size. + GLOO_ENFORCE(in->size * context->size == out->size); + + // Post receive operations from peers into out buffer + for (size_t i = 0; i < context->size; i++) { + if (i == context->rank) { + continue; + } + out->recv(i, slot, i * chunkSize, chunkSize); + } + + // Copy local input to output + memcpy((char*) out->ptr + (context->rank * chunkSize), in->ptr, chunkSize); + + // Wait for receive operations to complete + for (size_t i = 0; i < context->size; i++) { + if (i == context->rank) { + continue; + } + out->waitRecv(); + } + } else { + in->send(opts.root, slot); + in->waitSend(); + } +} + +} // namespace gloo diff --git a/gloo/gather.h b/gloo/gather.h new file mode 100644 index 000000000..f1bf9ae91 --- /dev/null +++ b/gloo/gather.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#pragma once + +#include "gloo/context.h" +#include "gloo/transport/unbound_buffer.h" + +namespace gloo { + +struct GatherOptions { + // The input and output buffers can either be specified as a unbound + // buffer (that can be cached and reused by the caller), or a + // literal pointer and number of elements stored at that pointer. + std::unique_ptr inBuffer; + void* inPtr; + size_t inElements; + std::unique_ptr outBuffer; + void* outPtr; + size_t outElements; + + // Number of bytes per element. + size_t elementSize; + + // Rank of receiving process. + int root; + + // Tag for this gather operation. + // Must be unique across operations executing in parallel. + uint32_t tag; +}; + +void gather(const std::shared_ptr& context, GatherOptions& opts); + +} // namespace gloo diff --git a/gloo/test/CMakeLists.txt b/gloo/test/CMakeLists.txt index ec7ec15e2..a126b3b15 100644 --- a/gloo/test/CMakeLists.txt +++ b/gloo/test/CMakeLists.txt @@ -4,6 +4,7 @@ set(GLOO_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/barrier_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/broadcast_builder_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/gather_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/main.cc" "${CMAKE_CURRENT_SOURCE_DIR}/send_recv_test.cc" diff --git a/gloo/test/base_test.h b/gloo/test/base_test.h index d74a48983..091214520 100644 --- a/gloo/test/base_test.h +++ b/gloo/test/base_test.h @@ -172,6 +172,10 @@ class Fixture { } } + T* getPointer() const { + return srcs.front().get(); + } + std::vector getPointers() const { std::vector out; for (const auto& src : srcs) { diff --git a/gloo/test/gather_test.cc b/gloo/test/gather_test.cc new file mode 100644 index 000000000..13acf74af --- /dev/null +++ b/gloo/test/gather_test.cc @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/gather.h" +#include "gloo/test/base_test.h" + +namespace gloo { +namespace test { +namespace { + +// Test parameterization. +using Param = std::tuple; + +// Test fixture. +class GatherTest : public BaseTest, + public ::testing::WithParamInterface { +}; + +TEST_P(GatherTest, Default) { + auto contextSize = std::get<0>(GetParam()); + auto dataSize = std::get<1>(GetParam()); + + spawn(contextSize, [&](std::shared_ptr context) { + auto input = Fixture(context, 1, dataSize); + auto output = Fixture(context, 1, contextSize * dataSize); + + // Initialize fixture with globally unique values + input.assignValues(); + + GatherOptions opts; + opts.inPtr = input.getPointer(); + opts.inElements = dataSize; + opts.elementSize = sizeof(uint64_t); + + // Take turns being root + for (auto i = 0; i < context->size; i++) { + // Set output pointer only when root + if (i == context->rank) { + opts.outPtr = output.getPointer(); + opts.outElements = dataSize * contextSize; + } else { + opts.outPtr = nullptr; + opts.outElements = 0; + } + + opts.root = i; + gather(context, opts); + + // Validate result if root + if (i == context->rank) { + const auto ptr = output.getPointer(); + const auto stride = context->size; + for (auto j = 0; j < context->size; j++) { + for (auto k = 0; k < dataSize; k++) { + ASSERT_EQ(j + k * stride, ptr[k + j * dataSize]) + << "Mismatch at index " << (k + j * dataSize); + } + } + } + } + }); +} + +std::vector genMemorySizes() { + std::vector v; + v.push_back(1); + v.push_back(10); + v.push_back(100); + v.push_back(1000); + return v; +} + +INSTANTIATE_TEST_CASE_P( + GatherDefault, + GatherTest, + ::testing::Combine( + ::testing::Values(2, 4, 7), + ::testing::ValuesIn(genMemorySizes()))); + +} // namespace +} // namespace test +} // namespace gloo diff --git a/gloo/types.cc b/gloo/types.cc new file mode 100644 index 000000000..72c50b2ca --- /dev/null +++ b/gloo/types.cc @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include "gloo/types.h" + +#include +#include + +namespace gloo { + +Slot Slot::build(uint8_t prefix, uint32_t tag) { + uint64_t u64prefix = ((uint64_t) prefix) << 56; + uint64_t u64tag = (((uint64_t) tag) & 0xffffffff) << 24; + return Slot(u64prefix || u64tag, 0); +} + +const Slot Slot::operator+(uint8_t i) const { + // Maximum of 8 bits for use in a single collective operation. + // To avoid conflicts between them, raise if it overflows. + auto delta = delta_ + i; + if (delta > 0xff) { + std::stringstream ss; + ss << "Slot overflow: delta " << delta << " > 0xff"; + throw std::runtime_error(ss.str()); + } + + return Slot(base_, delta); +} + +} // namespace gloo diff --git a/gloo/types.h b/gloo/types.h index bd678e669..d2ef9c2ce 100644 --- a/gloo/types.h +++ b/gloo/types.h @@ -27,6 +27,54 @@ namespace gloo { +// Unlike old style collectives that are class instances that hold +// some state, the new style collectives do not need initialization +// before they can run. Instead of asking the context for a series of +// slots and storing them for later use and reuse, the new style +// collectives take a slot (or tag) argument that allows for +// concurrent execution of multiple collectives on the same context. +// +// This tag is what determines the slot numbers for the send and recv +// operations that the collectives end up executing. A single +// collective may have many send and recv operations running in +// parallel, so instead of using the specified tag verbatim, we use it +// as a prefix. Also, to avoid conflicts between collectives with the +// same tag, we have another tag prefix per collective type. Out of +// the 64 bits we can use for a slot, we use 8 of them to identify a +// collective, 32 to identify the collective tag, another 8 for use by +// the collective operation itself (allowing for 256 independent send +// and recv operations against the same point to point pair), and +// leave 16 bits unused. +// +// Below, you find constexprs for the prefix per collective type, as +// well as a way to compute slots when executing a collective. The +// slot class below captures both a prefix and a delta on that prefix +// to support addition with bounds checking. It is usable as an +// uint64_t, but one that cannot overflow beyond the bits allocated +// for use within a collective. +// + +constexpr uint8_t kGatherSlotPrefix = 0x01; + +class Slot { + public: + static Slot build(uint8_t prefix, uint32_t tag); + + operator uint64_t() const { + return base_ + delta_; + } + + const Slot operator+(uint8_t i) const; + + protected: + explicit Slot(uint64_t base, uint64_t delta) + : base_(base), delta_(delta) { + } + + const uint64_t base_; + const uint64_t delta_; +}; + struct float16; float16 cpu_float2half_rn(float f); float cpu_half2float(float16 h);