Skip to content

Commit

Permalink
Automated rollback of commit 681f6a6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 229930698
  • Loading branch information
dubey authored and tensorflower-gardener committed Jan 18, 2019
1 parent 72571e0 commit f6b81f4
Show file tree
Hide file tree
Showing 22 changed files with 927 additions and 128 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3719,7 +3719,6 @@ tf_cc_tests(
srcs = [
"common_runtime/buf_rendezvous_test.cc",
"common_runtime/collective_executor_mgr_test.cc",
"common_runtime/collective_param_resolver_local_test.cc",
"common_runtime/collective_rma_local_test.cc",
"common_runtime/device_resolver_local_test.cc",
"common_runtime/device_set_test.cc",
Expand Down Expand Up @@ -3835,6 +3834,7 @@ tf_cc_tests(
name = "higher_level_tests_needing_kernels",
size = "small",
srcs = [
"common_runtime/collective_param_resolver_local_test.cc",
"graph/graph_constructor_test.cc",
],
linkopts = select({
Expand Down
33 changes: 33 additions & 0 deletions tensorflow/core/common_runtime/base_collective_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,37 @@ Status BaseCollectiveExecutor::CreateCollective(
return status;
}

bool BaseCollectiveExecutor::CheckDependencies(
const CollectiveParams& col_params) {
for (int32 instance : col_params.instance.impl_details.dependencies) {
auto find_iter = launched_.find(instance);
if (find_iter == launched_.end() || find_iter->second != 0) {
return false;
}
}
return true;
}

void BaseCollectiveExecutor::WaitForDependencies(
const CollectiveParams& col_params) {
mutex_lock l(launch_mu_);
while (!CheckDependencies(col_params)) {
launch_cv_.wait(l);
}
}

void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) {
mutex_lock l(launch_mu_);
if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
const string& task_name =
col_params.instance.task_names[col_params.default_rank];
const int32 num_devices =
col_params.instance.num_devices_per_task.at(task_name);
launched_[col_params.instance.instance_key] = num_devices;
}
if (--launched_[col_params.instance.instance_key] == 0) {
launch_cv_.notify_all();
}
}

} // namespace tensorflow
18 changes: 18 additions & 0 deletions tensorflow/core/common_runtime/base_collective_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,33 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
client_locality, done);
}

// If we need to enforce an ordering on any portion of collective
// implementation, and the ordering is encoded via attribute on the collective
// op, this function will block until all dependencies for this collective
// have completed.
void WaitForDependencies(const CollectiveParams& col_params) override;
// Record that this collective has completed the portion of the implementation
// that needs to be ordered wrt other collectives, to unblock any of its
// dependent ops.
void Launched(const CollectiveParams& col_params) override;

protected:
const int64 step_id_;
const DeviceMgr* dev_mgr_; // Not owned.
std::unique_ptr<PerStepCollectiveRemoteAccess> remote_access_;
const string* gpu_ring_order_; // Not owned.
mutex launch_mu_;
condition_variable launch_cv_;
// collective instance key -> number of local devices for which NCCL ops have
// been launched.
std::unordered_map<int32, int32> launched_ GUARDED_BY(launch_mu_);

private:
Status CreateCollective(const CollectiveParams& col_params,
CollectiveImplementationInterface** col_impl);
// Check if all ops on which this collective depends on have launched.
bool CheckDependencies(const CollectiveParams& col_params)
EXCLUSIVE_LOCKS_REQUIRED(launch_mu_);
};

} // namespace tensorflow
Expand Down
159 changes: 105 additions & 54 deletions tensorflow/core/common_runtime/collective_param_resolver_local.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
Expand All @@ -39,7 +40,10 @@ void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu(mutex_lock& lock) {
CollectiveParamResolverLocal::CollectiveParamResolverLocal(
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
const string& task_name)
: dev_mgr_(dev_mgr), dev_resolver_(dev_resolver), task_name_(task_name) {}
: nccl_(false), // (b/111897089): turn on NCCL collectives.
dev_mgr_(dev_mgr),
dev_resolver_(dev_resolver),
task_name_(task_name) {}

void CollectiveParamResolverLocal::CompleteGroupAsync(
const CompleteGroupRequest* request, CompleteGroupResponse* response,
Expand Down Expand Up @@ -316,29 +320,28 @@ GlobalDeviceMap EstablishGlobalRank(
// cp->same_num_devices_per_task. Requires cp->instance.task_names
// be sorted.
void SetDevPerTask(CollectiveParams* cp) {
cp->instance.same_num_devices_per_task = false;
if (cp->instance.task_names.empty()) return;
int dev_per_task = -1;
int count = 0;
cp->instance.num_devices_per_task.clear();
const string* last_task_name = &cp->instance.task_names[0];
int count = 0;
for (const string& task_name : cp->instance.task_names) {
if (task_name != *last_task_name) {
CHECK_GT(count, 0);
if (dev_per_task < 0) {
dev_per_task = count;
} else {
CHECK_GT(dev_per_task, 0);
if (count != dev_per_task) return;
}
if (task_name == *last_task_name) {
++count;
} else {
cp->instance.num_devices_per_task[*last_task_name] = count;
count = 1;
last_task_name = &task_name;
} else {
++count;
}
}
CHECK_GT(count, 0);
if ((dev_per_task > 0) && (count != dev_per_task)) {
return;
cp->instance.num_devices_per_task[*last_task_name] = count;

cp->instance.same_num_devices_per_task = false;
int dev_per_task = -1;
for (const auto& task_dev : cp->instance.num_devices_per_task) {
if (dev_per_task == -1) {
dev_per_task = task_dev.second;
} else if (dev_per_task != task_dev.second) {
return;
}
}
cp->instance.same_num_devices_per_task = true;
CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0);
Expand Down Expand Up @@ -398,7 +401,6 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
void CollectiveParamResolverLocal::InitInstanceSharedParams(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const StatusCallback& done) {
VLOG(1) << "InitInstanceSharedParams " << ir;
ir->shared.instance = cp->instance;
{
mutex_lock gl(gr->mu);
Expand All @@ -412,8 +414,8 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
}
ir->shared.default_rank = -1;

// Sort devce_names lexicographcally, keeping task_names in
// corresponding order.
// Sort device_names lexicographically, keeping task_names in corresponding
// order. Also set number of devices per task.
SortDevicesAndTasks(&ir->shared);

// Get Locality data for all devices.
Expand Down Expand Up @@ -605,6 +607,25 @@ void CollectiveParamResolverLocal::CompleteInstanceAsync(
"intended only for non-distributed deployment."));
}

// TODO(b/111897089): we need a better way to pick the collective
// implementation. The ideal way would depend upon the topology and link
// strength before picking a particular implementation.
void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
if (cp->instance.type == BROADCAST_COLLECTIVE) {
cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
} else if (cp->instance.type == REDUCTION_COLLECTIVE) {
if (nccl_) {
cp->instance.impl_details.collective_name = "NcclReduce";
} else {
cp->instance.impl_details.collective_name = "RingReduce";
}
} else {
cp->instance.impl_details.collective_name = "undef";
}
VLOG(1) << "AssignCollectiveType "
<< cp->instance.impl_details.collective_name;
}

void CollectiveParamResolverLocal::CompleteInstanceLocal(
const string& device, const GroupRec* gr, CollectiveParams* cp,
bool is_source, const StatusCallback& done) {
Expand Down Expand Up @@ -641,48 +662,57 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
// custom operator= does a deep copy.
cp->instance = ir->shared.instance;
}
// Populate the fields common across task, also default_rank.
// Populate the fields common across task.
AssignCollectiveType(cp);
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
// TODO(b/113171733): we need a better way to pick the collective
// implementation. The ideal way would depend upon the topology and link
// strength before picking a particular implementation.
cp->instance.impl_details.collective_name =
(cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast"
: "RingReduce";

CollectiveImplementationInterface* col_impl;
Status lookup_status = CollectiveRegistry::LookupParamResolverInstance(
Status status = CollectiveRegistry::LookupParamResolverInstance(
cp->instance.impl_details.collective_name, &col_impl);
if (!lookup_status.ok()) {
done(lookup_status);
if (status.ok()) {
status = col_impl->InitializeInstanceBeforeGroupDiscovery(cp);
}
if (!status.ok()) {
done(status);
return;
}
// If broadcast, may need to wait for source discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
CompleteInstanceSource(ir, cp, is_source,
[col_impl, ir, device, cp, done](InstanceRec* irec) {
CHECK_EQ(ir, irec);
Status s;
{
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
cp->source_rank = irec->source_rank;
}
if (s.ok()) {
s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});

// We may need to wait for the group if:
// * this is a broadcast, for source discovery;
// * we are using NCCL with more than 1 worker, for the communicator key from
// rank 0.
bool broadcast = cp->instance.type == BROADCAST_COLLECTIVE;
bool nccl = cp->instance.type == REDUCTION_COLLECTIVE &&
cp->instance.impl_details.collective_name == "NcclReduce" &&
cp->group.num_tasks > 1;
if (broadcast || nccl) {
WaitForGroup(ir, cp, is_source, broadcast, nccl,
[col_impl, ir, device, cp, done](InstanceRec* irec) {
Status s;
if (ir != irec) {
s = errors::Internal("Expected ir ", ir, " and irec ",
irec, " to be equal");
} else {
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
cp->source_rank = irec->source_rank;
cp->instance.communicator_key = irec->communicator_key;
}
if (s.ok()) {
s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
} else {
done(col_impl->InitializeCollectiveParams(cp));
}
}

void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
CollectiveParams* cp,
bool is_source,
const IRConsumer& f) {
void CollectiveParamResolverLocal::WaitForGroup(
InstanceRec* ir, CollectiveParams* cp, bool is_source, bool init_source,
bool init_nccl, const IRConsumer& f) {
std::vector<IRConsumer> ready_waiters;
{
mutex_lock l(ir->out_mu);
Expand All @@ -692,7 +722,8 @@ void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
if (!ir->known[cp->default_rank]) {
ir->known[cp->default_rank] = true;
++ir->known_count;
if (is_source) {
if (init_source && is_source) {
// Initialize source rank.
if (ir->source_rank >= 0) {
ir->status = errors::Internal("Instance ", cp->instance.instance_key,
" already has source ", ir->source_rank,
Expand All @@ -702,13 +733,26 @@ void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
ir->source_rank = cp->default_rank;
}
}
if (init_nccl && cp->default_rank == 0) {
// Initialize communicator key.
if (!ir->communicator_key.empty()) {
ir->status =
errors::Internal("Instance ", cp->instance.instance_key,
" already has communicator_key ",
str_util::CEscape(ir->communicator_key),
", received second claim from device ",
cp->instance.device_names[cp->default_rank]);
} else {
ir->communicator_key = cp->instance.communicator_key;
}
}
}
if (ir->known_count < ir->shared.group.group_size) {
ir->known_waiters.push_back(f);
return;
}
CHECK_EQ(ir->known_count, ir->shared.group.group_size);
if (ir->source_rank < 0) {
if (init_source && ir->source_rank < 0) {
// NOTE(ayushd): changing the error message below would also require
// updating CompleteParamsBroadcastForgotSend test in
// CollectiveParamResolverLocalTest.
Expand All @@ -718,6 +762,13 @@ void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
"could mean that there were group_size=",
ir->known_count, " BcastRecvs but no BcastSend.");
}
if (init_nccl && ir->communicator_key.empty()) {
ir->status = errors::Internal(
"Instance ", cp->instance.instance_key, " device ",
cp->instance.device_names[cp->default_rank],
" did not find rank 0 for setting communicator key. This is an "
"internal error in collective param resolution");
}
if (!ir->known_waiters.empty()) {
ready_waiters = std::move(ir->known_waiters);
}
Expand Down
15 changes: 11 additions & 4 deletions tensorflow/core/common_runtime/collective_param_resolver_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,10 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
Status status GUARDED_BY(out_mu);

// These fields are used to count the instances that have called
// in and become known while resolving broadcast source identity.
// in and become known while resolving broadcast source identity and
// communicator key.
int source_rank GUARDED_BY(out_mu);
string communicator_key GUARDED_BY(out_mu);
int known_count GUARDED_BY(out_mu);
std::vector<bool> known GUARDED_BY(out_mu);
std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
Expand Down Expand Up @@ -197,10 +199,10 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
const StatusCallback& done)
LOCKS_EXCLUDED(ir->out_mu);

// Complete source data for a broadcast instance.
// Complete source data and/or nccl communicator key.
// Precondition: *cp has complete group data and default_rank.
void CompleteInstanceSource(InstanceRec* ir, CollectiveParams* cp,
bool is_source, const IRConsumer& f)
void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, bool is_source,
bool init_source, bool init_nccl, const IRConsumer& f)
LOCKS_EXCLUDED(ir->out_mu);

// If cp.device_names contains only devices local to this process
Expand All @@ -216,10 +218,15 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// current ordering of cp->instance.device_names.
void SetDefaultRank(const string& device, CollectiveParams* cp);

// Sets cp->instance.type based on collective op type, and attempts to assign
// best implementation.
void AssignCollectiveType(CollectiveParams* cp);

// Helper to grab status under lock, invoke callback out of lock.
void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
LOCKS_EXCLUDED(irec->out_mu);

const bool nccl_;
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
Expand Down
Loading

0 comments on commit f6b81f4

Please sign in to comment.