Skip to content

Commit

Permalink
PR #8874: [GPU] Use NCCL user buffers for collective permute and all-…
Browse files Browse the repository at this point in the history
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
  • Loading branch information
trevor-m authored and tensorflower-gardener committed Mar 18, 2024
1 parent b960be6 commit 6d82b56
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 29 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_FLOAT32;
} else if (type.isF16()) {
return tflite::TensorType_FLOAT16;
} else if (type.isBF16()) {
return tflite::TensorType_BFLOAT16;
} else if (type.isF64()) {
return tflite::TensorType_FLOAT64;
} else if (type.isa<mlir::TF::StringType>()) {
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/lite/ir/tfl_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -3926,10 +3926,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}];

let arguments = (ins
TFL_TensorOf<[F16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
);

let results = (outs TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);

// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure cast with bfloat16 roundtrip exactly

func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> {
^bb0(%arg0: tensor<4x5xbf16>):
// CHECK-LABEL: @main
// CHECK: (tensor<4x5xbf16>) -> tensor<4x5xf32>
// CHECK-NEXT: (tensor<4x5xf32>) -> tensor<4x5xbf16>
%0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1")
%1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2")
func.return %1 : tensor<4x5xbf16>
}
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,18 @@ func.func @matmul_batchv3_unknown_dim(%arg0: tensor<?x10x15xf32>, %arg1: tensor<
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
}

func.func @matmul_batchv3_unknown_dim_bf16(%arg0: tensor<?x4x5xbf16>, %arg1: tensor<5x6xf32>) -> tensor<?x4x6xbf16> {
%0 = "tf.Cast"(%arg0) : (tensor<?x4x5xbf16>) -> tensor<?x4x5xf32>
%1 = "tf.BatchMatMulV3"(%0, %arg1) {Ta = "tfdtype$DT_FLOAT", Tb = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<?x4x5xf32>, tensor<5x6xf32>) -> tensor<?x4x6xf32>
%2 = "tf.Cast"(%1) : (tensor<?x4x6xf32>) -> tensor<?x4x6xbf16>
func.return %2 : tensor<?x4x6xbf16>
// CHECK-LABEL: matmul_batchv3_unknown_dim_bf16
// CHECK: [[CST:%.*]] = "tfl.cast"(%arg0) : (tensor<?x4x5xbf16>) -> tensor<?x4x5xf32>
// CHECK: [[BMM:%.*]] = "tfl.batch_matmul"([[CST]], %arg1) {adj_x = false, adj_y = false} : (tensor<?x4x5xf32>, tensor<5x6xf32>) -> tensor<?x4x6xf32>
// CHECK: "tfl.cast"([[BMM]]) : (tensor<?x4x6xf32>) -> tensor<?x4x6xbf16>
}

// -----

func.func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tensor<1x1x1x1x1x4xf32>, %arg2 : tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> {
Expand Down
74 changes: 74 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s

func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> {
^bb0(%arg0: tensor<4x5xbf16>):

// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: deprecated_builtin_code: 53,
// CHECK-NEXT: version: 7,
// CHECK-NEXT: builtin_code: CAST
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4, 5 ],
// CHECK-NEXT: type: BFLOAT16,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: has_rank: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 5 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "cast1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: has_rank: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 5 ],
// CHECK-NEXT: type: BFLOAT16,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "cast2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: has_rank: true
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: inputs: [ 1 ],
// CHECK-NEXT: outputs: [ 2 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ],
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }

%0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1")
%1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2")
func.return %1 : tensor<4x5xbf16>
}
43 changes: 31 additions & 12 deletions tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,22 +345,41 @@ StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
switch (elem_type.getIntOrFloatBitWidth()) {
case 16: {
assert(bytes_len % 2 == 0);
assert(elem_type.isF16());
// Supports both BF16 and F16.
assert(elem_type.isF16() || elem_type.isBF16());
int elem_count = bytes_len / 2;
std::vector<Eigen::half> values;
values.reserve(elem_count);

const char* data = reinterpret_cast<const char*>(buffer.data());
if (elem_type.isF16()) {
std::vector<Eigen::half> values;
values.reserve(elem_count);

for (int i = 0; i < elem_count; i++) {
uint16_t bit_repr =
llvm::support::endian::readNext<uint16_t, llvm::endianness::native,
llvm::support::unaligned>(data);
values.push_back(Eigen::numext::bit_cast<Eigen::half>(bit_repr));
}
const char* data = reinterpret_cast<const char*>(buffer.data());

return mlir::ElementsAttr(
DenseElementsAttr::get(shaped_type, ArrayRef<Eigen::half>(values)));
for (int i = 0; i < elem_count; i++) {
uint16_t bit_repr = llvm::support::endian::readNext<
uint16_t, llvm::endianness::native, llvm::support::unaligned>(
data);
values.push_back(Eigen::numext::bit_cast<Eigen::half>(bit_repr));
}

return mlir::ElementsAttr(
DenseElementsAttr::get(shaped_type, ArrayRef<Eigen::half>(values)));
} else {
std::vector<Eigen::bfloat16> values;
values.reserve(elem_count);

const char* data = reinterpret_cast<const char*>(buffer.data());

for (int i = 0; i < elem_count; i++) {
uint16_t bit_repr = llvm::support::endian::readNext<
uint16_t, llvm::endianness::native, llvm::support::unaligned>(
data);
values.push_back(Eigen::numext::bit_cast<Eigen::bfloat16>(bit_repr));
}

return mlir::ElementsAttr(DenseElementsAttr::get(
shaped_type, ArrayRef<Eigen::bfloat16>(values)));
}
}
case 32: {
assert(bytes_len % 4 == 0);
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/lite/tools/versioning/op_version.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1045,8 +1045,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 2;
case BuiltinOperator_CAST:
if (op_sig.inputs.at(0).type == kTfLiteInt4 &&
op_sig.outputs.at(0).type == kTfLiteFloat32) {
if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
return 7;
} else if (op_sig.inputs.at(0).type == kTfLiteInt4 &&
op_sig.outputs.at(0).type == kTfLiteFloat32) {
return 6;
} else if (op_sig.inputs.at(0).type == kTfLiteFloat64 ||
op_sig.outputs.at(0).type == kTfLiteFloat64 ||
Expand Down
26 changes: 17 additions & 9 deletions third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,28 @@ inline constexpr int64_t kCollectiveMemorySpaceColor = 1;
// collective memory using ncclMemAlloc in the runtime.
inline BufferAssigner::Colorer CollectiveColorer() {
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
static const auto* kSupportedOpcodes = new absl::flat_hash_set<HloOpcode>{
HloOpcode::kAllReduce,
HloOpcode::kAllReduceStart,
HloOpcode::kAllReduceDone,
HloOpcode::kAllGather,
HloOpcode::kAllGatherStart,
HloOpcode::kAllGatherDone,
HloOpcode::kReduceScatter,
HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart,
HloOpcode::kCollectivePermuteDone,
HloOpcode::kAllToAll,
};
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
auto& buffer = alias_analysis->GetBufferContainingValue(*value);
for (const auto& alias : buffer.values()) {
if ((alias->instruction()->opcode() == HloOpcode::kAllReduce ||
alias->instruction()->opcode() == HloOpcode::kAllReduceStart ||
alias->instruction()->opcode() == HloOpcode::kAllReduceDone ||
alias->instruction()->opcode() == HloOpcode::kAllGather ||
alias->instruction()->opcode() == HloOpcode::kAllGatherStart ||
alias->instruction()->opcode() == HloOpcode::kAllGatherDone ||
alias->instruction()->opcode() == HloOpcode::kReduceScatter) ||
// opcode or async wrapped opcode is in kSupportedOpcodes.
if (kSupportedOpcodes->contains(alias->instruction()->opcode()) ||
((alias->instruction()->opcode() == HloOpcode::kAsyncStart ||
alias->instruction()->opcode() == HloOpcode::kAsyncDone) &&
alias->instruction()->async_wrapped_opcode() ==
HloOpcode::kReduceScatter)) {
kSupportedOpcodes->contains(
alias->instruction()->async_wrapped_opcode()))) {
value->set_color(kCollectiveMemorySpaceColor);
}
}
Expand Down
20 changes: 16 additions & 4 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2179,14 +2179,18 @@ Status IrEmitterUnnested::EmitCollectivePermute(
// First output is aliased.
TF_RET_CHECK(
instr->shape().IsTuple() && instr->shape().tuple_shapes_size() == 2 &&
instr->shape().tuple_shapes(0) == instr->shape().tuple_shapes(1));
Shape::Equal().IgnoreMemorySpaceInLayout()(
instr->shape().tuple_shapes(0), instr->shape().tuple_shapes(1)));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSliceForHlo(instr, {1}));

const Shape shape = operand->shape();
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t src_memory_space = shape.layout().memory_space();
const int64_t dst_memory_space =
instr->shape().tuple_shapes(1).layout().memory_space();

if (NcclCollectivePermuteStartThunk::IsDegenerate(instr, replica_count,
partition_count)) {
Expand All @@ -2202,7 +2206,9 @@ Status IrEmitterUnnested::EmitCollectivePermute(
const NcclCollectiveThunk::Buffer buffer = {
/*element_count=*/ShapeUtil::ElementsIn(shape),
/*source_buffer=*/source_slice,
/*destination_buffer=*/result_slice};
/*destination_buffer=*/result_slice,
/*source_memory_space=*/src_memory_space,
/*destination_memory_space=*/dst_memory_space};
auto thunk = std::make_unique<NcclCollectivePermuteStartThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, buffer);
Expand Down Expand Up @@ -2619,10 +2625,13 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t memory_space = src->shape().layout().memory_space();
const NcclCollectiveThunk::Buffer nccl_buffer = {
/*element_count=*/ShapeUtil::ElementsIn(src->shape()),
/*source_buffer=*/buffer,
/*destination_buffer=*/buffer};
/*destination_buffer=*/buffer,
/*source_memory_space=*/memory_space,
/*destination_memory_space=*/memory_space};
auto thunk = std::make_unique<NcclSendThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, nccl_buffer);
Expand Down Expand Up @@ -2685,10 +2694,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t memory_space = instr->shape().layout().memory_space();
const NcclCollectiveThunk::Buffer nccl_buffer = {
/*element_count=*/ShapeUtil::ElementsIn(instr->shape().tuple_shapes(0)),
/*source_buffer=*/buffer,
/*destination_buffer=*/buffer};
/*destination_buffer=*/buffer,
/*source_memory_space=*/memory_space,
/*destination_memory_space=*/memory_space};
auto thunk = std::make_unique<NcclRecvThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, nccl_buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
se::Stream& stream, NcclApi::NcclCommHandle comm) {
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ absl::Status RunCollectivePermute(
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, device_ordinal, {buffer}, comm));

const std::optional<int64_t> source_id = source_target.source;
const std::optional<int64_t> target_id = source_target.target;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ absl::Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params,
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing Recv from device ordinal: " << device_ordinal
<< "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm));

const std::optional<int64_t> source_id = source_target.source;
se::DeviceMemoryBase dest_addr = buffer.destination_buffer;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ absl::Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params,
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm));

const std::optional<int64_t> target_id = source_target.target;
se::DeviceMemoryBase src_addr = buffer.source_buffer;
Expand Down

0 comments on commit 6d82b56

Please sign in to comment.