Skip to content

Commit

Permalink
PR #8874: [GPU] Use NCCL user buffers for collective permute and all-…
Browse files Browse the repository at this point in the history
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
  • Loading branch information
trevor-m authored and tensorflower-gardener committed Mar 18, 2024
1 parent 1e3478b commit 5cd046b
Show file tree
Hide file tree
Showing 13 changed files with 240 additions and 18 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ tf_cc_test(
name = "shared_batch_scheduler_test",
size = "small",
srcs = ["shared_batch_scheduler_test.cc"],
tags = ["no_windows"],
deps = [
":batch_scheduler",
":fake_clock_env",
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/python/compiler/xla/experimental/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down Expand Up @@ -40,3 +41,27 @@ py_strict_test(
"@absl_py//absl/testing:absltest",
],
)

tpu_py_strict_test(
name = "resource_variable_xla_sharding_test",
srcs = ["resource_variable_xla_sharding_test.py"],
disable_v3_4chips = False,
python_version = "PY3",
srcs_version = "PY3",
tags = ["requires-net:external"],
deps = [
":xla_sharding",
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/framework:config",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/ops:array_ops",
"//tensorflow/python/ops:math_ops",
"//tensorflow/python/ops:variables",
"//tensorflow/python/tpu:device_assignment",
"//tensorflow/python/tpu:tpu_py",
"//tensorflow/python/training:adagrad",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from tensorflow.python.compiler.xla.experimental import xla_sharding
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.tpu import device_assignment
from tensorflow.python.tpu import tpu
from tensorflow.python.training import adagrad


# Gets all the nodes of `op` in graph that have `input_node_name` as one of the
# inputs
def _get_op_nodes_with_input(input_node_name, op, graph):
nodes_with_input = []
for node in graph.node:
nodes_with_input += [
node
for input in node.input
if input == input_node_name and node.op == op
]
return nodes_with_input


# Gets XlaSharding ops connected to ReadVariableOp for the given variable_name
def _get_xla_sharding_nodes_for_variable(variable_name, graph):
read_variable_op_nodes = _get_op_nodes_with_input(
variable_name, 'ReadVariableOp', graph
)
xla_sharding_op_nodes = []
for read_variable_op_node in read_variable_op_nodes:
xla_sharding_op_nodes += _get_op_nodes_with_input(
read_variable_op_node.name, 'XlaSharding', graph
)
return xla_sharding_op_nodes


def _get_xla_sharding_proto_from_node(node):
sharding_proto = xla_sharding.xla_data_pb2.OpSharding()
sharding_proto.ParseFromString(node.attr['sharding'].s)
return sharding_proto


class ResourceVariableXlaShardingTest(test.TestCase):

def setUp(self) -> None:
super().setUp()

context.enable_xla_sharding_for_resource_variables()
self.topology = tpu_cluster_resolver.initialize_tpu_system()
if len(config.list_logical_devices('TPU')) != 8:
self.skipTest('All tests require 8 TPUs.')

self.da = device_assignment.DeviceAssignment.build(
self.topology, computation_shape=[2, 2, 1, 2], num_replicas=1
)

def test_xla_sharding_ops_created_for_optimizer_slot_variables(self):
w = variables.Variable(
initial_value=math_ops.range(8, dtype=dtypes.float32),
name='w',
)
self.assertIsInstance(w, resource_variable_ops.BaseResourceVariable)
w = xla_sharding.split(
w,
split_dimension=0,
num_devices=8,
)
sharding_proto = xla_sharding.xla_data_pb2.OpSharding()
sharding_proto.ParseFromString(xla_sharding.get_tensor_sharding(w))
opt = adagrad.AdagradOptimizer(1.0)

@def_function.function
def computation(x):
def tpu_fn(x):
y = math_ops.add(w, x)
loss = math_ops.reduce_sum(y)
opt.minimize(loss, None, [w])
return loss

output = tpu.replicate(tpu_fn, [[x]], device_assignment=self.da)
return output

inputs = array_ops.reshape(math_ops.range(16, dtype=dtypes.float32), (2, 8))
result = computation(inputs)
self.assertSequenceEqual([[176.0]], self.evaluate(result))
graph = computation.get_concrete_function(inputs).graph.as_graph_def()

update_op_nodes = [
node for node in graph.node if node.op == 'ResourceApplyAdagrad'
]
self.assertLen(update_op_nodes, 1)
update_op_node = update_op_nodes[0]

var_input_name = update_op_node.input[0]
var_sharding_nodes = _get_xla_sharding_nodes_for_variable(
var_input_name, graph
)
self.assertLen(var_sharding_nodes, 1)
self.assertProtoEquals(
_get_xla_sharding_proto_from_node(var_sharding_nodes[0]), sharding_proto
)

slot_var_input_name = update_op_node.input[1]
slot_var_sharding_nodes = _get_xla_sharding_nodes_for_variable(
slot_var_input_name, graph
)
self.assertLen(slot_var_sharding_nodes, 1)
self.assertProtoEquals(
_get_xla_sharding_proto_from_node(slot_var_sharding_nodes[0]),
sharding_proto,
)


if __name__ == '__main__':
test.main()
25 changes: 24 additions & 1 deletion tensorflow/python/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,30 @@ def update_op(self, optimizer, g):
"Cannot use a constraint function on a sparse variable.")
return optimizer._resource_apply_sparse_duplicate_indices(
g.values, self._v, g.indices)
update_op = optimizer._resource_apply_dense(g, self._v)

if context.xla_sharding_for_resource_variables_enabled():
# For each slot variable that is annotated with an XLA sharding, we read
# the variable and assign the value to itself. This is done to trigger the
# creation of an XlaShardingOp when a ReadVariableOp is created upon the
# call to `slot_var.read_value()`. This is needed to ensure that slot
# variables with XLA sharding are sharded correctly. Please see
# b/307541427 for more details.
assign_ops = []
for variable_dict in optimizer._slots.values():
for slot_var in variable_dict.values():
if (
isinstance(slot_var, resource_variable_ops.BaseResourceVariable)
and slot_var._get_xla_sharding() is not None
):
assign_ops.append(slot_var.assign(slot_var.read_value()))

# The assign_ops created above are added as a control dependency for the
# update op to make sure these appear before the update_op.
with ops.control_dependencies(assign_ops):
update_op = optimizer._resource_apply_dense(g, self._v)
else:
update_op = optimizer._resource_apply_dense(g, self._v)

if self._v.constraint is not None:
with ops.control_dependencies([update_op]):
return self._v.assign(self._v.constraint(self._v))
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/hlo/ir/hlo_computation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -506,6 +507,7 @@ void HloComputation::ForEachInstructionPostOrderImpl(
absl::FunctionRef<void(HloInstruction*)> func, HloInstruction* root,
const ChannelDependencies& channel_dependencies, VisitMap& visited,
std::vector<HloInstruction*>* dfs_stack_scratch) const {
bool has_channel_dependencies = !channel_dependencies.empty();
auto* dfs_stack = dfs_stack_scratch;
dfs_stack->clear();
dfs_stack->push_back(root);
Expand All @@ -532,7 +534,7 @@ void HloComputation::ForEachInstructionPostOrderImpl(
// Collectives with the same channel ID must be performed together, as these
// represent MPMD-partitioned that will later be split into separate modules
// and the order must be preserved.
if (&current != root) {
if (has_channel_dependencies && &current != root) {
auto it = channel_dependencies.find(&current);
if (it != channel_dependencies.end()) {
dfs_stack->insert(dfs_stack->end(), it->second.begin(),
Expand Down
7 changes: 4 additions & 3 deletions third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ std::unique_ptr<HloDfsReachability> HloDfsReachability::Build(
const HloComputation* computation) {
auto res = std::make_unique<HloDfsReachability>();

HloComputation::ChannelDependencies channel_dependencies =
computation->ComputeChannelDependencies();
// For instruction reachability we do not care about correct order of
// collective operations as we only care about use-def chains.
HloComputation::ChannelDependencies empty_channel_dependencies;
std::vector<HloInstruction*> instructions =
computation->MakeInstructionPostOrder(channel_dependencies);
computation->MakeInstructionPostOrder(empty_channel_dependencies);

res->instruction_to_idx_.reserve(instructions.size());
for (size_t i = 0; i < instructions.size(); ++i) {
Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ class GemmAutotuner {
-> absl::StatusOr<se::blas::ProfileResult> {
se::OwningScratchAllocator<> scratch_allocator(
stream_->parent()->device_ordinal(), autotune_config_.GetAllocator());
// Run a warmup iteration without the profiler active.
TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
stream_, lhs_buffer_, rhs_buffer_, output_buffer_, output_buffer_,
bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
scratch_allocator));
se::blas::ProfileResult profile_result;
TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
stream_, lhs_buffer_, rhs_buffer_, output_buffer_, output_buffer_,
Expand Down
26 changes: 17 additions & 9 deletions third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,28 @@ inline constexpr int64_t kCollectiveMemorySpaceColor = 1;
// collective memory using ncclMemAlloc in the runtime.
inline BufferAssigner::Colorer CollectiveColorer() {
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
static const auto* kSupportedOpcodes = new absl::flat_hash_set<HloOpcode>{
HloOpcode::kAllReduce,
HloOpcode::kAllReduceStart,
HloOpcode::kAllReduceDone,
HloOpcode::kAllGather,
HloOpcode::kAllGatherStart,
HloOpcode::kAllGatherDone,
HloOpcode::kReduceScatter,
HloOpcode::kCollectivePermute,
HloOpcode::kCollectivePermuteStart,
HloOpcode::kCollectivePermuteDone,
HloOpcode::kAllToAll,
};
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
auto& buffer = alias_analysis->GetBufferContainingValue(*value);
for (const auto& alias : buffer.values()) {
if ((alias->instruction()->opcode() == HloOpcode::kAllReduce ||
alias->instruction()->opcode() == HloOpcode::kAllReduceStart ||
alias->instruction()->opcode() == HloOpcode::kAllReduceDone ||
alias->instruction()->opcode() == HloOpcode::kAllGather ||
alias->instruction()->opcode() == HloOpcode::kAllGatherStart ||
alias->instruction()->opcode() == HloOpcode::kAllGatherDone ||
alias->instruction()->opcode() == HloOpcode::kReduceScatter) ||
// opcode or async wrapped opcode is in kSupportedOpcodes.
if (kSupportedOpcodes->contains(alias->instruction()->opcode()) ||
((alias->instruction()->opcode() == HloOpcode::kAsyncStart ||
alias->instruction()->opcode() == HloOpcode::kAsyncDone) &&
alias->instruction()->async_wrapped_opcode() ==
HloOpcode::kReduceScatter)) {
kSupportedOpcodes->contains(
alias->instruction()->async_wrapped_opcode()))) {
value->set_color(kCollectiveMemorySpaceColor);
}
}
Expand Down
20 changes: 16 additions & 4 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2179,14 +2179,18 @@ Status IrEmitterUnnested::EmitCollectivePermute(
// First output is aliased.
TF_RET_CHECK(
instr->shape().IsTuple() && instr->shape().tuple_shapes_size() == 2 &&
instr->shape().tuple_shapes(0) == instr->shape().tuple_shapes(1));
Shape::Equal().IgnoreMemorySpaceInLayout()(
instr->shape().tuple_shapes(0), instr->shape().tuple_shapes(1)));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSliceForHlo(instr, {1}));

const Shape shape = operand->shape();
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t src_memory_space = shape.layout().memory_space();
const int64_t dst_memory_space =
instr->shape().tuple_shapes(1).layout().memory_space();

if (NcclCollectivePermuteStartThunk::IsDegenerate(instr, replica_count,
partition_count)) {
Expand All @@ -2202,7 +2206,9 @@ Status IrEmitterUnnested::EmitCollectivePermute(
const NcclCollectiveThunk::Buffer buffer = {
/*element_count=*/ShapeUtil::ElementsIn(shape),
/*source_buffer=*/source_slice,
/*destination_buffer=*/result_slice};
/*destination_buffer=*/result_slice,
/*source_memory_space=*/src_memory_space,
/*destination_memory_space=*/dst_memory_space};
auto thunk = std::make_unique<NcclCollectivePermuteStartThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, buffer);
Expand Down Expand Up @@ -2619,10 +2625,13 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t memory_space = src->shape().layout().memory_space();
const NcclCollectiveThunk::Buffer nccl_buffer = {
/*element_count=*/ShapeUtil::ElementsIn(src->shape()),
/*source_buffer=*/buffer,
/*destination_buffer=*/buffer};
/*destination_buffer=*/buffer,
/*source_memory_space=*/memory_space,
/*destination_memory_space=*/memory_space};
auto thunk = std::make_unique<NcclSendThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, nccl_buffer);
Expand Down Expand Up @@ -2685,10 +2694,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
const int64_t replica_count = hlo_config.replica_count();
const int64_t partition_count = hlo_config.num_partitions();
const int64_t memory_space = instr->shape().layout().memory_space();
const NcclCollectiveThunk::Buffer nccl_buffer = {
/*element_count=*/ShapeUtil::ElementsIn(instr->shape().tuple_shapes(0)),
/*source_buffer=*/buffer,
/*destination_buffer=*/buffer};
/*destination_buffer=*/buffer,
/*source_memory_space=*/memory_space,
/*destination_memory_space=*/memory_space};
auto thunk = std::make_unique<NcclRecvThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(),
instr, replica_count, partition_count, nccl_buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
se::Stream& stream, NcclApi::NcclCommHandle comm) {
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ absl::Status RunCollectivePermute(
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, device_ordinal, {buffer}, comm));

const std::optional<int64_t> source_id = source_target.source;
const std::optional<int64_t> target_id = source_target.target;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ absl::Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params,
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing Recv from device ordinal: " << device_ordinal
<< "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm));

const std::optional<int64_t> source_id = source_target.source;
se::DeviceMemoryBase dest_addr = buffer.destination_buffer;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ absl::Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params,
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm));

const std::optional<int64_t> target_id = source_target.target;
se::DeviceMemoryBase src_addr = buffer.source_buffer;
Expand Down

0 comments on commit 5cd046b

Please sign in to comment.