Skip to content

Commit

Permalink
[Multi-host GPU]Integrate GPU topology into PjRtClient.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633640627
  • Loading branch information
tensorflower-gardener committed May 16, 2024
1 parent a6f1278 commit a23440b
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 47 deletions.
5 changes: 3 additions & 2 deletions ci/official/utilities/setup_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# limitations under the License.
# ==============================================================================
if [[ "$TFCI_DOCKER_PULL_ENABLE" == 1 ]]; then
# Simple retry logic for docker-pull errors. Sleeps for 15s if a pull fails.
# Simple retry logic for docker-pull errors. Sleeps if a pull fails.
# Pulling an already-pulled container image will finish instantly, so
# repeating the command costs nothing.
docker pull "$TFCI_DOCKER_IMAGE" || sleep 15
docker pull "$TFCI_DOCKER_IMAGE" || sleep 15
docker pull "$TFCI_DOCKER_IMAGE" || sleep 30
docker pull "$TFCI_DOCKER_IMAGE" || sleep 60
docker pull "$TFCI_DOCKER_IMAGE"
fi

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/common_runtime/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ tf_cuda_library(
"//tensorflow/core",
"//tensorflow/core/framework:resource_base",
"@local_xla//xla/pjrt/distributed:key_value_store_interface",
"@local_xla//xla/pjrt/gpu:gpu_topology",
"@local_xla//xla/pjrt/gpu:gpu_topology_proto_cc",
"@local_xla//xla/pjrt:local_device_state",
"@local_xla//xla/pjrt:pjrt_client",
"@local_xla//xla/pjrt:pjrt_compiler",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ limitations under the License.
#if (defined(PLATFORM_GOOGLE) && defined(TF_PLATFORM_LINUX_X86_64))
#define TF_GPU_USE_PJRT
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/gpu_topology.pb.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/local_device_state.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -328,9 +330,11 @@ absl::Status CreateClientOnce(
// proceed.
creation_state->SetReady();
}
xla::GpuTopologyProto gpu_topology;
auto status = BuildDistributedDevices(
platform_name, std::move(unique_local_device_states), node_id, num_nodes,
&pjrt_devices, gpu_run_options.get(), kv_store,
&pjrt_devices, gpu_run_options.get(),
use_creation_info ? &gpu_topology : nullptr, kv_store,
/*enable_mock_nccl=*/false);
if (!status.ok()) {
if (use_creation_info) {
Expand Down Expand Up @@ -360,7 +364,10 @@ absl::Status CreateClientOnce(
/*allocator=*/std::move(info->allocator),
/*host_memory_allocator=*/std::move(info->host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options));
/*gpu_run_options=*/std::move(gpu_run_options),
/*gpu_topology=*/
std::shared_ptr<const xla::GpuTopology>(
xla::GpuTopology::FromProto(gpu_topology)));
VLOG(2) << "PJRT GPU client with remote devices created.";
status = SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU),
std::move(pjrt_client));
Expand Down Expand Up @@ -1139,6 +1146,7 @@ Status EagerContextDistributedManager::EnableCollectiveOps(
for (auto& local_device : local_devices) {
devices.mutable_device()->Add()->PackFrom(local_device);
}
LOG(INFO) << "xiangll device size is " << devices.device_size();
LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Connect());
LOG_AND_RETURN_IF_ERROR(
coordination_service_agent_->WaitForAllTasks(devices));
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/ops/compat/op_compatibility_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ class OpCompatibilityLib {

// Should match the contents of ops_file(). Run before calling
// ValidateCompatible().
string OpsString() const { return op_list_.DebugString(); }
string OpsString() const {
string result;
google::protobuf::TextFormat::PrintToString(op_list_, &result);
return result;
}

// Returns the number of ops in OpsString(), includes all ops, not
// just stable ops.
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ cc_library(
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_device_description",
"//xla/pjrt/gpu:gpu_helpers",
"//xla/pjrt/gpu:gpu_topology",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler
"//xla/python:custom_partition_callback",
Expand Down
8 changes: 7 additions & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/pjrt/c/pjrt_c_api_stream_extension.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
Expand Down Expand Up @@ -175,9 +176,14 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create(
device_ids.push_back(executor->device_ordinal());
}
auto gpu_target_config = xla::Compiler::TargetConfig(executor);
auto gpu_topology = std::make_shared<const xla::GpuTopology>(
device_ids, description.name(),
/*num_slices=*/1,
/*num_hosts_per_slice=*/1,
/*num_devices_per_host=*/device_ids.size());
auto pjrt_topology =
std::make_unique<xla::StreamExecutorGpuTopologyDescription>(
xla::CudaId(), xla::CudaName(), description.name(), device_ids,
xla::CudaId(), xla::CudaName(), description.name(), gpu_topology,
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>{
{"target_config",
gpu_target_config.ToProto().SerializeAsString()}});
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
":gpu_helpers",
":gpu_metrics",
":gpu_topology",
":gpu_topology_proto_cc",
"//xla:literal",
"//xla:shape_util",
"//xla:status",
Expand Down Expand Up @@ -296,6 +297,7 @@ xla_cc_test(
"requires-gpu-nvidia",
] + if_google(["config-cuda-only"]),
deps = [
":gpu_topology",
":se_gpu_pjrt_client",
":se_gpu_pjrt_compiler",
"//xla:test",
Expand Down
23 changes: 18 additions & 5 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ limitations under the License.
#include "xla/pjrt/distributed/topology_util.h"
#include "xla/pjrt/event_pool.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/gpu_topology.pb.h"
#include "xla/pjrt/local_device_state.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -900,6 +902,7 @@ Status BuildDistributedDevices(
int node_id, int num_nodes,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options,
GpuTopologyProto* gpu_topology,
std::shared_ptr<KeyValueStoreInterface> kv_store, bool enable_mock_nccl,
absl::Duration get_local_topology_timeout,
absl::Duration get_global_topology_timeout) {
Expand Down Expand Up @@ -942,6 +945,14 @@ Status BuildDistributedDevices(
&global_topology));
}

if (gpu_topology != nullptr) {
TF_ASSIGN_OR_RETURN(*gpu_topology, BuildGpuTopology(global_topology));
} else {
LOG(INFO)
<< "Skipping building GpuTopology. This is expected in tests that use "
"multiple threads to simulate multiple workers.";
}

std::map<int, GlobalDeviceId> gpu_device_ids;
absl::flat_hash_map<GlobalDeviceId, int> device_to_node;
for (const LocalTopologyProto& node : global_topology.nodes()) {
Expand Down Expand Up @@ -1129,24 +1140,26 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
kv_store = std::make_shared<InMemoryKeyValueStore>();
}
TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr);
GpuTopologyProto gpu_topology;
TF_RETURN_IF_ERROR(BuildDistributedDevices(
pjrt_platform_name, std::move(local_device_states), options.node_id,
options.num_nodes, &devices, gpu_run_options.get(), kv_store,
options.enable_mock_nccl));
options.num_nodes, &devices, gpu_run_options.get(), &gpu_topology,
kv_store, options.enable_mock_nccl));
auto memory_spaces = BuildMemorySpaces(devices);

return std::unique_ptr<PjRtClient>(std::make_unique<StreamExecutorGpuClient>(
pjrt_platform_name, xla_client, std::move(devices),
std::move(memory_spaces), options.node_id, std::move(allocator),
std::move(host_memory_allocator),
options.should_stage_host_to_device_transfers,
std::move(gpu_run_options)));
options.should_stage_host_to_device_transfers, std::move(gpu_run_options),
std::shared_ptr<const GpuTopology>(
GpuTopology::FromProto(gpu_topology))));
}

absl::StatusOr<std::string> StreamExecutorGpuTopologyDescription::Serialize()
const {
std::string result;
if (!tsl::SerializeToStringDeterministic(gpu_topology_.ToProto(), &result)) {
if (!tsl::SerializeToStringDeterministic(gpu_topology_->ToProto(), &result)) {
return absl::InternalError("Failed to serialize gpu_topology");
}
return result;
Expand Down
41 changes: 17 additions & 24 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,21 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
static StreamExecutorGpuTopologyDescription Create(
const PjRtPlatformId platform_id, const absl::string_view platform_name,
const absl::string_view platform_version,
const std::vector<PjRtDevice*>& devices) {
std::vector<int> device_ids;
device_ids.reserve(devices.size());
for (PjRtDevice* device : devices) {
device_ids.push_back(device->id());
}
std::shared_ptr<const GpuTopology> gpu_topology) {
return StreamExecutorGpuTopologyDescription(platform_id, platform_name,
platform_version, device_ids);
platform_version, gpu_topology);
}
// `gpu_device_ids` is the list of logical device ids for the GPU devices and
// will be used to initialize the GPU topology.

StreamExecutorGpuTopologyDescription(
const PjRtPlatformId platform_id, const absl::string_view platform_name,
const absl::string_view platform_version,
const std::vector<int>& gpu_device_ids,
std::shared_ptr<const GpuTopology> gpu_topology,
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
{})
: platform_id_(platform_id),
platform_name_(platform_name),
platform_version_(platform_version),
// TODO(b/331224674): Add support for multi-host.
gpu_topology_(gpu_device_ids, platform_version, /*num_slices=*/1,
/*num_hosts_per_slice=*/1,
/*num_devices_per_host=*/gpu_device_ids.size()),
gpu_topology_(gpu_topology),
attributes_(attributes) {}

bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
Expand All @@ -111,16 +102,16 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
const override {
std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
devices.reserve(gpu_topology_.number_of_devices());
for (const int device_id : gpu_topology_.device_ids()) {
devices.reserve(gpu_topology_->number_of_devices());
for (const int device_id : gpu_topology_->device_ids()) {
devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
device_id, platform_version_));
}
return devices;
}

const GpuTopology& gpu_topology() const { return gpu_topology_; }
const GpuTopology* gpu_topology_ptr() const { return &gpu_topology_; }
const GpuTopology& gpu_topology() const { return *gpu_topology_; }
const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }

// No subslice is supported.
bool is_subslice_topology() const override { return false; }
Expand All @@ -129,15 +120,15 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
absl::StatusOr<int> ProcessCount() const override { return 1; }

absl::StatusOr<int> CoreCountOfDefaultType() const override {
return gpu_topology_.number_of_devices();
return gpu_topology_->number_of_devices();
}

absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
return gpu_topology_.number_of_devices();
return gpu_topology_->number_of_devices();
}

absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
return gpu_topology_.number_of_devices();
return gpu_topology_->number_of_devices();
}

absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
Expand All @@ -160,7 +151,7 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
const PjRtPlatformId platform_id_;
const std::string platform_name_;
const std::string platform_version_;
const GpuTopology gpu_topology_;
std::shared_ptr<const GpuTopology> gpu_topology_;
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
};

Expand Down Expand Up @@ -209,15 +200,16 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tsl::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options,
std::shared_ptr<const GpuTopology> gpu_topology = nullptr)
: xla::PjRtStreamExecutorClient(
platform_name, client, std::move(devices), std::move(memory_spaces),
process_index, std::move(allocator),
std::move(host_memory_allocator),
should_stage_host_to_device_transfers, std::move(gpu_run_options)),
topology_(xla::StreamExecutorGpuTopologyDescription::Create(
tsl::Fingerprint64(platform_name), platform_name,
devices_.back()->device_kind(), devices_)) {}
devices_.back()->device_kind(), gpu_topology)) {}

absl::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
Expand Down Expand Up @@ -279,6 +271,7 @@ absl::Status BuildDistributedDevices(
int node_id, int num_nodes,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options,
GpuTopologyProto* gpu_topology,
std::shared_ptr<KeyValueStoreInterface> kv_store, bool enable_mock_nccl,
absl::Duration get_local_topology_timeout = absl::Minutes(2),
absl::Duration get_global_topology_timeout = absl::Minutes(5));
Expand Down
29 changes: 21 additions & 8 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "xla/client/xla_computation.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -60,10 +61,19 @@ absl::StatusOr<xla::XlaComputation> GetXlaComputation(
return XlaComputation(hlo_module->ToProto());
}

std::shared_ptr<xla::GpuTopology> GetGpuTopology(
std::vector<int> device_ids, absl::string_view platform_version,
int num_slices, int num_hosts_per_slice, int num_devices_per_host) {
return std::make_shared<xla::GpuTopology>(device_ids, platform_version,
num_slices, num_hosts_per_slice,
num_devices_per_host);
}

TEST(StreamExecutorGpuCompilerTest, NoClientXla) {
StreamExecutorGpuCompiler compiler;
StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(),
"Fake_device", {0, 1});
StreamExecutorGpuTopologyDescription topology(
CudaId(), CudaName(), "Fake_platform",
GetGpuTopology({0, 1}, "Fake_platform", 1, 1, 2));

TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram));
EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology,
Expand All @@ -73,8 +83,9 @@ TEST(StreamExecutorGpuCompilerTest, NoClientXla) {

TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) {
StreamExecutorGpuCompiler compiler;
StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(),
"Fake_device", {0, 1});
StreamExecutorGpuTopologyDescription topology(
CudaId(), CudaName(), "Fake_device",
GetGpuTopology({0, 1}, "Fake_platform", 1, 1, 2));

TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
Expand Down Expand Up @@ -119,8 +130,9 @@ TEST(StreamExecutorGpuCompilerTest, NoClientMlir) {
auto mlir_module =
mlir::parseSourceString<mlir::ModuleOp>(mlir_str, &context);

StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(),
"Fake_device", {0, 1});
StreamExecutorGpuTopologyDescription topology(
CudaId(), CudaName(), "Fake_device",
GetGpuTopology({0, 1}, "Fake_platform", 1, 1, 2));

EXPECT_THAT(
compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology,
Expand All @@ -137,8 +149,9 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) {
auto mlir_module =
mlir::parseSourceString<mlir::ModuleOp>(mlir_str, &context);

StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(),
"Fake_device", {0, 1});
StreamExecutorGpuTopologyDescription topology(
CudaId(), CudaName(), "Fake_device",
GetGpuTopology({0, 1}, "Fake_platform", 1, 1, 2));

TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4473,7 +4473,7 @@ cc_library(
]),
)

cuda_library(
gpu_kernel_library(
name = "buffer_comparator_kernel",
srcs = if_gpu_is_configured(["buffer_comparator.cu.cc"]),
copts = rocm_copts(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <mutex> // NOLINT
#include <optional>
#include <string>
#include <string_view>
#include <system_error> // NOLINT
#include <utility>
#include <variant>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ struct EstimateRunTimeData {
" compute_time: %s\n"
" exec_time: %s\n"
"}",
flops, bytes_written, num_threads, FormatDuration(read_time),
FormatDuration(write_time), FormatDuration(compute_time),
FormatDuration(exec_time));
flops, bytes_written, num_threads, absl::FormatDuration(read_time),
absl::FormatDuration(write_time), absl::FormatDuration(compute_time),
absl::FormatDuration(exec_time));
}
};

Expand Down

0 comments on commit a23440b

Please sign in to comment.