Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gloo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ list(APPEND GLOO_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/algorithm.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/gather.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/types.cc"
)

list(APPEND GLOO_HDRS
Expand All @@ -23,6 +25,7 @@ list(APPEND GLOO_HDRS
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_all.h"
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_one.h"
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast_one_to_all.h"
"${CMAKE_CURRENT_SOURCE_DIR}/gather.h"
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter.h"
"${CMAKE_CURRENT_SOURCE_DIR}/context.h"
"${CMAKE_CURRENT_SOURCE_DIR}/math.h"
Expand Down
81 changes: 81 additions & 0 deletions gloo/gather.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/

#include "gloo/gather.h"

#include <cstring>

#include "gloo/common/logging.h"
#include "gloo/types.h"

namespace gloo {

void gather(const std::shared_ptr<Context>& context, GatherOptions& opts) {
std::unique_ptr<transport::UnboundBuffer> tmpInBuffer;
std::unique_ptr<transport::UnboundBuffer> tmpOutBuffer;
transport::UnboundBuffer* in = nullptr;
transport::UnboundBuffer* out = nullptr;
const auto slot = Slot::build(kGatherSlotPrefix, opts.tag);

// Sanity checks
GLOO_ENFORCE(opts.elementSize > 0);

// Figure out pointer to input buffer
if (opts.inBuffer) {
in = opts.inBuffer.get();
} else {
GLOO_ENFORCE(opts.inPtr != nullptr);
GLOO_ENFORCE(opts.inElements > 0);
tmpInBuffer =
context->createUnboundBuffer(opts.inPtr, opts.inElements * opts.elementSize);
in = tmpInBuffer.get();
}

if (context->rank == opts.root) {
const size_t chunkSize = in->size;

// Figure out pointer to output buffer (only for root rank)
if (opts.outBuffer) {
out = opts.outBuffer.get();
} else {
GLOO_ENFORCE(opts.outPtr != nullptr);
GLOO_ENFORCE(opts.outElements > 0);
tmpOutBuffer =
context->createUnboundBuffer(opts.outPtr, opts.outElements * opts.elementSize);
out = tmpOutBuffer.get();
}

// Ensure the output buffer has the right size.
GLOO_ENFORCE(in->size * context->size == out->size);

// Post receive operations from peers into out buffer
for (size_t i = 0; i < context->size; i++) {
if (i == context->rank) {
continue;
}
out->recv(i, slot, i * chunkSize, chunkSize);
}

// Copy local input to output
memcpy((char*) out->ptr + (context->rank * chunkSize), in->ptr, chunkSize);

// Wait for receive operations to complete
for (size_t i = 0; i < context->size; i++) {
if (i == context->rank) {
continue;
}
out->waitRecv();
}
} else {
in->send(opts.root, slot);
in->waitSend();
}
}

} // namespace gloo
41 changes: 41 additions & 0 deletions gloo/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/

#pragma once

#include "gloo/context.h"
#include "gloo/transport/unbound_buffer.h"

namespace gloo {

struct GatherOptions {
// The input and output buffers can either be specified as a unbound
// buffer (that can be cached and reused by the caller), or a
// literal pointer and number of elements stored at that pointer.
std::unique_ptr<transport::UnboundBuffer> inBuffer;
void* inPtr;
size_t inElements;
std::unique_ptr<transport::UnboundBuffer> outBuffer;
void* outPtr;
size_t outElements;

// Number of bytes per element.
size_t elementSize;

// Rank of receiving process.
int root;

// Tag for this gather operation.
// Must be unique across operations executing in parallel.
uint32_t tag;
};

void gather(const std::shared_ptr<Context>& context, GatherOptions& opts);

} // namespace gloo
1 change: 1 addition & 0 deletions gloo/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(GLOO_TEST_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast_builder_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/gather_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/main.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/send_recv_test.cc"
Expand Down
4 changes: 4 additions & 0 deletions gloo/test/base_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class Fixture {
}
}

T* getPointer() const {
return srcs.front().get();
}

std::vector<T*> getPointers() const {
std::vector<T*> out;
for (const auto& src : srcs) {
Expand Down
88 changes: 88 additions & 0 deletions gloo/test/gather_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/

#include "gloo/gather.h"
#include "gloo/test/base_test.h"

namespace gloo {
namespace test {
namespace {

// Test parameterization.
using Param = std::tuple<int, size_t>;

// Test fixture.
class GatherTest : public BaseTest,
public ::testing::WithParamInterface<Param> {
};

TEST_P(GatherTest, Default) {
auto contextSize = std::get<0>(GetParam());
auto dataSize = std::get<1>(GetParam());

spawn(contextSize, [&](std::shared_ptr<Context> context) {
auto input = Fixture<uint64_t>(context, 1, dataSize);
auto output = Fixture<uint64_t>(context, 1, contextSize * dataSize);

// Initialize fixture with globally unique values
input.assignValues();

GatherOptions opts;
opts.inPtr = input.getPointer();
opts.inElements = dataSize;
opts.elementSize = sizeof(uint64_t);

// Take turns being root
for (auto i = 0; i < context->size; i++) {
// Set output pointer only when root
if (i == context->rank) {
opts.outPtr = output.getPointer();
opts.outElements = dataSize * contextSize;
} else {
opts.outPtr = nullptr;
opts.outElements = 0;
}

opts.root = i;
gather(context, opts);

// Validate result if root
if (i == context->rank) {
const auto ptr = output.getPointer();
const auto stride = context->size;
for (auto j = 0; j < context->size; j++) {
for (auto k = 0; k < dataSize; k++) {
ASSERT_EQ(j + k * stride, ptr[k + j * dataSize])
<< "Mismatch at index " << (k + j * dataSize);
}
}
}
}
});
}

std::vector<size_t> genMemorySizes() {
std::vector<size_t> v;
v.push_back(1);
v.push_back(10);
v.push_back(100);
v.push_back(1000);
return v;
}

INSTANTIATE_TEST_CASE_P(
GatherDefault,
GatherTest,
::testing::Combine(
::testing::Values(2, 4, 7),
::testing::ValuesIn(genMemorySizes())));

} // namespace
} // namespace test
} // namespace gloo
36 changes: 36 additions & 0 deletions gloo/types.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/

#include "gloo/types.h"

#include <sstream>
#include <stdexcept>

namespace gloo {

Slot Slot::build(uint8_t prefix, uint32_t tag) {
uint64_t u64prefix = ((uint64_t) prefix) << 56;
uint64_t u64tag = (((uint64_t) tag) & 0xffffffff) << 24;
return Slot(u64prefix || u64tag, 0);
}

const Slot Slot::operator+(uint8_t i) const {
// Maximum of 8 bits for use in a single collective operation.
// To avoid conflicts between them, raise if it overflows.
auto delta = delta_ + i;
if (delta > 0xff) {
std::stringstream ss;
ss << "Slot overflow: delta " << delta << " > 0xff";
throw std::runtime_error(ss.str());
}

return Slot(base_, delta);
}

} // namespace gloo
48 changes: 48 additions & 0 deletions gloo/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,54 @@

namespace gloo {

// Unlike old style collectives that are class instances that hold
// some state, the new style collectives do not need initialization
// before they can run. Instead of asking the context for a series of
// slots and storing them for later use and reuse, the new style
// collectives take a slot (or tag) argument that allows for
// concurrent execution of multiple collectives on the same context.
//
// This tag is what determines the slot numbers for the send and recv
// operations that the collectives end up executing. A single
// collective may have many send and recv operations running in
// parallel, so instead of using the specified tag verbatim, we use it
// as a prefix. Also, to avoid conflicts between collectives with the
// same tag, we have another tag prefix per collective type. Out of
// the 64 bits we can use for a slot, we use 8 of them to identify a
// collective, 32 to identify the collective tag, another 8 for use by
// the collective operation itself (allowing for 256 independent send
// and recv operations against the same point to point pair), and
// leave 16 bits unused.
//
// Below, you find constexprs for the prefix per collective type, as
// well as a way to compute slots when executing a collective. The
// slot class below captures both a prefix and a delta on that prefix
// to support addition with bounds checking. It is usable as an
// uint64_t, but one that cannot overflow beyond the bits allocated
// for use within a collective.
//

constexpr uint8_t kGatherSlotPrefix = 0x01;

class Slot {
public:
static Slot build(uint8_t prefix, uint32_t tag);

operator uint64_t() const {
return base_ + delta_;
}

const Slot operator+(uint8_t i) const;

protected:
explicit Slot(uint64_t base, uint64_t delta)
: base_(base), delta_(delta) {
}

const uint64_t base_;
const uint64_t delta_;
};

struct float16;
float16 cpu_float2half_rn(float f);
float cpu_half2float(float16 h);
Expand Down