Skip to content

Commit

Permalink
[xla:gpu] Add support for device constraints to Send/Recv operations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592996655
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Dec 22, 2023
1 parent 1338377 commit a5ac85e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 25 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ cc_library(
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:global_device_id",
"//xla/service:name_uniquer",
"//xla/service/gpu/fusions",
"//xla/service/gpu/fusions:fusion_emitter",
Expand Down
19 changes: 15 additions & 4 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ limitations under the License.
#include "xla/service/buffer_assignment.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
Expand Down Expand Up @@ -3504,6 +3505,14 @@ static absl::flat_hash_map<std::string, std::string> ConvertFrontendAttributes(
return result;
}

static std::optional<GlobalDeviceId> DeviceConstraint(
const HloInstruction* hlo) {
if (hlo->has_sharding() && hlo->sharding().HasUniqueDevice()) {
return GlobalDeviceId(hlo->sharding().GetUniqueDevice());
}
return std::nullopt;
}

Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
if (!instr->channel_id().has_value())
return absl::InternalError("Unknown send instruction channel id");
Expand All @@ -3515,7 +3524,8 @@ Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
AddThunkToThunkSequence(std::make_unique<SendThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), src->shape(), buffer,
*instr->channel_id(), send_recv_events_,
ConvertFrontendAttributes(instr->frontend_attributes())));
ConvertFrontendAttributes(instr->frontend_attributes()),
DeviceConstraint(instr)));

return OkStatus();
}
Expand All @@ -3527,7 +3537,7 @@ Status IrEmitterUnnested::EmitSendDoneThunk(

AddThunkToThunkSequence(std::make_unique<SendDoneThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(),
send_recv_events_));
send_recv_events_, DeviceConstraint(instr)));

return OkStatus();
}
Expand All @@ -3543,7 +3553,8 @@ Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
Thunk::ThunkInfo::WithProfileAnnotation(instr),
instr->shape().tuple_shapes()[0], buffer, *instr->channel_id(),
send_recv_events_,
ConvertFrontendAttributes(instr->frontend_attributes())));
ConvertFrontendAttributes(instr->frontend_attributes()),
DeviceConstraint(instr)));

return OkStatus();
}
Expand All @@ -3555,7 +3566,7 @@ Status IrEmitterUnnested::EmitRecvDoneThunk(

AddThunkToThunkSequence(std::make_unique<RecvDoneThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(),
send_recv_events_));
send_recv_events_, DeviceConstraint(instr)));

return OkStatus();
}
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/runtime3/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ cc_library(
"//xla:statusor",
"//xla:xla_data_proto_cc",
"//xla/service:buffer_assignment",
"//xla/service:global_device_id",
"//xla/service/gpu:thunk",
"//xla/stream_executor",
"@com_google_absl//absl/base:core_headers",
Expand Down
74 changes: 57 additions & 17 deletions third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>

#include "absl/container/flat_hash_map.h"
Expand All @@ -26,9 +28,11 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/thunk.h"
#include "xla/shape.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/stream_executor.h"
Expand All @@ -43,6 +47,25 @@ using tsl::AsyncValueRef;
using tsl::profiler::TraceMe;
using tsl::profiler::TraceMeEncode;

// For sharded buffers we should execute Send/Recv operations only on devices
// with maximal sharding, and do nothing on every other device.
static StatusOr<bool> ShouldSkip(
std::string_view operation, const Thunk::ExecuteParams& params,
const std::optional<GlobalDeviceId>& device_constraint) {
if (!device_constraint.has_value()) return false;

TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id,
params.nccl_params.GetGlobalDeviceId());

bool skip = global_device_id != *device_constraint;
if (skip) {
VLOG(3) << "Skip " << operation << " as device id " << global_device_id
<< " doesn't match device id constraint " << *device_constraint;
}

return skip;
}

//===----------------------------------------------------------------------===//
// SendRecvAsyncEvents
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -78,18 +101,24 @@ absl::StatusOr<AsyncValueRef<se::Event>> SendRecvAsyncEvents::Extract(
SendThunk::SendThunk(
ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer,
int64_t channel_id, std::shared_ptr<SendRecvAsyncEvents> events,
absl::flat_hash_map<std::string, std::string> frontend_attrs)
absl::flat_hash_map<std::string, std::string> frontend_attrs,
std::optional<GlobalDeviceId> device_constraint)
: Thunk(Thunk::kSend, thunk_info),
shape_(shape),
buffer_(buffer),
channel_id_(channel_id),
events_(std::move(events)),
frontend_attrs_(std::move(frontend_attrs)) {}
frontend_attrs_(std::move(frontend_attrs)),
device_constraint_(device_constraint) {}

Status SendThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(3) << "Send buffer: channel_id=" << channel_id_
<< "; shape=" << shape_.ToString();

TF_ASSIGN_OR_RETURN(bool skip,
ShouldSkip("sending buffer", params, device_constraint_));
if (skip) return OkStatus();

TraceMe trace([&] {
return TraceMeEncode("Send", {{"channel_id", channel_id_}});
});
Expand Down Expand Up @@ -122,15 +151,19 @@ Status SendThunk::ExecuteOnStream(const ExecuteParams& params) {
//===----------------------------------------------------------------------===//

SendDoneThunk::SendDoneThunk(ThunkInfo thunk_info, int64_t channel_id,

std::shared_ptr<SendRecvAsyncEvents> events)
std::shared_ptr<SendRecvAsyncEvents> events,
std::optional<GlobalDeviceId> device_constraint)
: Thunk(Thunk::kSend, thunk_info),
channel_id_(channel_id),
events_(std::move(events)) {}
events_(std::move(events)),
device_constraint_(device_constraint) {}

Status SendDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(3) << "Wait for Host Send completion:"
<< " channel_id=" << channel_id_;
VLOG(3) << "Wait for send completion: channel_id=" << channel_id_;

TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("waiting for send completion",
params, device_constraint_));
if (skip) return OkStatus();

TraceMe trace([&] {
return TraceMeEncode("SendDone", {{"channel_id", channel_id_}});
Expand All @@ -143,8 +176,7 @@ Status SendDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
BlockUntilReady(done_event.GetAsyncValue());
if (done_event.IsError()) return done_event.GetError();

VLOG(5) << "Completed Host Send operation: "
<< " channel_id=" << channel_id_;
VLOG(5) << "Completed Send operation: channel_id=" << channel_id_;

// Once event is recorded we can add a stream dependency.
params.stream->ThenWaitFor(&done_event.get());
Expand All @@ -158,18 +190,24 @@ Status SendDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
RecvThunk::RecvThunk(
ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer,
int64_t channel_id, std::shared_ptr<SendRecvAsyncEvents> events,
absl::flat_hash_map<std::string, std::string> frontend_attrs)
absl::flat_hash_map<std::string, std::string> frontend_attrs,
std::optional<GlobalDeviceId> device_constraint)
: Thunk(Thunk::kSend, thunk_info),
shape_(shape),
buffer_(buffer),
channel_id_(channel_id),
events_(std::move(events)),
frontend_attrs_(std::move(frontend_attrs)) {}
frontend_attrs_(std::move(frontend_attrs)),
device_constraint_(device_constraint) {}

Status RecvThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(3) << "Recv buffer: channel_id=" << channel_id_
<< "; shape=" << shape_.ToString();

TF_ASSIGN_OR_RETURN(
bool skip, ShouldSkip("receiving buffer", params, device_constraint_));
if (skip) return OkStatus();

TraceMe trace([&] {
return TraceMeEncode("Recv", {{"channel_id", channel_id_}});
});
Expand Down Expand Up @@ -202,15 +240,18 @@ Status RecvThunk::ExecuteOnStream(const ExecuteParams& params) {
//===----------------------------------------------------------------------===//

RecvDoneThunk::RecvDoneThunk(ThunkInfo thunk_info, int64_t channel_id,

std::shared_ptr<SendRecvAsyncEvents> events)
std::shared_ptr<SendRecvAsyncEvents> events,
std::optional<GlobalDeviceId> device_constraint)
: Thunk(Thunk::kSend, thunk_info),
channel_id_(channel_id),
events_(std::move(events)) {}

Status RecvDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(3) << "Wait for Recv completion:"
<< " channel_id=" << channel_id_;
VLOG(3) << "Wait for recv completion: channel_id=" << channel_id_;

TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("waiting for recv completion",
params, device_constraint_));
if (skip) return OkStatus();

TraceMe trace([&] {
return TraceMeEncode("RecvDone", {{"channel_d", channel_id_}});
Expand All @@ -223,8 +264,7 @@ Status RecvDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
BlockUntilReady(done_event.GetAsyncValue());
if (done_event.IsError()) return done_event.GetError();

VLOG(5) << "Completed Host Recv operation: "
<< " channel=" << channel_id_;
VLOG(5) << "Completed Recv operation: channel=" << channel_id_;

// Once event is recorded we can add a stream dependency.
params.stream->ThenWaitFor(&done_event.get());
Expand Down
18 changes: 14 additions & 4 deletions third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/thunk.h"
#include "xla/shape.h"
#include "xla/status.h"
Expand Down Expand Up @@ -84,7 +86,8 @@ class SendThunk : public Thunk {
public:
SendThunk(ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer,
int64_t channel_id, std::shared_ptr<SendRecvAsyncEvents> events,
absl::flat_hash_map<std::string, std::string> frontend_attrs);
absl::flat_hash_map<std::string, std::string> frontend_attrs,
std::optional<GlobalDeviceId> device_constraint);

Status ExecuteOnStream(const ExecuteParams& params) override;

Expand All @@ -96,6 +99,7 @@ class SendThunk : public Thunk {

std::shared_ptr<SendRecvAsyncEvents> events_;
absl::flat_hash_map<std::string, std::string> frontend_attrs_;
std::optional<GlobalDeviceId> device_constraint_;
};

//===----------------------------------------------------------------------===//
Expand All @@ -105,14 +109,16 @@ class SendThunk : public Thunk {
class SendDoneThunk : public Thunk {
public:
SendDoneThunk(ThunkInfo thunk_info, int64_t channel_id,
std::shared_ptr<SendRecvAsyncEvents> events);
std::shared_ptr<SendRecvAsyncEvents> events,
std::optional<GlobalDeviceId> device_constraint);

Status ExecuteOnStream(const ExecuteParams& params) override;

private:
int64_t channel_id_;

std::shared_ptr<SendRecvAsyncEvents> events_;
std::optional<GlobalDeviceId> device_constraint_;
};

//===----------------------------------------------------------------------===//
Expand All @@ -123,7 +129,8 @@ class RecvThunk : public Thunk {
public:
RecvThunk(ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer,
int64_t channel_id, std::shared_ptr<SendRecvAsyncEvents> events,
absl::flat_hash_map<std::string, std::string> frontend_attrs);
absl::flat_hash_map<std::string, std::string> frontend_attrs,
std::optional<GlobalDeviceId> device_constraint);

Status ExecuteOnStream(const ExecuteParams& params) override;

Expand All @@ -135,6 +142,7 @@ class RecvThunk : public Thunk {

std::shared_ptr<SendRecvAsyncEvents> events_;
absl::flat_hash_map<std::string, std::string> frontend_attrs_;
std::optional<GlobalDeviceId> device_constraint_;
};

//===----------------------------------------------------------------------===//
Expand All @@ -144,14 +152,16 @@ class RecvThunk : public Thunk {
class RecvDoneThunk : public Thunk {
public:
RecvDoneThunk(ThunkInfo thunk_info, int64_t channel_id,
std::shared_ptr<SendRecvAsyncEvents> events);
std::shared_ptr<SendRecvAsyncEvents> events,
std::optional<GlobalDeviceId> device_constraint);

Status ExecuteOnStream(const ExecuteParams& params) override;

private:
int64_t channel_id_;

std::shared_ptr<SendRecvAsyncEvents> events_;
std::optional<GlobalDeviceId> device_constraint_;
};

} // namespace xla::gpu
Expand Down

0 comments on commit a5ac85e

Please sign in to comment.