Skip to content

Commit

Permalink
PR #8874: [GPU] Use NCCL user buffers for collective permute and all-…
Browse files Browse the repository at this point in the history
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
  • Loading branch information
trevor-m authored and tensorflower-gardener committed Mar 19, 2024
1 parent 8957235 commit 393389d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 18 deletions.
13 changes: 8 additions & 5 deletions tensorflow/lite/experimental/shlo/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ constexpr TensorElementType BaselineType(TensorElementType type) {
return type;
}

std::variant<TensorElementType, QuantizedTensorElementType> BaselineType(
const std::variant<TensorElementType, QuantizedTensorElementType>& type);
using TensorElementTypeVariant =
std::variant<TensorElementType, QuantizedTensorElementType>;

TensorElementTypeVariant BaselineType(const TensorElementTypeVariant& type);

struct TensorType {
Shape shape;
Expand All @@ -46,6 +48,8 @@ struct QuantizedTensorType {
QuantizedTensorElementType element_type;
};

using TensorTypeVariant = std::variant<TensorType, QuantizedTensorType>;

struct Tensor {
const Shape& shape() const;
Shape& shape();
Expand All @@ -69,8 +73,7 @@ struct Tensor {
const TensorElementType& tensor_element_type() const;
const QuantizedTensorElementType& quantized_tensor_element_type() const;

std::variant<TensorElementType, QuantizedTensorElementType> element_type()
const;
TensorElementTypeVariant element_type() const;

template <DataType data_type, typename T = typename Storage<data_type>::Type>
T* GetDataAs() {
Expand All @@ -88,7 +91,7 @@ struct Tensor {
static_cast<size_t>(NumElements()));
}

std::variant<TensorType, QuantizedTensorType> type;
TensorTypeVariant type;

// If type is TensorType, the type should be Storage<type.element_type>::Type.
// If type is QuantizedTensorType, the type should be
Expand Down
26 changes: 17 additions & 9 deletions third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,28 @@ inline constexpr int64_t kCollectiveMemorySpaceColor = 1;
// collective memory using ncclMemAlloc in the runtime.
inline BufferAssigner::Colorer CollectiveColorer() {
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
static const auto* kSupportedOpcodes = new absl::flat_hash_set<HloOpcode>{
HloOpcode::kAllReduce,
HloOpcode::kAllReduceStart,
HloOpcode::kAllReduceDone,
HloOpcode::kAllGather,
HloOpcode::kAllGatherStart,
HloOpcode::kAllGatherDone,
HloOpcode::kReduceScatter,
HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart,
HloOpcode::kCollectivePermuteDone,
HloOpcode::kAllToAll,
};
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
auto& buffer = alias_analysis->GetBufferContainingValue(*value);
for (const auto& alias : buffer.values()) {
if ((alias->instruction()->opcode() == HloOpcode::kAllReduce ||
alias->instruction()->opcode() == HloOpcode::kAllReduceStart ||
alias->instruction()->opcode() == HloOpcode::kAllReduceDone ||
alias->instruction()->opcode() == HloOpcode::kAllGather ||
alias->instruction()->opcode() == HloOpcode::kAllGatherStart ||
alias->instruction()->opcode() == HloOpcode::kAllGatherDone ||
alias->instruction()->opcode() == HloOpcode::kReduceScatter) ||
// opcode or async wrapped opcode is in kSupportedOpcodes.
if (kSupportedOpcodes->contains(alias->instruction()->opcode()) ||
((alias->instruction()->opcode() == HloOpcode::kAsyncStart ||
alias->instruction()->opcode() == HloOpcode::kAsyncDone) &&
alias->instruction()->async_wrapped_opcode() ==
HloOpcode::kReduceScatter)) {
kSupportedOpcodes->contains(
alias->instruction()->async_wrapped_opcode()))) {
value->set_color(kCollectiveMemorySpaceColor);
}
}
Expand Down
20 changes: 16 additions & 4 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2179,14 +2179,18 @@ Status IrEmitterUnnested::EmitCollectivePermute(
// First output is aliased.
TF_RET_CHECK(
instr->shape().IsTuple() && instr->shape().tuple_shapes_size() == 2 &&
instr->shape().tuple_shapes(0) == instr->shape().tuple_shapes(1));
Shape::Equal().IgnoreMemorySpaceInLayout()(
instr->shape().tuple_shapes(0), instr->shape().tuple_shapes(1)));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSliceForHlo(instr, {1}));

const Shape shape = operand->shape();
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t src_memory_space = shape.layout().memory_space();
const int64_t dst_memory_space =
instr->shape().tuple_shapes(1).layout().memory_space();

if (NcclCollectivePermuteStartThunk::IsDegenerate(instr, replica_count,
partition_count)) {
Expand All @@ -2202,7 +2206,9 @@ Status IrEmitterUnnested::EmitCollectivePermute(
const NcclCollectiveThunk::Buffer buffer = {
/*element_count=*/ShapeUtil::ElementsIn(shape),
/*source_buffer=*/source_slice,
/*destination_buffer=*/result_slice};
/*destination_buffer=*/result_slice,
/*source_memory_space=*/src_memory_space,
/*destination_memory_space=*/dst_memory_space};
auto thunk = std::make_unique<NcclCollectivePermuteStartThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, buffer);
Expand Down Expand Up @@ -2619,10 +2625,13 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t memory_space = src->shape().layout().memory_space();
const NcclCollectiveThunk::Buffer nccl_buffer = {
/*element_count=*/ShapeUtil::ElementsIn(src->shape()),
/*source_buffer=*/buffer,
/*destination_buffer=*/buffer};
/*destination_buffer=*/buffer,
/*source_memory_space=*/memory_space,
/*destination_memory_space=*/memory_space};
auto thunk = std::make_unique<NcclSendThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, nccl_buffer);
Expand Down Expand Up @@ -2685,10 +2694,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t memory_space = instr->shape().layout().memory_space();
const NcclCollectiveThunk::Buffer nccl_buffer = {
/*element_count=*/ShapeUtil::ElementsIn(instr->shape().tuple_shapes(0)),
/*source_buffer=*/buffer,
/*destination_buffer=*/buffer};
/*destination_buffer=*/buffer,
/*source_memory_space=*/memory_space,
/*destination_memory_space=*/memory_space};
auto thunk = std::make_unique<NcclRecvThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, nccl_buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
se::Stream& stream, NcclApi::NcclCommHandle comm) {
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ absl::Status RunCollectivePermute(
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, device_ordinal, {buffer}, comm));

const std::optional<int64_t> source_id = source_target.source;
const std::optional<int64_t> target_id = source_target.target;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ absl::Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params,
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing Recv from device ordinal: " << device_ordinal
<< "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm));

const std::optional<int64_t> source_id = source_target.source;
se::DeviceMemoryBase dest_addr = buffer.destination_buffer;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ absl::Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params,
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm));

const std::optional<int64_t> target_id = source_target.target;
se::DeviceMemoryBase src_addr = buffer.source_buffer;
Expand Down

0 comments on commit 393389d

Please sign in to comment.