Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-host GPU]Integrate GPU topology into PjRtClient for multi-host GPU support #68444

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ package_group(
packages = [
"//tensorflow/...",
"//tensorflow_text/...",
"//waymo/ml/compiler/frontend/kernels/...",
"//waymo/onboard/ml/...",
],
)
Expand Down
63 changes: 42 additions & 21 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ class IndexFlattener {
namespace {

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -925,7 +925,11 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
for (int i = 0; i < IXDIM; ++i) { \
output_shape_prefix[i] = shape.dim_size(i); \
} \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
constexpr bool kShallDropBadIndices = \
kDropBadIndices || std::is_same<Device, GPUDevice>::value; \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM, \
kShallDropBadIndices> \
functor; \
bad_i = \
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
output_matrix, indices_flat, updates_flat, output_matrix); \
Expand All @@ -947,6 +951,9 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
slice_dim);
}
}
if constexpr (kDropBadIndices) {
return absl::OkStatus();
}
if (bad_i >= 0) {
auto slice_shape = indices.shape();
slice_shape.RemoveLastDims(1);
Expand All @@ -970,7 +977,8 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
// back to GPU. This is useful because the CPU implementation is deterministic
// and the GPU implementation is not. Tensor inputs to this function must be on
// the GPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
template <typename T, typename Index, scatter_nd_op::UpdateOp Op,
bool kDropBadIndices>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -1015,7 +1023,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
}

TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op, kDropBadIndices>(
c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));

// Copy 'host_out' to device.
Expand All @@ -1033,44 +1041,57 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
} // namespace

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value &&
tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Run on the CPU for integer types, since the GPU implementation uses
// atomics, which are not supported for all integer types.
if constexpr (std::is_same<Device, GPUDevice>::value &&
std::is_integral<T>::value) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
} else {
return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape,
out, allocate);
return DoScatterNdImpl<Device, T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
}
}
} // namespace functor

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index \
ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, /*kDropBadIndices=*/true>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>; \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scatter_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op, int IXDIM>
scatter_nd_op::UpdateOp op, int IXDIM, bool kDropBadIndices>
struct ScatterNdFunctor {
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(
Expand All @@ -63,7 +63,7 @@ struct ScatterNdFunctor {
// right type (T) and shape. This tensor will not be zeroed out
// before the scatter is executed.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices = false>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate);
Expand Down
52 changes: 32 additions & 20 deletions tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
namespace functor {

// Implementation of update functor for CPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM, kDropBadIndices> {
Index operator()(
const CPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -136,33 +137,44 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
i += ix_d * batch_strides[dim];
}
if (TF_PREDICT_FALSE(out_of_bounds)) {
if constexpr (kDropBadIndices) {
continue;
}
error_loc = loc;
break;
} else {
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}

return error_loc;
}
};

#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index \
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/true>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)

#define REGISTER_SCATTER_ND_INDEX(type, op) \
REGISTER_SCATTER_ND_FULL(type, int32, op); \
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ __global__ void ScatterNdOpKernel(
namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, kDropBadIndices> {
Index operator()(
const GPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -164,8 +165,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {

} // namespace functor

#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
6 changes: 2 additions & 4 deletions tensorflow/core/profiler/convert/xplane_to_op_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,9 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
*op_stats.mutable_host_op_metrics_db() =
ConvertHostThreadsXPlaneToOpMetricsDb(*host_plane);
}
if (options.generate_step_db) {
const StepEvents* device_step_events =
has_device ? &step_events : nullptr;
if (options.generate_step_db && !has_device) {
StepEvents host_step_events =
ConvertHostThreadsXPlaneToStepEvents(*host_plane, device_step_events);
ConvertHostThreadsXPlaneToStepEvents(*host_plane, nullptr);
CombineStepEvents(host_step_events, &step_events);
}
XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(host_plane);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ TEST(ConvertXPlaneToOpStats, GpuStepDbTest) {
options, &op_stats));
const StepDatabaseResult& step_db = op_stats.step_db();

EXPECT_EQ(step_db.step_sequence_size(), 1);
EXPECT_EQ(step_db.step_sequence_size(), 0);

PrecisionStats precision_stats =
op_stats.device_op_metrics_db().precision_stats();
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ AotCompileToGpuPjRtExecutable(
xla::Compiler::TargetConfig gpu_config(gpu_target_config);
xla::StreamExecutorGpuCompiler pjrt_gpu_compiler;
// Create a trivial topology, which won't be used.
xla::StreamExecutorGpuTopologyDescription topology(
xla::CudaId(), xla::CudaName(), "fake_device", {0});
xla::StreamExecutorGpuTopologyDescription topology(xla::CudaId(),
xla::CudaName(), nullptr);
xla::CompileOptions pjrt_options =
GetPjRtCompileOptions(options, **compilation_result);
pjrt_options.target_config = gpu_config;
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ cc_library(
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_device_description",
"//xla/pjrt/gpu:gpu_helpers",
"//xla/pjrt/gpu:gpu_topology",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler
"//xla/python:custom_partition_callback",
Expand Down
10 changes: 9 additions & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/pjrt/c/pjrt_c_api_stream_extension.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
Expand Down Expand Up @@ -175,9 +176,16 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create(
device_ids.push_back(executor->device_ordinal());
}
auto gpu_target_config = xla::Compiler::TargetConfig(executor);
// TODO(b/341334898): Create a single-host GPU topology. Will be updated for
// multi-host support in the future.
auto gpu_topology = std::make_shared<const xla::GpuTopology>(
device_ids, description.name(),
/*num_slices=*/1,
/*num_hosts_per_slice=*/1,
/*num_devices_per_host=*/device_ids.size());
auto pjrt_topology =
std::make_unique<xla::StreamExecutorGpuTopologyDescription>(
xla::CudaId(), xla::CudaName(), description.name(), device_ids,
xla::CudaId(), xla::CudaName(), std::move(gpu_topology),
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>{
{"target_config",
gpu_target_config.ToProto().SerializeAsString()}});
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
":gpu_helpers",
":gpu_metrics",
":gpu_topology",
":gpu_topology_proto_cc",
"//xla:literal",
"//xla:shape_util",
"//xla:status",
Expand Down Expand Up @@ -297,6 +298,7 @@ xla_cc_test(
"requires-gpu-nvidia",
] + if_google(["config-cuda-only"]),
deps = [
":gpu_topology",
":se_gpu_pjrt_client",
":se_gpu_pjrt_compiler",
"//xla:test",
Expand Down
Loading