Skip to content

Commit

Permalink
[XLA:GPU] Add experimental, lightly tested support for multi-host and…
Browse files Browse the repository at this point in the history
… multi-process NCCL AllReduce.

This change makes several API changes:
* we allow the client to provide a mapping from the local device ordinals on the machine to global device IDs. If provided, we interpret the device IDs in the DeviceAssignment provided by the client as global IDs, not as local device ordinals. This allows us to describe computations that cross a host boundary.
* we allow the client to provide a callback for manufacturing a ncclUniqueId for a particular subset of global devices. The idea is that the client should use some other distributed system of their own (e.g., MPI) to share ncclUniqueId values needed for a computation. NCCL allows for cross-host/process collectives iff the same ncclUniqueId value is used.

Refactors the common collective logic and the NCCL collective logic in particular to support a local/global distinction.

PiperOrigin-RevId: 296505571
Change-Id: I5ed42d65597b0960df78890745421f77e9789ba3
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Feb 21, 2020
1 parent 371c29f commit 8a72c44
Show file tree
Hide file tree
Showing 18 changed files with 450 additions and 139 deletions.
11 changes: 11 additions & 0 deletions tensorflow/compiler/xla/executable_run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_;
}

ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options(
const GpuExecutableRunOptions* gpu_executable_run_options) {
gpu_executable_run_options_ = gpu_executable_run_options;
return *this;
}

const GpuExecutableRunOptions*
ExecutableRunOptions::gpu_executable_run_options() const {
return gpu_executable_run_options_;
}

ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
rng_seed_ = rng_seed;
return *this;
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/xla/executable_run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace xla {

class DeviceAssignment;
class ExecutionProfile;
class GpuExecutableRunOptions;

// A unique identifier for a particular "logical execution" of an XLA model.
//
Expand Down Expand Up @@ -137,6 +138,12 @@ class ExecutableRunOptions {
return then_execute_function_;
}

// GPU-backend specific options. These are kept out-of-line to avoid bloating
// the size of this dependency for CPU-only AOT builds.
ExecutableRunOptions& set_gpu_executable_run_options(
const GpuExecutableRunOptions* gpu_executable_run_options);
const GpuExecutableRunOptions* gpu_executable_run_options() const;

private:
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
Expand All @@ -148,6 +155,7 @@ class ExecutableRunOptions {
stream_executor::Stream* host_to_device_stream_ = nullptr;
ThenExecuteFunction* then_execute_function_ = nullptr;
RunId run_id_;
const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
};

} // namespace xla
Expand Down
17 changes: 6 additions & 11 deletions tensorflow/compiler/xla/refcounting_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,7 @@ template <typename K, typename V>
class RefcountingHashMap {
public:
// Default-constructs new values.
RefcountingHashMap()
: value_factory_([](const K&) { return absl::make_unique<V>(); }) {}

// Constructs new values according to the given factory function.
explicit RefcountingHashMap(
std::function<std::unique_ptr<V>(const K&)> value_factory)
: value_factory_(std::move(value_factory)) {}
RefcountingHashMap() = default;

// Not copyable or movable because this contains internal pointers (namely,
// instances of Deleter contain pointers to `this` and into `map_`).
Expand All @@ -60,8 +54,10 @@ class RefcountingHashMap {
// Gets the value for the given key.
//
// If the map doesn't contain a live value for the key, constructs one
// according to the factory passed to the map's constructor.
std::shared_ptr<V> operator[](const K& key) {
// using `value_factory`.
std::shared_ptr<V> GetOrCreateIfAbsent(
const K& key,
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
// We ensure that the entry has not expired in case deleter was running when
Expand All @@ -76,7 +72,7 @@ class RefcountingHashMap {
// Create entry in the map and then set its value, so the value can
// contain a pointer back into the map.
it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_factory_(key).release(),
std::shared_ptr<V> value(value_factory(key).release(),
Deleter{&it->first, this});
it->second = value; // Set the weak ptr to the shared ptr.
return value;
Expand Down Expand Up @@ -112,7 +108,6 @@ class RefcountingHashMap {
}
};

std::function<std::unique_ptr<V>(const K&)> value_factory_;
absl::Mutex mu_;
absl::node_hash_map<K, std::weak_ptr<V>> map_ ABSL_GUARDED_BY(mu_);
};
Expand Down
26 changes: 15 additions & 11 deletions tensorflow/compiler/xla/refcounting_hash_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,36 @@ struct DeleteNotifier {

TEST(RefcountingHashMapTest, PointerIdentity) {
RefcountingHashMap<int, int> m;
std::shared_ptr<int> a = m[0];
std::shared_ptr<int> b = m[0];
std::shared_ptr<int> c = m[1];
auto factory = [](const int&) { return absl::make_unique<int>(); };
std::shared_ptr<int> a = m.GetOrCreateIfAbsent(0, factory);
std::shared_ptr<int> b = m.GetOrCreateIfAbsent(0, factory);
std::shared_ptr<int> c = m.GetOrCreateIfAbsent(1, factory);
EXPECT_EQ(a.get(), b.get());
EXPECT_NE(a.get(), c.get());
}

TEST(RefcountingHashMapTest, DefaultInitialized) {
RefcountingHashMap<int, int> m;
EXPECT_EQ(*m[42], 0);
auto factory = [](const int&) { return absl::make_unique<int>(); };
EXPECT_EQ(*m.GetOrCreateIfAbsent(42, factory), 0);
}

TEST(RefcountingHashMapTest, DeletesEagerly) {
RefcountingHashMap<int, DeleteNotifier> m;
bool deleted = false;
auto handle = m[0];
auto factory = [](const int&) { return absl::make_unique<DeleteNotifier>(); };
auto handle = m.GetOrCreateIfAbsent(0, factory);
handle->fn = [&] { deleted = true; };
EXPECT_FALSE(deleted);
handle = nullptr;
EXPECT_TRUE(deleted);
}

TEST(RefcountingHashMapTest, CustomFactory) {
RefcountingHashMap<int, int> m(
[](const int& x) { return absl::make_unique<int>(x + 1); });
EXPECT_EQ(*m[0], 1);
EXPECT_EQ(*m[100], 101);
RefcountingHashMap<int, int> m;
auto factory = [](const int& x) { return absl::make_unique<int>(x + 1); };
EXPECT_EQ(*m.GetOrCreateIfAbsent(0, factory), 1);
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
}

TEST(RefcountingHashMapTest, ForEachEmpty) {
Expand All @@ -85,8 +88,9 @@ TEST(RefcountingHashMapTest, ForEachEmpty) {

TEST(RefcountingHashMapTest, ForEachNonempty) {
RefcountingHashMap<int, int> m;
auto a = m[0];
auto b = m[1];
auto factory = [](const int&) { return absl::make_unique<int>(); };
auto a = m.GetOrCreateIfAbsent(0, factory);
auto b = m.GetOrCreateIfAbsent(1, factory);

std::vector<int> seen_keys;
std::vector<int*> seen_values;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4585,6 +4585,7 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal", # fixdeps: keep
"//tensorflow/stream_executor/lib",
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/collective_ops_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ absl::optional<ReductionKind> MatchReductionComputation(
}

StatusOr<std::vector<int64>> GetParticipatingReplicas(
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups,
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
int64 total_replica_count, const DeviceAssignment& device_assn) {
std::vector<int64> participating_replicas;

Expand All @@ -58,7 +58,7 @@ StatusOr<std::vector<int64>> GetParticipatingReplicas(

// Use the DeviceAssignment to figure out our replica-id.
TF_ASSIGN_OR_RETURN(int replica_id,
device_assn.ReplicaIdForDeviceOrdinal(device_ordinal));
device_assn.ReplicaIdForDeviceOrdinal(device_id.value()));

// Figure out the other replicas that go together with this one.
absl::optional<ReplicaGroup> replica_group;
Expand Down
51 changes: 29 additions & 22 deletions tensorflow/compiler/xla/service/collective_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
Expand All @@ -37,9 +38,9 @@ absl::optional<ReductionKind> MatchReductionComputation(
const HloComputation* computation);

// Figures out which devices (named by their replica-ids) are participating in
// the all-reduce subgroup that contains device_ordinal.
// the all-reduce subgroup that contains device_id.
StatusOr<std::vector<int64>> GetParticipatingReplicas(
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups,
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
int64 total_replica_count, const DeviceAssignment& device_assn);

// Key that identifies a particular Rendezvous object in our global hashtable.
Expand Down Expand Up @@ -72,16 +73,18 @@ struct RendezvousKey {
};

explicit RendezvousKey(const RunId& run_id,
std::vector<int64> participating_replicas,
std::vector<GlobalDeviceId> global_devices,
int num_local_participants,
CollectiveOpKind collective_op_kind, int64 op_id)
: run_id(run_id),
participating_replicas(participating_replicas),
global_devices(std::move(global_devices)),
num_local_participants(num_local_participants),
collective_op_kind(collective_op_kind),
op_id(op_id) {}

static RendezvousKey FromInstruction(
const RunId& run_id, std::vector<int64> participating_replicas,
const HloInstruction* instr) {
const RunId& run_id, std::vector<GlobalDeviceId> global_devices,
int num_local_participants, const HloInstruction* instr) {
CollectiveOpKind collective_op_kind;
int64 op_id;

Expand All @@ -91,20 +94,19 @@ struct RendezvousKey {
: std::make_pair(
kCrossReplica,
static_cast<int64>(instr->GetModule()->unique_id()));
return RendezvousKey(run_id, participating_replicas, collective_op_kind,
op_id);
return RendezvousKey(run_id, std::move(global_devices),
num_local_participants, collective_op_kind, op_id);
}

int num_participants() const { return participating_replicas.size(); }

template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id, k.participating_replicas,
return H::combine(std::move(h), k.run_id, k.global_devices,
k.num_local_participants,
static_cast<int>(k.collective_op_kind), k.op_id);
}
friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
return a.run_id == b.run_id &&
a.participating_replicas == b.participating_replicas &&
return a.run_id == b.run_id && a.global_devices == b.global_devices &&
a.num_local_participants == b.num_local_participants &&
a.collective_op_kind == b.collective_op_kind && //
a.op_id == b.op_id;
}
Expand All @@ -114,14 +116,15 @@ struct RendezvousKey {

string ToString() const {
return absl::StrFormat(
"RendezvousKey{run_id=%s, participating_replicas=[%s], "
"collective_op_kind=%d, op_id=%d}",
run_id.ToString(), absl::StrJoin(participating_replicas, ","),
static_cast<int>(collective_op_kind), op_id);
"RendezvousKey{run_id=%s, global_devices=[%s], "
"num_local_participants=%d, collective_op_kind=%d, op_id=%d}",
run_id.ToString(), GlobalDeviceIdsToString(global_devices),
num_local_participants, static_cast<int>(collective_op_kind), op_id);
}

RunId run_id;
std::vector<int64> participating_replicas;
std::vector<GlobalDeviceId> global_devices;
int num_local_participants;
CollectiveOpKind collective_op_kind;
int64 op_id;
};
Expand Down Expand Up @@ -164,10 +167,13 @@ struct AllReduceParticipantData {
};
std::vector<Buffer> buffers;
se::Stream* stream;
const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;

ReductionKind reduction_kind;

int num_participants() const { return rendezvous_key.num_participants(); }
// For each local all-reduce participant a (global ID, local device ordinal)
// pair for the participant. Participants are in no particular order.
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;

string ToString() const {
std::vector<std::string> buffer_strs;
Expand Down Expand Up @@ -303,12 +309,13 @@ class Rendezvous {
const RendezvousKey key_;

tensorflow::BlockingCounter all_participants_present_{
key_.num_participants()};
tensorflow::BlockingCounter done_{key_.num_participants()};
key_.num_local_participants};
tensorflow::BlockingCounter done_{key_.num_local_participants};

// tensorflow::BlockingCounter returned by SubmitParticipant.
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
std::make_shared<tensorflow::BlockingCounter>(key_.num_participants())};
std::make_shared<tensorflow::BlockingCounter>(
key_.num_local_participants)};
};

} // end namespace xla
Expand Down
42 changes: 28 additions & 14 deletions tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalRendezvousMap() {
static auto& m =
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>(
[](const xla::RendezvousKey& k) {
return absl::make_unique<CpuAllReduceRendezvous>(k);
});
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
return m;
}

Expand All @@ -411,18 +408,28 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(

std::vector<xla::ReplicaGroup> group =
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
xla::int32 replica_count = run_options->device_assignment()->replica_count();
std::vector<xla::int64> participating_replicas_vec =
xla::GetParticipatingReplicas(device_ordinal, group, replica_count,
const xla::DeviceAssignment& device_assignment =
*run_options->device_assignment();
xla::int32 replica_count = device_assignment.replica_count();
CHECK_EQ(device_assignment.computation_count(), 1);
std::vector<xla::int64> participating_replicas =
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
replica_count,
*run_options->device_assignment())
.ValueOrDie();

xla::RendezvousKey::CollectiveOpKind op_kind =
channel_id_present ? xla::RendezvousKey::kCrossModule
: xla::RendezvousKey::kCrossReplica;
xla::RendezvousKey rendezvous_key(run_options->run_id(),
participating_replicas_vec, op_kind, op_id);

std::vector<xla::GlobalDeviceId> participating_devices;
participating_devices.reserve(participating_replicas.size());
for (xla::int64 replica : participating_replicas) {
participating_devices.push_back(
xla::GlobalDeviceId(device_assignment(replica, 0)));
}
xla::RendezvousKey rendezvous_key(
run_options->run_id(), std::move(participating_devices),
participating_replicas.size(), op_kind, op_id);
auto shape_str = ShapeString(shape_ptr, shape_length);
VLOG(2) << "All-reduce input/output shape : " << shape_str;

Expand All @@ -444,10 +451,17 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
participant.buffers = {buffer};
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);

TF_CHECK_OK(
CpuAllReduceRendezvous::SubmitParticipant(
[&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant)
.status());
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
return absl::make_unique<CpuAllReduceRendezvous>(k);
};

TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
[&] {
return GlobalRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, make_cpu_rendezvous);
},
participant)
.status());
}

TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
Expand Down
Loading

0 comments on commit 8a72c44

Please sign in to comment.