Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/pjrt/test_runtime_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def _reduce_scatter(pin_layout):
return out.cpu().numpy()

# 2023-08-02 04:16:36.520884: F external/xla/xla/service/layout_assignment.cc:157] Check failed: ShapeUtil::Compatible(shape_layout.shape(), instruction->operand(operand_no)->shape()) f32[1]{0} is not compatible with f32[2]{0} (for operand 0 of instruction %reduce-scatter.10 = f32[1]{0} reduce-scatter(f32[2]{0} %add.5), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.6)
@unittest.skip("Failed with known error.")
@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_reduce_scatter(self, pin_layout):
results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout)
Expand Down
29 changes: 0 additions & 29 deletions torch_xla/_internal/gpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os
import atexit
import torch_xla
import torch_xla.core.xla_env_vars as xenv

distributed_service = None


def num_local_processes() -> int:
"""Returns number of processes to create on this host.
Expand All @@ -17,28 +13,3 @@ def num_local_processes() -> int:
"Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client"
os.environ[xenv.LOCAL_WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES]
return int(os.environ[xenv.LOCAL_WORLD_SIZE])


def initialize_distributed_runtime(global_world_size: int) -> None:
"""Configures GPU distributed runtime parameters.

Must be run before using any XLA devices.

Args:
global_world_size: number of devices in the cluster.
"""
if global_world_size > 1:
global distributed_service
if distributed_service is None:
num_nodes = global_world_size
distributed_service = torch_xla._XLAC._xla_get_distributed_runtime_service(
num_nodes)
atexit.register(shutdown_distributed_runtime)


def shutdown_distributed_runtime() -> None:
"""Destroy the distributed runtime after a distributed computation."""
global distributed_service
if distributed_service:
distributed_service.shutdown()
distributed_service = None
13 changes: 0 additions & 13 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
return fn(*args, **kwargs)


def should_initialize_dist_runtime(local_rank: int):
if dist.is_torchelastic_launched():
assert xenv.RANK in os.environ, 'Environment variable is not set.'
return xu.getenv_as(xenv.RANK, int) == 0
return local_rank == 0


@runtime.requires_pjrt
def initialize_multiprocess(local_rank: int, local_world_size: int):
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank))
Expand All @@ -120,12 +113,6 @@ def initialize_multiprocess(local_rank: int, local_world_size: int):
tpu.configure_topology(local_rank, local_world_size)
elif runtime.device_type() == 'NEURON':
neuron.initialize_env(local_rank)
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
global_world_size = xu.getenv_as(
xenv.WORLD_SIZE, int, xu.getenv_as(xenv.LOCAL_WORLD_SIZE, int, 1))
assert global_world_size >= 0
if should_initialize_dist_runtime(local_rank):
gpu.initialize_distributed_runtime(global_world_size)

devices = xm.get_xla_supported_devices()
xm.set_replication(xm.xla_device(), devices)
Expand Down
24 changes: 0 additions & 24 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1915,30 +1915,6 @@ void InitXlaModuleBindings(py::module m) {
SetAllReduceToken(device, token);
});

/* The distributed runtime service is used by the PjRt GPU client. */
py::class_<xla::DistributedRuntimeService,
std::unique_ptr<xla::DistributedRuntimeService>>
distributed_runtime_service(m, "DistributedRuntimeService");
distributed_runtime_service.def("shutdown",
&xla::DistributedRuntimeService::Shutdown,
py::call_guard<py::gil_scoped_release>());
m.def(
"_xla_get_distributed_runtime_service",
[](int num_nodes) -> std::unique_ptr<xla::DistributedRuntimeService> {
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port =
runtime::sys_util::GetEnvString("XLA_COORDINATOR_PORT", "8547");
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
XLA_CHECK(num_nodes > 0) << "num_nodes must be positive: " << num_nodes;

xla::CoordinationServiceImpl::Options options;
options.num_nodes = num_nodes;
return std::move(
xla::GetDistributedRuntimeService(dist_service_addr, options)
.value());
});

BuildProfilerSubmodule(&m);
BuildLoweringContextSubmodule(&m);

Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ cc_library(
deps = [
":computation_client",
":debug_macros",
":distributed_runtime",
":env_vars",
":multi_wait",
":stablehlo_helper",
Expand Down Expand Up @@ -163,6 +164,17 @@ cc_library(
],
)

cc_library(
name = "distributed_runtime",
srcs = ["distributed_runtime.cc"],
hdrs = ["distributed_runtime.h"],
deps = [
":debug_macros",
":sys_util",
"@xla//xla/pjrt/distributed",
],
)

cc_library(
name = "metrics",
srcs = ["metrics.cc"],
Expand Down
54 changes: 54 additions & 0 deletions torch_xla/csrc/runtime/distributed_runtime.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "torch_xla/csrc/runtime/distributed_runtime.h"

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"

namespace torch_xla {
namespace runtime {

const std::string DistributedRuntime::default_coordinator_port = "8547";

DistributedRuntime::DistributedRuntime(int global_rank, std::string master_addr,
std::string port) {
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
if (global_rank == 0) {
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
xla::CoordinationServiceImpl::Options service_options;
service_options.num_nodes = global_world_size;
xla::StatusOr<std::unique_ptr<xla::DistributedRuntimeService>>
dist_runtime_service = xla::GetDistributedRuntimeService(
dist_service_addr, service_options);
XLA_CHECK(dist_runtime_service.ok())
<< "Failed to initialize distributed runtime service.";
dist_runtime_service_ = std::move(dist_runtime_service.value());
}

xla::DistributedRuntimeClient::Options client_options;
client_options.node_id = global_rank;
dist_runtime_client_ =
xla::GetDistributedRuntimeClient(dist_service_addr, client_options);
XLA_CHECK(dist_runtime_client_->Connect().ok())
<< "Failed to initialize distributed runtime client";
}

DistributedRuntime::~DistributedRuntime() {
if (dist_runtime_client_ != nullptr) {
XLA_CHECK(dist_runtime_client_->Shutdown().ok())
<< "Failed to shut down the distributed runtime client.";
dist_runtime_client_ = nullptr;
}
if (dist_runtime_service_ != nullptr) {
dist_runtime_service_->Shutdown();
dist_runtime_service_ = nullptr;
}
}

std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() {
XLA_CHECK(dist_runtime_client_ != nullptr)
<< "distributed runtime client is null.";
return dist_runtime_client_;
}

} // namespace runtime
} // namespace torch_xla
38 changes: 38 additions & 0 deletions torch_xla/csrc/runtime/distributed_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef XLA_CLIENT_DISTRIBUTED_RUNTIME_H_
#define XLA_CLIENT_DISTRIBUTED_RUNTIME_H_

#include <memory>

#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {

class DistributedRuntime {
public:
static const std::string default_coordinator_port;
static DistributedRuntime& getInstance(int global_rank,
std::string master_addr,
std::string port) {
static DistributedRuntime dist_runtime_instance(global_rank, master_addr,
port);
return dist_runtime_instance;
}
~DistributedRuntime();
DistributedRuntime(DistributedRuntime const&) = delete;
void operator=(DistributedRuntime const&) = delete;

std::shared_ptr<xla::DistributedRuntimeClient> GetClient();

private:
DistributedRuntime(int global_rank, std::string master_addr,
std::string port);

std::unique_ptr<xla::DistributedRuntimeService> dist_runtime_service_;
std::shared_ptr<xla::DistributedRuntimeClient> dist_runtime_client_;
};

} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_DISTRIBUTED_RUNTIME_H_
44 changes: 13 additions & 31 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/distributed_runtime.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
Expand Down Expand Up @@ -37,29 +38,6 @@ namespace {

static std::string spmd_device_str = "SPMD:0";

// Initializes a distributed runtime client if dist_service_addr is specified
std::shared_ptr<xla::DistributedRuntimeClient>
MaybeInitializeDistributedRuntimeClient(int local_rank) {
std::shared_ptr<xla::DistributedRuntimeClient> client;
int global_world_size = sys_util::GetEnvInt(
"WORLD_SIZE", sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
if (global_world_size < 2) {
return std::move(client);
}
std::string master_addr = sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port =
runtime::sys_util::GetEnvString("XLA_COORDINATOR_PORT", "8547");
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
xla::DistributedRuntimeClient::Options options;
options.node_id = local_rank;
TF_VLOG(3) << "Getting distributed runtime client for address="
<< dist_service_addr << ", node_id=" << options.node_id;
client = xla::GetDistributedRuntimeClient(dist_service_addr, options);
XLA_CHECK(client->Connect().ok())
<< "Failed to initialize distributed runtime client";
return std::move(client);
}

// Builds a map from the device's global ordinal to its index in the `devices`
// array.
std::unordered_map<int, int> build_index_map(
Expand Down Expand Up @@ -131,10 +109,14 @@ PjRtComputationClient::PjRtComputationClient() {
TF_VLOG(1) << "Initializing PjRt GPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true);
int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0);

int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank);
auto distributed_client =
MaybeInitializeDistributedRuntimeClient(global_process_rank);
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", DistributedRuntime::default_coordinator_port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
DistributedRuntime::getInstance(global_process_rank, master_addr, port)
.GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand All @@ -151,15 +133,15 @@ PjRtComputationClient::PjRtComputationClient() {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
int global_world_size = sys_util::GetEnvInt(
"WORLD_SIZE", sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
/*asynchronous=*/async, xla::GpuAllocatorConfig{},
/*asynchronous=*/async,
/*allocator_config=*/xla::GpuAllocatorConfig{},
/*node_id=*/global_process_rank,
/*num_nodes=*/
global_world_size,
/*num_nodes=*/global_world_size,
/*allowed_devices=*/allowed_devices,
/*platform_name=*/"gpu",
/*should_stage_host_to_device_transfers=*/true,
Expand Down