From 30fa84e696f956a5825692caa74d7c98792a41ac Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 16:53:24 +0000 Subject: [PATCH 01/22] Delete XRT from the main branch --- .github/workflows/_build.yml | 10 +- .github/workflows/build_and_test.yml | 1 - infra/ansible/config/env.yaml | 1 - setup.py | 3 - test/cpp/BUILD | 10 - test/cpp/test_op_by_op_executor.cpp | 85 - test/run_tests.sh | 31 - test/test_xla_dist.py | 887 ------- torch_xla/core/_xrt_run_server.py | 6 - torch_xla/csrc/init_python_bindings.cpp | 2 - torch_xla/csrc/runtime/BUILD | 138 +- torch_xla/csrc/runtime/runtime.cc | 30 - torch_xla/csrc/runtime/triggered_task.cc | 74 - torch_xla/csrc/runtime/triggered_task.h | 57 - .../csrc/runtime/xrt_computation_client.cc | 2302 ----------------- .../csrc/runtime/xrt_computation_client.h | 677 ----- torch_xla/csrc/runtime/xrt_local_service.cc | 67 - torch_xla/csrc/runtime/xrt_local_service.h | 45 - torch_xla/csrc/runtime/xrt_session.cc | 25 - torch_xla/csrc/runtime/xrt_session.h | 102 - torch_xla/csrc/runtime/xrt_session_cache.cc | 74 - torch_xla/csrc/runtime/xrt_session_cache.h | 101 - torch_xla/distributed/_xrt_run_server.py | 45 - torch_xla/distributed/cluster.py | 477 ---- torch_xla/distributed/worker.py | 114 - torch_xla/distributed/xla_dist.py | 696 ----- torch_xla/distributed/xrt_init.py | 249 -- 27 files changed, 2 insertions(+), 6307 deletions(-) delete mode 100644 test/cpp/test_op_by_op_executor.cpp delete mode 100644 test/test_xla_dist.py delete mode 100644 torch_xla/core/_xrt_run_server.py delete mode 100644 torch_xla/csrc/runtime/triggered_task.cc delete mode 100644 torch_xla/csrc/runtime/triggered_task.h delete mode 100644 torch_xla/csrc/runtime/xrt_computation_client.cc delete mode 100644 torch_xla/csrc/runtime/xrt_computation_client.h delete mode 100644 torch_xla/csrc/runtime/xrt_local_service.cc delete mode 100644 torch_xla/csrc/runtime/xrt_local_service.h delete mode 100644 torch_xla/csrc/runtime/xrt_session.cc delete mode 100644 torch_xla/csrc/runtime/xrt_session.h delete mode 100644 torch_xla/csrc/runtime/xrt_session_cache.cc delete mode 100644 torch_xla/csrc/runtime/xrt_session_cache.h delete mode 100644 torch_xla/distributed/_xrt_run_server.py delete mode 100644 torch_xla/distributed/cluster.py delete mode 100644 torch_xla/distributed/worker.py delete mode 100755 torch_xla/distributed/xla_dist.py delete mode 100644 torch_xla/distributed/xrt_init.py diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 674dc145735a..6a9510b64141 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -48,7 +48,6 @@ jobs: SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} XLA_CUDA: ${{ inputs.cuda }} - DISABLE_XRT: ${{ inputs.disable_xrt }} steps: - name: Setup Linux uses: pytorch/test-infra/.github/actions/setup-linux@main @@ -88,7 +87,6 @@ jobs: run: | echo "declare -x SCCACHE_BUCKET=${SCCACHE_BUCKET}" | docker exec -i "${pid}" sh -c "cat >> env" echo "declare -x CC=clang-8 CXX=clang++-8" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x DISABLE_XRT=${DISABLE_XRT}" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_CUDA=${XLA_CUDA}" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> default_credentials.json" @@ -107,13 +105,7 @@ jobs: id: upload-docker-image shell: bash run: | - if [[ ${DISABLE_XRT} == 1 ]]; then - image_tag_base=latest - else - image_tag_base=latest-xrt - fi - - export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:${image_tag_base}-${GITHUB_SHA}" + export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:latest-${GITHUB_SHA}" time docker commit "${pid}" "${COMMIT_DOCKER_IMAGE}" time docker push "${COMMIT_DOCKER_IMAGE}" echo "docker-image=${COMMIT_DOCKER_IMAGE}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index e45a7a4d41bb..6a80fde05d81 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -22,7 +22,6 @@ jobs: with: ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base gcr-docker-image: gcr.io/tpu-pytorch/xla_base:latest - disable_xrt: 1 cuda: 1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index d2996f750367..3a25c298ae34 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -32,7 +32,6 @@ build_env: XLA_SANDBOX_BUILD: 1 BAZEL_REMOTE_CACHE: 1 SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}" - DISABLE_XRT: "{{ disable_xrt }}" amd64: ARCH: amd64 diff --git a/setup.py b/setup.py index 2a9b1f4c1ca8..3ae615a500bc 100644 --- a/setup.py +++ b/setup.py @@ -40,9 +40,6 @@ # TPUVM_MODE=0 # whether to build for TPU # -# DISABLE_XRT=0 -# whether to exclude XRT from the build -# # SILO_NAME="" # name of the remote build cache silo # diff --git a/test/cpp/BUILD b/test/cpp/BUILD index d46153da1549..1f96881d9a3f 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -70,16 +70,6 @@ ptxla_cc_test( ], ) -ptxla_cc_test( - name = "test_op_by_op_executor", - srcs = ["test_op_by_op_executor.cpp"], - deps = [ - ":cpp_test_util", - "//torch_xla/csrc:tensor", - "@com_google_googletest//:gtest_main", - ], -) - ptxla_cc_test( name = "test_replication", srcs = ["test_replication.cpp"], diff --git a/test/cpp/test_op_by_op_executor.cpp b/test/cpp/test_op_by_op_executor.cpp deleted file mode 100644 index fddedde24798..000000000000 --- a/test/cpp/test_op_by_op_executor.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include - -#include "test/cpp/cpp_test_util.h" -#include "torch_xla/csrc/ir.h" -#include "torch_xla/csrc/op_by_op_executor.h" -#include "torch_xla/csrc/ops/arithmetic_ir_ops.h" -#include "torch_xla/csrc/ops/ops.h" -#include "torch_xla/csrc/ops/scalar.h" -#include "torch_xla/csrc/ops/stack.h" -#include "torch_xla/csrc/tensor_util.h" - -namespace torch_xla { -namespace cpp_test { - -TEST(OpByOpExecutorTest, TestSimpleAdd) { - if (UsingPjRt()) { - GTEST_SKIP(); - } - - ForEachDevice([&](const torch::lazy::BackendDevice& device) { - at::Tensor a = at::rand({4, 16, 3}, at::TensorOptions(at::kFloat)); - at::Tensor b = at::rand({4, 16, 3}, at::TensorOptions(at::kFloat)); - at::Tensor c = a + b; - - torch::lazy::Value v_a = GetTensorIrValue(a, device); - torch::lazy::Value v_b = GetTensorIrValue(b, device); - torch::lazy::Value v_c = v_a + v_b; - - auto results_data = - OpByOpExecutor::Get()->Execute({v_c}, device.toString(), {}); - auto results = Fetch(UnwrapXlaData(results_data)); - - AllClose(results.front(), c); - }); -} - -TEST(OpByOpExecutorTest, TestStack) { - if (UsingPjRt()) { - GTEST_SKIP(); - } - - ForEachDevice([&](const torch::lazy::BackendDevice& device) { - at::Tensor a = at::rand({4, 8, 3}, at::TensorOptions(at::kFloat)); - at::Tensor b = at::rand({4, 8, 3}, at::TensorOptions(at::kFloat)); - at::Tensor c = at::stack({a, b}, 1); - - torch::lazy::Value v_a = GetTensorIrValue(a, device); - torch::lazy::Value v_b = GetTensorIrValue(b, device); - torch::lazy::Value v_c = torch::lazy::MakeNode( - std::vector({v_a, v_b}), 1); - - auto results_data = - OpByOpExecutor::Get()->Execute({v_c}, device.toString(), {}); - auto results = Fetch(UnwrapXlaData(results_data)); - - AllClose(results.front(), c); - }); -} - -TEST(OpByOpExecutorTest, TestAsyncStack) { - if (UsingPjRt()) { - GTEST_SKIP(); - } - - ForEachDevice([&](const torch::lazy::BackendDevice& device) { - at::Tensor a = at::rand({4, 8, 3}, at::TensorOptions(at::kFloat)); - at::Tensor b = at::rand({4, 8, 3}, at::TensorOptions(at::kFloat)); - at::Tensor c = at::stack({a, b}, 1); - - torch::lazy::Value v_a = GetTensorIrValue(a, device); - torch::lazy::Value v_b = GetTensorIrValue(b, device); - torch::lazy::Value v_c = torch::lazy::MakeNode( - std::vector({v_a, v_b}), 1); - - auto async = - OpByOpExecutor::Get()->ExecuteAsync({v_c}, device.toString(), {}); - async.Wait(); - auto results = Fetch(UnwrapXlaData(async.ConsumeValue())); - - AllClose(results.front(), c); - }); -} - -} // namespace cpp_test -} // namespace torch_xla diff --git a/test/run_tests.sh b/test/run_tests.sh index 04df4ea13581..e898ea6d6c48 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -116,34 +116,6 @@ function run_xrt { fi } -function run_opbyop { - echo "Running in OpByOp mode: $@" - XLA_GET_TENSORS_OPBYOP=1 XLA_SYNC_TENSORS_OPBYOP=1 run_xrt "$@" -} - -function run_async_scalar { - echo "Running in Async Scalar Upload mode: $@" - XLA_TRANSFER_SCALAR_ASYNC=1 run_xrt "$@" -} - -function run_torchrun { - echo "Running tests spawned by torchrun" - if [ -x "$(command -v nvidia-smi)" ]; then - run_xrt "$@" - else - echo "the tests need atleast two XLA workers to validate" - fi -} - -function run_xrt_tests { - # For features not supported in PJRT - echo "Running XRT tests" - run_xrt "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - run_opbyop "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - run_async_scalar "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - run_torchrun "$CDIR/test_allreduce_torchrun.py" -} - function run_torch_op_tests { run_dynamic "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA run_test_without_functionalization "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA @@ -235,9 +207,6 @@ function run_tests { if [[ "$XLA_SKIP_MP_OP_TESTS" != "1" ]]; then run_mp_op_tests fi - if [[ "$XLA_SKIP_XRT_TESTS" != "1" ]]; then - run_xrt_tests - fi } if [ "$LOGFILE" != "" ]; then diff --git a/test/test_xla_dist.py b/test/test_xla_dist.py deleted file mode 100644 index d8e7bd462599..000000000000 --- a/test/test_xla_dist.py +++ /dev/null @@ -1,887 +0,0 @@ -"""Tests for xla_dist.""" - -import cloud_tpu_client -import uuid -import unittest -from unittest import mock - -from googleapiclient import discovery -from oauth2client.client import GoogleCredentials -from torch_xla.distributed.cluster import Cluster -from torch_xla.distributed.cluster import ClusterResolver -from torch_xla.distributed.worker import ClientWorker -from torch_xla.distributed.worker import ServiceWorker - -PROJECT_ZONE_PREFIX = ('https://www.googleapis.com/compute/v1/' - 'projects/fake-project/zones/fake-zone') -TPUVM_HOSTNAME_PREFIX = 't1v-n-5d9c8fb2-w-' - - -class ClusterTest(unittest.TestCase): - - def test_validate_good_cluster(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), - ClientWorker( - '10.0.0.3', 'n1-standard-16', 'europe-west4-a', hostname='test'), - ] - service_workers = [ - ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ] - cluster = Cluster( - client_workers, service_workers, client_master_ip='10.0.0.0') - cluster.validate() # Does not raise exception - - def test_create_bad_client_workers(self): - service_workers = [ - ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ] - client_workers = [ - ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), - ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ] - self.assertRaisesRegex( - ValueError, - 'client_workers argument must be a list of ClientWorker', - Cluster, - client_workers, - service_workers, - client_master_ip='10.0.0.1') - - def test_create_bad_service_workers(self): - client_workers = [ - ClientWorker( - '10.0.0.1', 'n1-standard-16', 'europe-west4-a', hostname='test'), - ] - self.assertRaisesRegex( - ValueError, - 'service_workers argument must be a list of ServiceWorker', - Cluster, - client_workers, - client_workers, - client_master_ip='10.0.0.1') - - def test_validate_machine_type_client_cluster(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.1', 'n1-standard-8', 'europe-west4-a'), - ] - service_workers = [ - ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ] - - no_check_cluster = Cluster( - client_workers, - service_workers, - check_client_machine_type=False, - client_master_ip='10.0.0.0') - no_check_cluster.validate() # Does not raise exception - - check_cluster = Cluster( - client_workers, service_workers, client_master_ip='10.0.0.0') - self.assertRaisesRegex( - RuntimeError, 'All client_workers must have the same machine_type', - check_cluster.validate) - - def test_validate_machine_type_service_cluster(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), - ] - service_workers = [ - ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a', - 'pytorch-0.2'), - ] - - no_check_cluster = Cluster( - client_workers, - service_workers, - check_service_machine_type=False, - client_master_ip='10.0.0.0') - no_check_cluster.validate() # Does not raise exception - - check_cluster = Cluster( - client_workers, service_workers, client_master_ip='10.0.0.0') - self.assertRaisesRegex( - RuntimeError, 'All service_workers must have the same machine_type', - check_cluster.validate) - - def test_validate_bad_zone_cluster(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'), - ] - service_workers = [ - ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', - 'pytorch-0.2'), - ] - cluster = Cluster( - client_workers, service_workers, client_master_ip='10.0.0.0') - self.assertRaisesRegex(RuntimeError, 'All workers must be in the same zone', - cluster.validate) - - def test_validate_diff_num_workers(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), - ] - service_workers = [ - ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ] - cluster = Cluster( - client_workers, service_workers, client_master_ip='10.0.0.0') - self.assertRaisesRegex( - RuntimeError, - 'The client_workers and service_workers must have a 1:1 mapping', - cluster.validate) - - def test_validate_empty_workers(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a') - ] - cluster = Cluster(client_workers, [], client_master_ip='10.0.0.0') - self.assertRaisesRegex( - RuntimeError, - 'Both client_workers and service_workers should not be empty', - cluster.validate) - - def test_validate_diff_runtime_versions(self): - client_workers = [ - ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), - ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'), - ] - service_workers = [ - ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.1'), - ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.1'), - ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', - 'pytorch-0.2'), - ] - cluster = Cluster( - client_workers, service_workers, client_master_ip='10.0.0.0') - self.assertRaisesRegex( - RuntimeError, - 'All service workers must have the same runtime_version.*', - cluster.validate) - - -def mock_request_metadata(metadata): - fake_metadata = { - 'project/project-id': 'fake-project', - 'instance/zone': 'project/fake-project/zones/fake-zone', - 'instance/name': 'fake-ig-a', - 'instance/network-interfaces/0/ip': '10.0.0.0', - # Adding this field to prevent crashing when ClusterResolver querying this - # metadata to identify the TPUVM case. - 'instance/attributes/accelerator-type': '', - } - return fake_metadata[metadata] - - -def mock_request_tpuvm_metadata(metadata): - fake_metadata = { - 'project/project-id': 'fake-project', - 'instance/zone': 'project/fake-project/zones/fake-zone', - 'instance/name': TPUVM_HOSTNAME_PREFIX + '0', - 'instance/network-interfaces/0/ip': '10.1.0.0', - 'instance/attributes/accelerator-type': 'v3-32', - } - return fake_metadata[metadata] - - -def mock_ip_to_hostname_mapping(tpu_name, zone, num_vm): - ip_to_hostname_map = {} - for index in range(num_vm): - ip_to_hostname_map[f'10.1.0.{index}'] = f'{TPUVM_HOSTNAME_PREFIX}{index}' - return ip_to_hostname_map - - -def build_mock_cloud_tpu_client_library(tpu_map): - - def mock_cloud_tpu_client_constructor(*args, **kwargs): - # Patch to mock cloud_tpu_client.Client.__init__ method. - tpu_name = kwargs['tpu'] - tpu_dict = tpu_map[tpu_name] - ctc = mock.MagicMock() - ctc.name.return_value = tpu_name - ctc.state.return_value = tpu_dict.get('state') - ctc.health.return_value = tpu_dict.get('health') - ctc.runtime_version.return_value = tpu_dict.get('runtime_version') - ctc.accelerator_type.return_value = tpu_dict.get('accelerator_type') - ctc.network_endpoints.return_value = tpu_dict.get('network_endpoints') - # TODO: add a api to get the tpu api version directly - ctc._get_tpu_property.return_value = tpu_dict.get('api_version') - ctc._full_name.return_value = \ - f'projects/fake-project/locations/fake-zone/nodes/{tpu_name}' - return ctc - - return mock_cloud_tpu_client_constructor - - -def build_mock_compute_service(get_instance_map, list_instances_map): - # Instances mock - def get_instance_fn(*args, **kwargs): - resp = get_instance_map[kwargs['instance']] - get_instance = mock.MagicMock() - get_instance.execute.return_value = resp - get_instance.resumable = None - return get_instance - - instances = mock.MagicMock() - instances.get.side_effect = get_instance_fn - - # Instance groups mock - def list_instances_fn(*args, **kwargs): - resp = list_instances_map[kwargs['instanceGroup']] - list_instances = mock.MagicMock() - list_instances.execute.return_value = resp - return list_instances - - instance_groups = mock.MagicMock() - instance_groups.listInstances.side_effect = list_instances_fn - - # Compute service mock - compute_service = mock.MagicMock() - compute_service.instances.return_value = instances - compute_service.instanceGroups.return_value = instance_groups - compute_service.new_batch_http_request.return_value = build_mock_batch_call() - - return compute_service - - -def build_mock_services_fn(mock_compute_service): - - def mock_google_services(serviceName, version, **kwargs): - if serviceName == 'compute': - return mock_compute_service - else: - raise RuntimeError(f'Service name "{serviceName}" is not mocked.') - - return mock_google_services - - -def build_mock_batch_call(): - batcher = mock.MagicMock() - - def build_execute_requests_fn(call_list): - - def execute_requests(*args): - del args - for args, _ in call_list: - req, callback = args - resp = None - exception = None - try: - resp = req.execute() - except e: - exception = e - callback(uuid.uuid4(), resp, exception) - - return execute_requests - - batcher.execute.side_effect = build_execute_requests_fn( - batcher.add.call_args_list) - return batcher - - -def gen_fake_instances_get_entry(instance_name, machine_type, internal_ip, - status): - return { - 'machineType': f'{PROJECT_ZONE_PREFIX}/machineTypes/{machine_type}', - 'metadata': { - 'fingerprint': 'abc', - 'items': [{ - 'key': - 'instance-template', - 'value': ('projects/123456789012/global/' - 'instanceTemplates/fake-ig-template'), - }, { - 'key': - 'created-by', - 'value': ('projects/123456789012/zones/fake-zone/' - 'instanceGroupManagers/fake-ig'), - }], - 'kind': 'compute#metadata', - }, - 'selfLink': f'{PROJECT_ZONE_PREFIX}/instances/{instance_name}', - 'networkInterfaces': [{ - 'networkIP': internal_ip, - }], - 'status': status, - 'zone': PROJECT_ZONE_PREFIX, - } - - -def gen_fake_ig_list_instances_entry(instance_name, status): - return { - 'instance': f'{PROJECT_ZONE_PREFIX}/instances/{instance_name}', - 'status': status, - } - - -class ClusterResolverTest(unittest.TestCase): - - def setUp(self): - super(ClusterResolverTest, self).setUp() - self.addCleanup(mock.patch.stopall) - mock.patch.object(ClusterResolver, 'get_instance_metadata', - mock_request_metadata).start() - mock.patch.object(ClusterResolver, '_get_internal_ip_to_hostname_mapping', - mock_ip_to_hostname_mapping).start() - mock.patch.object(GoogleCredentials, 'get_application_default', - lambda *args, **kwargs: None).start() - self.mock_discovery = mock.patch.object( - discovery, 'build', autospec=True).start() - self.mock_ctc = mock.patch.object( - cloud_tpu_client, 'Client', autospec=True).start() - - def test_bad_empty_tpu_constructor(self): - tpus = '' - self.assertRaisesRegex(ValueError, 'tpu must be a non-empty string', - ClusterResolver, tpus) - - def test_bad_none_tpu_constructor(self): - tpus = None - self.assertRaisesRegex(ValueError, 'tpu must be a non-empty string', - ClusterResolver, tpus) - - def test_bad_vm_constructor(self): - tpus = ['fake-tpu'] - vms = {'abc'} - self.assertRaisesRegex(ValueError, - 'vms must be a non-empty list if provided', - ClusterResolver, tpus, vms) - - def test_healthy_instance_group_client_cluster(self): - # Arrange - list_instances_map = { - 'fake-ig': { - 'kind': - 'compute#instanceGroupsListInstances', - 'items': [ - gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') - for c in 'abcd' - ], - }, - } - instance_resp_map = { - 'fake-ig-' + c: - gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', - '10.0.0.' + ip, 'RUNNING') - for c, ip in zip('abcd', '0123') - } - compute_service = build_mock_compute_service(instance_resp_map, - list_instances_map) - self.mock_discovery.side_effect = build_mock_services_fn(compute_service) - - # Act - cr = ClusterResolver(['fake-tpu']) - vm_cluster = cr.get_client_workers() - - # Assert - expected = [ - ClientWorker( - internal_ip='10.0.0.' + ip, - machine_type='n1-standard-16', - zone='fake-zone', - hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') - ] - self.assertCountEqual(expected, vm_cluster) - - def test_healthy_vm_list_client_cluster(self): - # Arrange - list_instances_map = {} - instance_resp_map = { - 'fake-ig-' + c: - gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', - '10.0.0.' + ip, 'RUNNING') - for c, ip in zip('abcd', '0123') - } - compute_service = build_mock_compute_service(instance_resp_map, - list_instances_map) - self.mock_discovery.side_effect = build_mock_services_fn(compute_service) - - # Act - vms = ['fake-ig-a', 'fake-ig-b', 'fake-ig-c', 'fake-ig-d'] - cr = ClusterResolver(['fake-tpu'], vms=vms) - vm_cluster = cr.get_client_workers() - - # Assert - expected = [ - ClientWorker( - internal_ip='10.0.0.' + ip, - machine_type='n1-standard-16', - zone='fake-zone', - hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') - ] - self.assertCountEqual(expected, vm_cluster) - - def test_empty_instance_group_client_cluster(self): - list_instances_map = { - 'fake-ig': { - 'kind': 'compute#instanceGroupsListInstances', - 'items': [], - }, - } - instance_resp_map = { - 'fake-ig-a': - gen_fake_instances_get_entry('fake-ig-a', 'n1-standard-16', - '10.0.0.0', 'RUNNING'), - } - compute_service = build_mock_compute_service(instance_resp_map, - list_instances_map) - self.mock_discovery.side_effect = build_mock_services_fn(compute_service) - - # Act - cr = ClusterResolver(['fake-tpu']) - - # Assert - self.assertRaisesRegex(RuntimeError, '.*vms is empty in instance group.*', - cr.get_client_workers) - - def test_unhealthy_client_cluster(self): - # Arrange - list_instances_map = { - 'fake-ig': { - 'kind': - 'compute#instanceGroupsListInstances', - 'items': [ - gen_fake_ig_list_instances_entry('fake-ig-a', 'RUNNING'), - gen_fake_ig_list_instances_entry('fake-ig-b', 'PROVISIONING'), - gen_fake_ig_list_instances_entry('fake-ig-c', 'RUNNING'), - gen_fake_ig_list_instances_entry('fake-ig-d', 'RUNNING'), - ], - }, - } - instance_resp_map = { - 'fake-ig-a': - gen_fake_instances_get_entry('fake-ig-a', 'n1-standard-16', - '10.0.0.0', 'RUNNING'), - 'fake-ig-b': - gen_fake_instances_get_entry('fake-ig-b', 'n1-standard-16', - '10.0.0.1', 'PROVISIONING'), - 'fake-ig-c': - gen_fake_instances_get_entry('fake-ig-c', 'n1-standard-16', - '10.0.0.2', 'RUNNING'), - 'fake-ig-d': - gen_fake_instances_get_entry('fake-ig-d', 'n1-standard-16', - '10.0.0.3', 'RUNNING'), - } - compute_service = build_mock_compute_service(instance_resp_map, - list_instances_map) - self.mock_discovery.side_effect = build_mock_services_fn(compute_service) - - # Act - cr = ClusterResolver(['fake-tpu']) - - # Assert - self.assertRaisesRegex(RuntimeError, - 'Instance fake-ig-b is not running yet.*', - cr.get_client_workers) - - def test_healthy_pod_service_cluster(self): - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-32', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - } for ip in range(4)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - service_workers = cr.get_tpu_workers() - - expected = [ - ServiceWorker( - internal_ip=f'10.0.0.{ip}', - port='8470', - machine_type='v3-32', - zone='fake-zone', - runtime_version='pytorch-nightly', - tpu='fake-pod') for ip in range(4) - ] - self.assertCountEqual(expected, service_workers) - - def test_healthy_sea_service_cluster(self): - noop_compute_service = build_mock_compute_service({}, {}) - self.mock_discovery.side_effect = build_mock_services_fn( - noop_compute_service) - tpu_map = { - f'fake-tpu-{ip}': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-8', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - }], - } for ip in range(256) - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - service_workers = cr.get_tpu_workers() - - expected = [ - ServiceWorker( - internal_ip=f'10.0.0.{ip}', - port='8470', - machine_type='v3-8', - zone='fake-zone', - runtime_version='pytorch-nightly', - tpu=f'fake-tpu-{ip}') for ip in range(256) - ] - self.assertCountEqual(expected, service_workers) - - def test_unhealthy_pod_service_cluster(self): - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'health': - 'UNHEALTHY_TENSORFLOW', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-128', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - } for ip in range(16)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - self.assertRaisesRegex(RuntimeError, 'TPU fake-pod is not HEALTHY yet.*', - cr.get_tpu_workers) - - def test_non_ready_sea_service_cluster(self): - noop_compute_service = build_mock_compute_service({}, {}) - self.mock_discovery.side_effect = build_mock_services_fn( - noop_compute_service) - - tpu_map = { - f'fake-tpu-{ip}': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-8', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - }], - } for ip in range(3) - } - tpu_map['fake-tpu-3'] = { - 'state': 'CREATING', - 'runtime_version': 'pytorch-nightly', - 'accelerator_type': 'v3-8', - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - self.assertRaisesRegex(RuntimeError, 'TPU fake-tpu-3 is not READY yet.*', - cr.get_tpu_workers) - - def test_unknown_health_pod_service_cluster(self): - noop_compute_service = build_mock_compute_service({}, {}) - self.mock_discovery.side_effect = build_mock_services_fn( - noop_compute_service) - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-32', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - } for ip in range(4)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - self.assertRaisesRegex(RuntimeError, 'TPU fake-pod is not HEALTHY yet.*', - cr.get_tpu_workers) - - def test_healthy_cluster(self): - list_instances_map = { - 'fake-ig': { - 'kind': - 'compute#instanceGroupsListInstances', - 'items': [ - gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') - for c in 'abcd' - ], - }, - } - instance_resp_map = { - 'fake-ig-' + c: - gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', - '10.0.0.' + ip, 'RUNNING') - for c, ip in zip('abcd', '0123') - } - compute_service = build_mock_compute_service(instance_resp_map, - list_instances_map) - self.mock_discovery.side_effect = build_mock_services_fn(compute_service) - - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-32', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - } for ip in range(4)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - cluster = cr.get_cluster() - - expected_client_workers = [ - ClientWorker( - internal_ip='10.0.0.' + ip, - machine_type='n1-standard-16', - zone='fake-zone', - hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') - ] - expected_service_workers = [ - ServiceWorker( - internal_ip=f'10.0.0.{ip}', - port='8470', - machine_type='v3-32', - zone='fake-zone', - runtime_version='pytorch-nightly', - tpu='fake-pod') for ip in range(4) - ] - expected = Cluster( - expected_client_workers, - expected_service_workers, - client_master_ip='10.0.0.0') - self.assertEqual(expected, cluster) - - def test_healthy_remote_coordinator(self): - noop_compute_service = build_mock_compute_service({}, {}) - self.mock_discovery.side_effect = build_mock_services_fn( - noop_compute_service) - - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'v2-nightly', - 'accelerator_type': - 'v3-32', - 'api_version': - 'V2_ALPHA1', - 'network_endpoints': [{ - 'ipAddress': f'10.1.0.{index}', - 'port': '8470', - } for index in range(4)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - cluster = cr.get_cluster() - - expected_client_workers = [ - ClientWorker( - internal_ip=f'10.1.0.{index}', - machine_type='v3-32', - zone='fake-zone', - hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}') for index in range(4) - ] - expected_service_workers = [ - ServiceWorker( - internal_ip=f'10.1.0.{ip}', - port='8470', - machine_type='v3-32', - zone='fake-zone', - runtime_version='v2-nightly', - tpu='fake-pod') for ip in range(4) - ] - expected = Cluster( - expected_client_workers, - expected_service_workers, - client_master_ip='10.1.0.0') - self.assertEqual(expected, cluster) - - def test_healthy_tpuvm_cluster(self): - # Using TPUVM flavor of metadata. - mock.patch.object(ClusterResolver, 'get_instance_metadata', - mock_request_tpuvm_metadata).start() - noop_compute_service = build_mock_compute_service({}, {}) - self.mock_discovery.side_effect = build_mock_services_fn( - noop_compute_service) - - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'v2-nightly', - 'accelerator_type': - 'v3-32', - 'api_version': - 'V2_ALPHA1', - 'network_endpoints': [{ - 'ipAddress': f'10.1.0.{index}', - 'port': '8470', - } for index in range(4)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - cluster = cr.get_cluster() - - expected_client_workers = [ - ClientWorker( - internal_ip=f'10.1.0.{index}', - machine_type='v3-32', - zone='fake-zone', - hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}') for index in range(4) - ] - expected_service_workers = [ - ServiceWorker( - internal_ip=f'10.1.0.{ip}', - port='8470', - machine_type='v3-32', - zone='fake-zone', - runtime_version='v2-nightly', - tpu='fake-pod') for ip in range(4) - ] - expected = Cluster( - expected_client_workers, - expected_service_workers, - client_master_ip='10.1.0.0') - self.assertEqual(expected, cluster) - mock.patch.object(ClusterResolver, 'get_instance_metadata', - mock_request_metadata).start() - - def test_bad_cluster(self): - list_instances_map = { - 'fake-ig': { - 'kind': - 'compute#instanceGroupsListInstances', - 'items': [ - gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') - for c in 'abc' - ], - }, - } - instance_resp_map = { - 'fake-ig-' + c: - gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', - '10.0.0.' + ip, 'RUNNING') - for c, ip in zip('abcd', '0123') - } - compute_service = build_mock_compute_service(instance_resp_map, - list_instances_map) - self.mock_discovery.side_effect = build_mock_services_fn(compute_service) - - tpu_map = { - 'fake-pod': { - 'state': - 'READY', - 'health': - 'HEALTHY', - 'runtime_version': - 'pytorch-nightly', - 'accelerator_type': - 'v3-32', - 'network_endpoints': [{ - 'ipAddress': f'10.0.0.{ip}', - 'port': '8470' - } for ip in range(4)], - } - } - self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(tpu_map) - - tpus = list(tpu_map.keys()) - cr = ClusterResolver(tpus) - self.assertRaisesRegex( - RuntimeError, - 'The client_workers and service_workers must have a 1:1 mapping', - cr.get_cluster) - - -if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/core/_xrt_run_server.py b/torch_xla/core/_xrt_run_server.py deleted file mode 100644 index 74c4b6936dee..000000000000 --- a/torch_xla/core/_xrt_run_server.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch_xla -import sys - -if __name__ == '__main__': - assert len(sys.argv) == 2, 'Need to provide the local service port' - torch_xla._XLAC._run_xrt_local_service(int(sys.argv[1])) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c6a4bd399441..24cf6cdb8948 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1306,8 +1306,6 @@ void InitXlaModuleBindings(py::module m) { const std::vector& operands, py::dict args) { return op_builder::CreateOp(builder, opname, operands, args); }); - m.def("_run_xrt_local_service", - [](uint64_t service_port) { runtime::RunLocalService(service_port); }); m.def("_xla_sgd_optimizer_step_", [](const at::Tensor& found_inf, at::Tensor& step, at::Tensor& param, at::Tensor& buf, const at::Tensor& d_p, double weight_decay, diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 956f48a2f592..384bcc2fb810 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -20,11 +20,6 @@ exports_files([ "tf_exported_symbols.lds", ]) -config_setting( - name = "disable_xrt", - define_values = {"disable_xrt": "true"}, -) - tf_proto_library_cc( name = "mesh_service_proto", srcs = ["mesh_service.proto"], @@ -64,22 +59,12 @@ cc_library( hdrs = [ "runtime.h", ], - local_defines = select({ - ":disable_xrt": ["DISABLE_XRT"], - "//conditions:default": [], - }), deps = [ ":computation_client", ":env_vars", ":pjrt_computation_client", "@org_tensorflow//tensorflow/tsl/platform:stacktrace", - ] + select({ - ":disable_xrt": [], - "//conditions:default": [ - ":xrt_computation_client", - ":xrt_local_service", - ], - }), + ], ) cc_library( @@ -140,75 +125,6 @@ cc_library( ], ) -cc_library( - name = "xrt_computation_client", - srcs = [ - "xrt_computation_client.cc", - ], - hdrs = [ - "xrt_computation_client.h", - ], - deps = [ - ":cache", - ":computation_client", - ":debug_macros", - ":env_vars", - ":mesh_service", - ":multi_wait", - ":sys_util", - ":thread_pool", - ":triggered_task", - ":types", - ":unique", - ":util", - ":xla_util", - ":xrt_local_service", - ":xrt_session", - ":xrt_session_cache", - "@org_tensorflow//tensorflow:grpc++", - "@org_tensorflow//tensorflow/cc:client_session", - "@org_tensorflow//tensorflow/cc:scope", - "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_device", - "@org_tensorflow//tensorflow/compiler/xla:debug_options_flags", - "@org_tensorflow//tensorflow/compiler/xla:literal", - "@org_tensorflow//tensorflow/compiler/xla:literal_util", - "@org_tensorflow//tensorflow/compiler/xla:shape_util", - "@org_tensorflow//tensorflow/compiler/xla:xla_proto_cc", - "@org_tensorflow//tensorflow/compiler/xla/client", - "@org_tensorflow//tensorflow/compiler/xla/client:global_data", - "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", - "@org_tensorflow//tensorflow/compiler/xla/rpc:grpc_stub", - "@org_tensorflow//tensorflow/compiler/xla/service:cpu_plugin", - "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", - "@org_tensorflow//tensorflow/compiler/xla:statusor", - "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc", - "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", - "@org_tensorflow//tensorflow/compiler/xla/hlo/ir:hlo", - "@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner", - "@org_tensorflow//tensorflow/compiler/xrt:xrt_proto_cc", - "@org_tensorflow//tensorflow/compiler/xrt:xrt_server", - "@org_tensorflow//tensorflow/compiler/xrt:xrt_utils", - "@org_tensorflow//tensorflow/compiler/xrt/cc:xrt_ops", - "@org_tensorflow//tensorflow/core:core_cpu", - "@org_tensorflow//tensorflow/core:framework_internal", - "@org_tensorflow//tensorflow/core:lib", - "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", - "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_runtime", - "@org_tensorflow//tensorflow/core/protobuf/tpu:topology_proto_cc", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor:stream_executor_impl", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:optional", - ] + if_cuda_is_configured([ - "@org_tensorflow//tensorflow/compiler/jit:xla_gpu_device", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor:cuda_platform", - ]) + if_with_tpu_support([ - "@org_tensorflow//tensorflow/compiler/jit:xla_tpu_device", - "@org_tensorflow//tensorflow/compiler/jit:xla_tpu_jit", - ]), -) - cc_library( name = "cache", hdrs = ["cache.h"], @@ -395,12 +311,6 @@ cc_library( ], ) -cc_library( - name = "triggered_task", - srcs = ["triggered_task.cc"], - hdrs = ["triggered_task.h"], -) - cc_library( name = "types", hdrs = ["types.h"], @@ -456,7 +366,6 @@ cc_library( ":tf_logging", ":types", ":util", - ":xrt_session", "@com_google_absl//absl/types:span", "@org_tensorflow//tensorflow/compiler/xla:shape_util", "@org_tensorflow//tensorflow/compiler/xla:status_macros", @@ -486,51 +395,6 @@ cc_test( ], ) -cc_library( - name = "xrt_local_service", - srcs = ["xrt_local_service.cc"], - hdrs = ["xrt_local_service.h"], - deps = [ - ":debug_macros", - ":xrt_session", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@org_tensorflow//tensorflow/compiler/xla:types", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/tpu:tpu_initializer_helper", - "@org_tensorflow//tensorflow/core:lib", - "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", - "@org_tensorflow//tensorflow/tsl/platform:errors", - "@org_tensorflow//tensorflow/tsl/platform:status", - ], -) - -cc_library( - name = "xrt_session_cache", - srcs = ["xrt_session_cache.cc"], - hdrs = ["xrt_session_cache.h"], - deps = [ - ":metrics", - ":sys_util", - ":xrt_session", - "@org_tensorflow//tensorflow/compiler/xla:types", - ], -) - -cc_library( - name = "xrt_session", - srcs = ["xrt_session.cc"], - hdrs = ["xrt_session.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@org_tensorflow//tensorflow/cc:cc_ops", - "@org_tensorflow//tensorflow/cc:client_session", - "@org_tensorflow//tensorflow/cc:scope", - "@org_tensorflow//tensorflow/compiler/xla:types", - ], -) - cc_binary( name = "libxla_computation_client.so", linkopts = select({ diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 82fd91068f6c..b7fd14a7c72c 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -3,11 +3,6 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" -#ifndef DISABLE_XRT -#include "torch_xla/csrc/runtime/xrt_computation_client.h" -#include "torch_xla/csrc/runtime/xrt_local_service.h" -#endif - namespace torch_xla { namespace runtime { namespace { @@ -25,11 +20,7 @@ ComputationClient* CreateClient() { if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { client = new PjRtComputationClient(); } else { -#ifndef DISABLE_XRT - client = new XrtComputationClient(); -#else XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; -#endif } XLA_CHECK(client != nullptr); @@ -49,26 +40,5 @@ ComputationClient* GetComputationClientIfInitialized() { return g_computation_client.load(); } -void RunLocalService(uint64_t service_port) { -#ifndef DISABLE_XRT - try { - XrtLocalService* service = new XrtLocalService( - "localservice|localhost:" + std::to_string(service_port), - "localservice", 0); - service->Start(); - service->Join(); - } catch (const std::runtime_error& error) { - if (std::string(error.what()).find("Couldn't open device: /dev/accel0") != - std::string::npos) { - TF_LOG(INFO) << "Local service has been created by other process, return"; - } else { - throw; - } - } -#else - XLA_ERROR() << "PyTorch/XLA was not built with XRT support." << std::endl; -#endif -} - } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/triggered_task.cc b/torch_xla/csrc/runtime/triggered_task.cc deleted file mode 100644 index 1e17b7ce4f80..000000000000 --- a/torch_xla/csrc/runtime/triggered_task.cc +++ /dev/null @@ -1,74 +0,0 @@ -#include "torch_xla/csrc/runtime/triggered_task.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -TriggeredTask::TriggeredTask(std::function function, size_t num_threads) - : function_(std::move(function)), running_(num_threads) { - // We set running_ to num_threads because until the threads reach the - // condition wait point (the cv_.wait() call) in the Runner() function, they - // are effectively running. - for (size_t i = 0; i < num_threads; ++i) { - threads_.emplace_back(new std::thread([this]() { Runner(); })); - } -} - -void TriggeredTask::Stop() { - { - std::lock_guard lock(mutex_); - stopped_ = true; - } - run_cv_.notify_all(); - cv_.notify_all(); - for (auto& thread : threads_) { - thread->join(); - } -} - -size_t TriggeredTask::Activate() { - bool notify = false; - size_t run_id; - { - std::lock_guard lock(mutex_); - notify = !activated_; - activated_ = true; - run_id = run_id_ + running_; - } - if (notify) { - cv_.notify_one(); - } - return run_id; -} - -size_t TriggeredTask::WaitForRun(size_t run_id) { - std::unique_lock lock(mutex_); - ++run_waiters_; - run_cv_.wait(lock, [this, run_id] { return run_id_ > run_id || stopped_; }); - --run_waiters_; - return run_id_; -} - -void TriggeredTask::Runner() { - while (true) { - { - std::unique_lock lock(mutex_); - ++run_id_; - if (run_waiters_ > 0) { - run_cv_.notify_all(); - } - --running_; - cv_.wait(lock, [this] { return activated_ || stopped_; }); - if (stopped_) { - break; - } - ++running_; - activated_ = false; - } - function_(); - } -} - -} // namespace util -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/triggered_task.h b/torch_xla/csrc/runtime/triggered_task.h deleted file mode 100644 index bde9c1605fda..000000000000 --- a/torch_xla/csrc/runtime/triggered_task.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef XLA_CLIENT_TRIGGERED_TASK_H_ -#define XLA_CLIENT_TRIGGERED_TASK_H_ - -#include -#include -#include -#include -#include -#include - -namespace torch_xla { -namespace runtime { -namespace util { - -// Wraps a function which should be run many times upon user activations. -class TriggeredTask { - public: - // Note that if num_threads > 1, the function will be run concurrently from - // multiple threads, so it will have to be thread safe. This condition does - // not apply if num_threads is 1. - TriggeredTask(std::function function, size_t num_threads); - - // Stops the background thread and waits for it to complete. - void Stop(); - - // Triggers a function run. If the function is already running, it will run - // again immediately after it completes. Returns tthe value of thte run-ID the - // caller should eventually wait with the WaitForRun() API, to be sure that a - // full function run happened after its Activate() call. - size_t Activate(); - - // Wait until a run-ID returned by the Activate() API completed. Returns the - // value of the current run-ID. If such value or less or equal to run_id, the - // wait did not complete successfully. - size_t WaitForRun(size_t run_id); - - private: - // Function implementing the main thread loop running the user function. - void Runner(); - - std::function function_; - std::mutex mutex_; - std::condition_variable cv_; - std::condition_variable run_cv_; - size_t run_id_ = 0; - size_t run_waiters_ = 0; - size_t running_ = 0; - bool activated_ = false; - bool stopped_ = false; - std::vector> threads_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_TRIGGERED_TASK_H_ diff --git a/torch_xla/csrc/runtime/xrt_computation_client.cc b/torch_xla/csrc/runtime/xrt_computation_client.cc deleted file mode 100644 index ade446a9f39a..000000000000 --- a/torch_xla/csrc/runtime/xrt_computation_client.cc +++ /dev/null @@ -1,2302 +0,0 @@ -#include "torch_xla/csrc/runtime/xrt_computation_client.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xrt/xrt_util.h" -#include "tensorflow/tsl/framework/allocator.h" -#include "tensorflow/tsl/lib/math/math_util.h" -#include "tensorflow/tsl/platform/net.h" -#include "tensorflow/tsl/profiler/lib/traceme.h" -#include "tensorflow/tsl/util/device_name_utils.h" -#include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/multi_wait.h" -#include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" -#include "torch_xla/csrc/runtime/util.h" -#include "torch_xla/csrc/runtime/xla_util.h" - -namespace torch_xla { -namespace runtime { -namespace { - -static const char* const kLocalService = "localservice"; - -// A simple Tensorflow Allocator which caches Tensor allocations in order to -// avoid paying the kernel's clear_page_c() price. -class TensorAllocator : public tensorflow::Allocator { - struct AllocKey { - struct Hash { - size_t operator()(const AllocKey& hk) const { - return util::StdHashCombine(hk.alignment, hk.num_bytes); - } - }; - - bool operator==(const AllocKey& rhs) const { - return num_bytes == rhs.num_bytes && alignment == rhs.alignment; - } - - size_t alignment = 0; - size_t num_bytes = 0; - }; - - struct AllocBlocks { - AllocBlocks(const AllocKey& alloc_key) : alloc_key(alloc_key) {} - - AllocKey alloc_key; - std::vector blocks; - }; - - using AllocList = std::list; - - public: - static TensorAllocator* Get() { - static size_t max_size = - sys_util::GetEnvInt("XLA_TENSOR_ALLOCATOR_MAXSIZE", 1000000000); - static TensorAllocator* allocator = new TensorAllocator(max_size); - return allocator; - } - - std::string Name() override { return "XLA_TensorAllocator"; } - - void* AllocateRaw(size_t alignment, size_t num_bytes) override { - // We use an alignment-sized area before the memory returned to the caller, - // to store a pointer to its AllocBlocks. - alignment = std::max(alignment, sizeof(void*)); - // To call aligned_alloc(), num_bytes must be multiple of alignment. - num_bytes = tsl::MathUtil::CeilOfRatio(num_bytes, alignment) * alignment; - - AllocKey alloc_key = {alignment, num_bytes}; - void* block = nullptr; - AllocBlocks* alloc_blocks = nullptr; - std::lock_guard lock(lock_); - auto it = allocs_.find(alloc_key); - if (it != allocs_.end()) { - alloc_blocks = &*it->second; - if (!alloc_blocks->blocks.empty()) { - block = alloc_blocks->blocks.back(); - alloc_blocks->blocks.pop_back(); - } - // LRU - alloc_list_.splice(alloc_list_.begin(), alloc_list_, it->second); - } else { - allocs_.emplace(alloc_key, - alloc_list_.insert(alloc_list_.begin(), alloc_key)); - alloc_blocks = &alloc_list_.front(); - } - if (block == nullptr) { - TrimCache(alloc_key.num_bytes); - block = NewBlock(alloc_blocks); - } - return block; - } - - void DeallocateRaw(void* ptr) override { - if (ptr != nullptr) { - // The pointer to AllocBlocks is right before the user memory. - AllocBlocks* alloc_blocks = reinterpret_cast(ptr)[-1]; - std::lock_guard lock(lock_); - if (alloc_blocks->alloc_key.num_bytes < max_size_) { - alloc_blocks->blocks.push_back(ptr); - } else { - // We do not cache blocks whose size is bigger than the max cache size. - FreeBlock(ptr, alloc_blocks); - } - } - } - - private: - explicit TensorAllocator(size_t max_size) : max_size_(max_size) {} - - void* NewBlock(AllocBlocks* alloc_blocks) { - // We allocate an extra alignment sized area to store the AllocBlocks - // pointer. - void* ptr = ::aligned_alloc( - alloc_blocks->alloc_key.alignment, - alloc_blocks->alloc_key.alignment + alloc_blocks->alloc_key.num_bytes); - XLA_CHECK(ptr != nullptr); - ptr = reinterpret_cast(ptr) + alloc_blocks->alloc_key.alignment; - // Store the pointer to AllocBlocks right before the user memory. - reinterpret_cast(ptr)[-1] = alloc_blocks; - size_ += alloc_blocks->alloc_key.num_bytes; - return ptr; - } - - void FreeBlock(void* ptr, AllocBlocks* alloc_blocks) { - size_ -= alloc_blocks->alloc_key.num_bytes; - std::free(reinterpret_cast(ptr) - alloc_blocks->alloc_key.alignment); - } - - void TrimCache(size_t num_bytes) { - auto it = alloc_list_.rbegin(); - for (; size_ + num_bytes > max_size_ && it != alloc_list_.rend(); ++it) { - AllocBlocks* alloc_blocks = &*it; - while (!alloc_blocks->blocks.empty() && size_ + num_bytes > max_size_) { - FreeBlock(alloc_blocks->blocks.back(), alloc_blocks); - alloc_blocks->blocks.pop_back(); - } - } - } - - size_t max_size_ = 0; - std::mutex lock_; - size_t size_ = 0; - AllocList alloc_list_; - std::unordered_map allocs_; -}; - -bool ShouldStartLocalService(const std::set& devices) { - // In the tpuvm pod setup, LocalService will be started in a separate process - bool tpuvm_mode = sys_util::GetEnvBool(env::kEnvTpuvmMode, false); - int shard_ordinal = sys_util::GetEnvInt(env::kEnvShardOrdinal, -1); - int world_size = sys_util::GetEnvInt(env::kEnvWorldSize, -1); - if (tpuvm_mode && (shard_ordinal % 8 == 0) && (world_size > 8)) { - return false; - } - if (!sys_util::GetEnvBool(env::kEnvStartService, true)) { - return false; - } - // Only the process with CPU device and GPU device should start the local - // service. - for (const std::string& device : devices) { - if (device.find("CPU", 0) == 0 || device.find("GPU", 0) == 0) { - return true; - } - } - return false; -} - -std::string StripPrefix(const std::string& value, const std::string& prefix) { - return value.find(prefix) == 0 ? value.substr(prefix.size()) : value; -} - -tensorflow::DeviceNameUtils::ParsedName ParseFullXrtDevice( - const std::string& device) { - tensorflow::DeviceNameUtils::ParsedName parsed_device; - XLA_CHECK(tensorflow::DeviceNameUtils::ParseFullName(device, &parsed_device)) - << device; - XLA_CHECK(parsed_device.has_job && parsed_device.has_task && - parsed_device.has_id && parsed_device.has_type) - << "Missing device information (" << device << "): " - << tensorflow::DeviceNameUtils::ParsedNameToString(parsed_device); - return parsed_device; -} - -void MaybeSaveLongCompileHlo(double compile_time, - const xla::XlaComputation& computation) { - static double compile_time_threshold = sys_util::GetEnvDouble( - "XLA_COMPILE_TIME_THRESHOLD", std::numeric_limits::max()); - static const std::string* hlo_folder = new std::string( - sys_util::GetEnvString("XLA_SLOW_COMPILE_HLO_FOLDER", "")); - if (compile_time > compile_time_threshold && !hlo_folder->empty()) { - static std::atomic hlo_count(0); - std::stringstream ss; - ss << *hlo_folder << "/hlo_module-" << hlo_count.fetch_add(1) << "-" - << static_cast(compile_time) << "s.txt"; - std::string hlo_text = - ConsumeValue(util::GetComputationHloText(computation)); - std::ofstream graph_file(ss.str()); - graph_file << hlo_text << "\n"; - } -} - -template -T ParseProto(const tensorflow::Tensor& tensor) { - const tensorflow::tstring& tensor_data = - tensor.scalar()(); - // The ParseFromArray() API takes an 'int' as size argument, so the tensor - // size better be fitting the 'int' domain. - XLA_CHECK_LE(tensor_data.size(), - static_cast(std::numeric_limits::max())); - T proto; - XLA_CHECK(proto.ParseFromArray(tensor_data.data(), tensor_data.size())); - return proto; -} - -int64_t GetMaxTensorsPartitionSize() { - // We need to limit the amount of data we send to the XRT backend since - // Protocol Buffers does not allow sizes greater than 2GB. We keep some margin - // to avoid extra metadata pushing us over the limit. - static int64_t max_partition_size = - sys_util::GetEnvInt("XRT_MAX_TENSORS_PARTITION", 1800000000); - return max_partition_size; -} - -struct DeviceCountDefaults { - int num_tpus = 0; - int num_gpus = 0; - int num_cpus = 1; -}; - -std::string MakeGrpcEndPoint(const std::string& server) { - return server.compare(0, 7, "grpc://") == 0 ? server - : absl::StrCat("grpc://", server); -} - -std::string GetXrtDevicePath(const std::string& worker, int task_no, - const std::string& device_type, int ordinal) { - return absl::StrCat("/job:", worker, "/replica:0/task:", task_no, - "/device:", device_type, ":", ordinal); -} - -std::string BuildTaskDeviceKey(int task_no, const std::string& kind) { - return absl::StrCat(task_no, ":", kind); -} - -tensorflow::DeviceNameUtils::ParsedName ParseXrtDevice( - const std::string& device) { - tensorflow::DeviceNameUtils::ParsedName parsed_device; - XLA_CHECK( - tensorflow::DeviceNameUtils::ParseFullName(device, &parsed_device) && - parsed_device.has_job && parsed_device.has_task && parsed_device.has_id && - parsed_device.has_type) - << device; - return parsed_device; -} - -bool IsLocalDevice(const XrtComputationClient::Worker& worker, - const tensorflow::DeviceNameUtils::ParsedName& parsed_device, - const std::map& dev_task_map) { - if (worker.name != parsed_device.job || - worker.task_no != parsed_device.task) { - return false; - } - std::string mp_device = XrtComputationClient::GetMultiProcessingDevice(); - if (mp_device.empty()) { - return true; - } - XrtComputationClient::Device device(mp_device); - std::string task_device_key = - BuildTaskDeviceKey(parsed_device.task, device.kind); - auto it = dev_task_map.find(task_device_key); - return it != dev_task_map.end() - ? (device.ordinal == it->second + parsed_device.id) - : false; -} - -std::map BuildDeviceTaskMap( - const XrtComputationClient::Options& options) { - // Builds a map from "TASK:DEV_KIND" (ie, "0:TPU") keys to the minimum global - // device ordinal assigned for that task+devkind couple. - std::map dev_task_map; - for (auto& device_xrt_device : options.global_device_map) { - XrtComputationClient::Device global_device(device_xrt_device.first); - tensorflow::DeviceNameUtils::ParsedName parsed_device = - ParseXrtDevice(device_xrt_device.second); - std::string task_device_key = - BuildTaskDeviceKey(parsed_device.task, global_device.kind); - util::InsertCombined(&dev_task_map, task_device_key, global_device.ordinal, - [](int a, int b) { return std::min(a, b); }); - } - return dev_task_map; -} - -void PopulateLocalDevices(XrtComputationClient::Options* options) { - std::string local_worker = sys_util::GetEnvString(env::kEnvLocalWorker, ""); - XrtComputationClient::Worker worker("", -1); - if (!local_worker.empty()) { - worker = XrtComputationClient::ParseWorker(local_worker); - } - auto dev_task_map = BuildDeviceTaskMap(*options); - std::map min_ordinals; - for (auto& device_xrt_device : options->global_device_map) { - if (worker.task_no >= 0) { - tensorflow::DeviceNameUtils::ParsedName parsed_device = - ParseXrtDevice(device_xrt_device.second); - if (!IsLocalDevice(worker, parsed_device, dev_task_map)) { - continue; - } - } - options->devices.insert(device_xrt_device.first); - - XrtComputationClient::Device global_device(device_xrt_device.first); - util::InsertCombined(&min_ordinals, global_device.kind, - global_device.ordinal, - [](int a, int b) { return std::min(a, b); }); - } - for (auto kind : {"TPU", "GPU", "CPU"}) { - auto it = min_ordinals.find(kind); - if (it != min_ordinals.end()) { - options->default_device = absl::StrCat(kind, ":", it->second); - break; - } - } -} - -void AddXrtHostDevices(const std::string& worker_name, int task_no, - const std::string& server, - const DeviceCountDefaults& device_counts, - std::map* device_ordinals, - XrtComputationClient::Options* options) { - struct Devices { - const char* name; - const char* tf_name; - int64_t count; - } const devices[] = { - {"TPU", "TPU", - sys_util::GetEnvInt(env::kEnvNumTpu, device_counts.num_tpus)}, - {"GPU", "XLA_GPU", - sys_util::GetEnvInt(env::kEnvNumGpu, device_counts.num_gpus)}, - {"CPU", "XLA_CPU", device_counts.num_cpus}, - }; - options->workers_map.emplace( - XrtComputationClient::Worker(worker_name, task_no), - MakeGrpcEndPoint(server)); - for (auto& device : devices) { - int& device_ordinal = (*device_ordinals)[device.name]; - for (int j = 0; j < device.count; ++j, ++device_ordinal) { - std::string device_name = absl::StrCat(device.name, ":", device_ordinal); - std::string xrt_device_name = - GetXrtDevicePath(worker_name, task_no, device.tf_name, j); - options->global_device_map.emplace(device_name, xrt_device_name); - } - } -} - -bool ParseEnvBasedTpuClusterConfig(XrtComputationClient::Options* options) { - std::string tpu_config = sys_util::GetEnvString(env::kEnvTpuConfig, ""); - if (tpu_config.empty()) { - return false; - } - std::map device_ordinals; - std::vector spec_parts = absl::StrSplit(tpu_config, '|'); - XLA_CHECK(!spec_parts.empty()) << tpu_config; - DeviceCountDefaults device_counts; - device_counts.num_tpus = 8; - for (const auto& spec : spec_parts) { - std::vector host_parts = absl::StrSplit(spec, ';'); - XLA_CHECK_EQ(host_parts.size(), 3) << spec; - AddXrtHostDevices(host_parts[0], std::stoi(host_parts[1]), host_parts[2], - device_counts, &device_ordinals, options); - } - return true; -} - -bool ParseMeshConfig( - XrtComputationClient::Options* options, - std::unique_ptr* topology_proto) { - service::MeshClient* client = service::MeshClient::Get(); - if (client == nullptr) { - return false; - } - std::string local_worker_env = - sys_util::GetEnvString(env::kEnvLocalWorker, ""); - XLA_CHECK(!local_worker_env.empty()) - << "In a mesh client setup the XRT_LOCAL_WORKER must be specified"; - - XrtComputationClient::Worker local_worker = - XrtComputationClient::ParseWorker(local_worker_env); - int host_ordinal = sys_util::GetEnvInt(env::kEnvHostOrdinal, 0); - - TF_LOG(INFO) << "Fetching mesh configuration for worker " << local_worker.name - << " (host_ordinal=" << host_ordinal - << "):" << local_worker.task_no << " from mesh service at " - << client->address(); - service::grpc::Config config = client->GetConfig(host_ordinal); - TF_VLOG(3) << "Mesh Config: " << config.DebugString(); - - std::string mp_device = XrtComputationClient::GetMultiProcessingDevice(); - for (auto& config_worker : config.workers()) { - XrtComputationClient::Worker worker(config_worker.name(), - config_worker.task_no()); - options->workers_map.emplace(worker, config_worker.address()); - - for (auto& device : config_worker.devices()) { - XrtComputationClient::Device local_device(device.local_name()); - options->global_device_map.emplace( - device.global_name(), - GetXrtDevicePath(worker.name, worker.task_no, local_device.kind, - local_device.ordinal)); - if (local_worker == worker && - (mp_device.empty() || device.global_name() == mp_device)) { - options->devices.insert(device.global_name()); - } - } - } - (*topology_proto) = absl::make_unique( - std::move(*config.mutable_proto())); - return true; -} - -bool ParseEnvDeviceCounts(XrtComputationClient::Options* options) { - DeviceCountDefaults device_counts; - device_counts.num_tpus = sys_util::GetEnvInt(env::kEnvNumTpu, 0); - device_counts.num_gpus = sys_util::GetEnvInt(env::kEnvNumGpu, 0); - if (device_counts.num_tpus > 0 || device_counts.num_gpus > 0) { - std::map device_ordinals; - std::string host_port = - absl::StrCat("localhost:", tsl::internal::PickUnusedPortOrDie()); - AddXrtHostDevices("localservice", 0, host_port, device_counts, - &device_ordinals, options); - } - return !options->global_device_map.empty(); -} - -bool ParseEnvDevices(XrtComputationClient::Options* options) { - std::string device_spec = sys_util::GetEnvString(env::kEnvDeviceMap, ""); - std::string workers_spec = sys_util::GetEnvString(env::kEnvWorkers, ""); - if (!device_spec.empty() && !workers_spec.empty()) { - for (const auto& device_target : absl::StrSplit(device_spec, '|')) { - std::vector parts = absl::StrSplit(device_target, ';'); - XLA_CHECK_EQ(parts.size(), 2) << device_target; - options->global_device_map.emplace(parts[0], parts[1]); - } - for (const auto& name_target : absl::StrSplit(workers_spec, '|')) { - std::vector parts = absl::StrSplit(name_target, ';'); - XLA_CHECK_EQ(parts.size(), 2) << name_target; - options->workers_map.emplace(XrtComputationClient::ParseWorker(parts[0]), - MakeGrpcEndPoint(parts[1])); - } - } - return !options->global_device_map.empty(); -} - -} // namespace - -const int64_t DataHandleLocker::dummy_handle = -151235; - -XrtComputationClient::Device::Device(const std::string& device_str) { - std::vector parts = absl::StrSplit(device_str, ':'); - XLA_CHECK_EQ(parts.size(), 2) << device_str; - kind = std::move(parts[0]); - ordinal = std::stoi(parts[1]); -} - -void XrtComputationClient::XrtData::Assign(const Data& data) { - const XrtData& xrt_data = dynamic_cast(data); - if (&xrt_data != this) { - handle_ptr = xrt_data.handle_ptr; - } -} - -XrtComputationClient::XrtComputationClient() - : compilation_cache_(sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 64)), - rng_seed_(0x5a2d296e9) { - // XrtComputationClient::Options options; - std::unique_ptr topology_proto; - if (!ParseEnvBasedTpuClusterConfig(&options_) && - !ParseEnvDeviceCounts(&options_) && !ParseEnvDevices(&options_) && - !ParseMeshConfig(&options_, &topology_proto)) { - XLA_ERROR() << "Missing XLA configuration"; - } - PopulateLocalDevices(&options_); - - tensorflow::ConfigProto config = CreateConfigProto(options_); - std::string local_target = GetLocalTarget(options_); - session_cache_ = absl::make_unique( - config, [this](XrtSession* s) { InitSession(s); }, local_target); - alloc_session_cache_ = - absl::make_unique(config, nullptr, local_target); - - auto default_device_target = - options_.global_device_map.find(options_.default_device); - XLA_CHECK(default_device_target != options_.global_device_map.end()) - << options_.default_device; - for (auto& device : options_.devices) { - XLA_CHECK(options_.global_device_map.find(device) != - options_.global_device_map.end()) - << "Missing device in global map: " << device; - } - for (const auto& dev_target : options_.global_device_map) { - const char* tag = - options_.devices.count(dev_target.first) > 0 ? "LOCAL" : "REMOTE"; - TF_VLOG(1) << "XRT device (" << tag << ") " << dev_target.first << " -> " - << dev_target.second; - } - for (auto& worker_target : options_.workers_map) { - TF_VLOG(1) << "Worker " << worker_target.second - << " for /job:" << worker_target.first.name - << "/replica:0/task:" << worker_target.first.task_no; - } - - TF_VLOG(1) << "XRT default device: " << options_.default_device; - if (ShouldStartLocalService(options_.devices)) { - MaybeCreateLocalService(options_); - } - InitializeDevices(std::move(topology_proto)); - StartHandleReleaser(); -} - -ComputationClient::DataPtr XrtComputationClient::CreateDataPlaceholder( - std::string device, xla::Shape shape) { - return std::make_shared(std::move(device), std::move(shape)); -} - -std::vector -XrtComputationClient::CreateAsyncDatas(absl::Span tensors) { - std::vector results( - tensors.size()); - for (size_t i = 0; i < tensors.size(); ++i) { - // Create a XrtHandle with dummy handle, releasr needs to take the - // real handle upon destructon. - XrtHandlePtr handle_ptr = std::make_shared( - DataHandleLocker::dummy_handle, - [this, device = tensors[i].device](int64_t handle) { - this->ReleaseXrtData(device, handle); - }, - /*async=*/true); - results[i] = std::make_shared(this, tensors[i].device, - tensors[i].shape, handle_ptr); - } - CreateAsyncDataHandlesCounter()->AddValue(results.size()); - return results; -} - -std::vector -XrtComputationClient::LockAsyncDatas( - absl::Span datas) { - std::vector unlcoker; - unlcoker.reserve(datas.size()); - for (int i = 0; i < datas.size(); i++) { - unlcoker.emplace_back( - dynamic_cast(*datas[i]).handle_ptr->LockHandle()); - } - return unlcoker; -} - -std::vector XrtComputationClient::PartitionTransferToServer( - absl::Span tensors) { - int64_t max_partition_size = GetMaxTensorsPartitionSize(); - uint64_t current_size = 0; - std::vector partitions; - for (size_t i = 0; i < tensors.size(); ++i) { - int64_t tensor_size = xla::ShapeUtil::ByteSizeOfElements(tensors[i].shape); - if (current_size + tensor_size > max_partition_size) { - if (partitions.empty() && i > 0) { - partitions.push_back(0); - } - partitions.push_back(i); - current_size = 0; - } - current_size += tensor_size; - } - if (partitions.empty()) { - partitions.push_back(0); - } - return partitions; -} - -std::vector XrtComputationClient::TransferToServer( - absl::Span tensors) { - return TransferToServerHelper(tensors, {}); -} - -void XrtComputationClient::TransferToServer( - absl::Span tensors, absl::Span datas) { - XLA_CHECK_EQ(tensors.size(), datas.size()); - TransferToServerHelper(tensors, datas); - return; -} - -std::vector -XrtComputationClient::TransferToServerHelper( - absl::Span tensors, absl::Span datas) { - auto partitions = PartitionTransferToServer(tensors); - if (partitions.size() == 1) { - // Fast path in case of single partition. Avoid creating threads and - // waiting, since this is the common case. - return TransferToServerInternal(tensors, datas); - } - XLA_COUNTER("XrtPartitionedTransferToServer", 1); - - auto mwait = std::make_shared(partitions.size()); - std::vector results(tensors.size()); - for (size_t i = 0; i < partitions.size(); ++i) { - auto sender = [&, i]() { - size_t base_index = partitions[i]; - size_t length = (i + 1 < partitions.size()) - ? partitions[i + 1] - base_index - : tensors.size() - base_index; - std::vector partitions_results; - // Only pass datas if it is not empty. - if (datas.size()) { - partitions_results = - TransferToServerInternal(tensors.subspan(base_index, length), - datas.subspan(base_index, length)); - } else { - partitions_results = - TransferToServerInternal(tensors.subspan(base_index, length), {}); - } - for (size_t r = 0; r < length; ++r) { - results[base_index + r] = std::move(partitions_results[r]); - } - }; - env::ScheduleIoClosure( - util::MultiWait::Completer(mwait, std::move(sender))); - } - mwait->Wait(); - return results; -} - -std::vector -XrtComputationClient::TransferToServerInternal( - absl::Span tensors, absl::Span datas) { - metrics::TimedSection timed(TransferToServerMetric()); - tsl::profiler::TraceMe activity("TransferToServerInternal", - tsl::profiler::TraceMeLevel::kInfo); - - // If datas are passed in, don't create new datas but modify the passed in - // datas. - bool create_new_data = (datas.size() == 0); - std::mutex lock; - XrtSessionCache::SessionMap session_map; - int64_t total_size = 0; - auto mwait = std::make_shared(tensors.size()); - std::map session_work_map; - { - tsl::profiler::TraceMe activity("TransferToServerTransform", - tsl::profiler::TraceMeLevel::kInfo); - metrics::TimedSection timed(TransferToServerTransformMetric()); - - for (size_t i = 0; i < tensors.size(); ++i) { - auto converter = [&, i]() { - const std::string& xrt_device = - TorchDeviceToXrtDevice(tensors[i].device); - tensorflow::Tensor tensor( - TensorAllocator::Get(), - XlaTypeToDataType(tensors[i].shape.element_type()), - MakeEquivalentTensorShape(tensors[i].shape)); - auto tdata = tensor.tensor_data(); - tensors[i].populate_fn(tensors[i], const_cast(tdata.data()), - tdata.size()); - - { - std::lock_guard slock(lock); - XrtSession* session = GetSessionForXrtDevice( - alloc_session_cache_.get(), xrt_device, &session_map); - SessionWork* session_work = &session_work_map[session]; - tensorflow::Scope device_scope = - session->root()->WithDevice(xrt_device); - const XrtSession::CachedNode& cached_node = GetAllocateNode( - session, device_scope, tensors[i].device, tensors[i].shape); - session_work->feed_inputs.insert({cached_node.holders[0], tensor}); - session_work->outputs_handles.push_back(cached_node.outputs[0]); - session_work->index_mapping.push_back(i); - - total_size += tdata.size(); - } - }; - env::ScheduleClosure( - util::MultiWait::Completer(mwait, std::move(converter))); - } - mwait->Wait(); - } - OutboundDataMetric()->AddSample(total_size); - - mwait->Reset(session_work_map.size()); - std::vector results; - if (create_new_data) { - results.resize(tensors.size()); - } - { - tsl::profiler::TraceMe activity( - [&] { - return tsl::profiler::TraceMeEncode( - "TransferToServerExecute", - {{"total_size", absl::StrCat(std::to_string(total_size), "B")}, - {"num_tensors", std::to_string(tensors.size())}}); - }, - tsl::profiler::TraceMeLevel::kInfo); - for (auto& session_session_work : session_work_map) { - XrtSession* session = session_session_work.first; - SessionWork* session_work = &session_session_work.second; - auto runner = [&, session, session_work]() { - std::vector outputs; - XLA_CHECK_OK(session->session()->Run(session_work->feed_inputs, - session_work->outputs_handles, - &outputs)); - XLA_CHECK_EQ(outputs.size(), session_work->outputs_handles.size()); - - for (size_t i = 0; i < outputs.size(); ++i) { - size_t li = session_work->index_mapping[i]; - if (create_new_data) { - results[li] = std::make_shared( - this, tensors[li].device, tensors[li].shape, - outputs[i].scalar()()); - } else { - dynamic_cast(*datas[li]) - .handle_ptr->update_handle(outputs[i].scalar()()); - } - } - CreateDataHandlesCounter()->AddValue(outputs.size()); - }; - env::ScheduleIoClosure( - util::MultiWait::Completer(mwait, std::move(runner))); - } - mwait->Wait(); - } - return results; -} - -std::vector XrtComputationClient::TransferFromServer( - absl::Span handles) { - metrics::TimedSection timed(TransferFromServerMetric()); - tsl::profiler::TraceMe activity("TransferFromServer", - tsl::profiler::TraceMeLevel::kInfo); - - int64_t max_partition_size = GetMaxTensorsPartitionSize(); - std::list session_maps; - int64_t current_size = 0; - session_maps.emplace_back(); - std::map session_work_map; - for (size_t i = 0; i < handles.size(); ++i) { - const XrtData& xrt_data = dynamic_cast(*handles[i]); - - int64_t shape_size = xla::ShapeUtil::ByteSizeOfElements(xrt_data.shape()); - if (current_size + shape_size >= max_partition_size) { - session_maps.emplace_back(); - current_size = 0; - } - current_size += shape_size; - - XrtSession* session = GetSessionForDevice( - session_cache_.get(), xrt_data.device(), &session_maps.back()); - SessionWork* session_work = &session_work_map[session]; - tensorflow::Scope device_scope = - session->root()->WithDevice(TorchDeviceToXrtDevice(xrt_data.device())); - const XrtSession::CachedNode& cached_node = - GetReadNode(session, device_scope, xrt_data.device()); - session_work->feed_inputs.insert( - {cached_node.holders[0], xrt_data.get_handle()}); - session_work->outputs_handles.push_back(cached_node.outputs[0]); - session_work->index_mapping.push_back(i); - } - - auto mwait = std::make_shared(session_work_map.size()); - std::atomic total_size(0); - std::atomic num_tensors(0); - std::vector results(handles.size()); - for (auto& session_session_work : session_work_map) { - XrtSession* session = session_session_work.first; - SessionWork* session_work = &session_session_work.second; - auto runner = [&, session, session_work]() { - std::vector outputs; - XLA_CHECK_OK(session->session()->Run( - session_work->feed_inputs, session_work->outputs_handles, &outputs)); - XLA_CHECK_EQ(outputs.size(), session_work->outputs_handles.size()); - num_tensors += outputs.size(); - - for (size_t i = 0; i < outputs.size(); ++i) { - size_t li = session_work->index_mapping[i]; - xla::LiteralProto response = ParseProto(outputs[i]); - results[li] = - std::move(xla::Literal::CreateFromProto(response).value()); - total_size += results[li].size_bytes(); - } - }; - env::ScheduleIoClosure( - util::MultiWait::Completer(mwait, std::move(runner))); - } - mwait->Wait(); - InboundDataMetric()->AddSample(total_size.load()); - activity.AppendMetadata([&total_size, &num_tensors]() { - return tsl::profiler::TraceMeEncode( - {{"total_size", absl::StrCat(std::to_string(total_size), "B")}, - {"num_tensors", std::to_string(num_tensors)}}); - }); - return results; -} - -std::vector XrtComputationClient::Compile( - std::vector instances) { - metrics::TimedSection timed(CompileMetric()); - tsl::profiler::TraceMe activity("Compile", - tsl::profiler::TraceMeLevel::kInfo); - - std::mutex lock; - auto mwait = std::make_shared(instances.size()); - std::vector program_shapes(instances.size()); - std::vector results(instances.size()); - std::vector cache_keys(instances.size()); - XrtSessionCache::SessionMap session_map; - std::map session_work_map; - for (size_t i = 0; i < instances.size(); ++i) { - auto builder = [&, this, i]() { - const CompileInstance& instance = instances[i]; - XLA_CHECK(!instance.is_sharded) - << "XrtComputationClient doesn't support SPMD."; - - std::unique_ptr xrt_computation = - CreateXrtComputation(instance.computation, instance.devices, - instance.output_shape); - CompilationCacheKey cache_key( - GetResourceDomain(instance.compilation_device), - xrt_computation->SerializeAsString()); - auto computation_ptr = compilation_cache_.Get(cache_key); - if (computation_ptr == nullptr) { - cache_keys[i] = std::move(cache_key); - program_shapes[i] = - xla::ProgramShape(xrt_computation->config().program_shape()); - - const std::string& xrt_device = - TorchDeviceToXrtDevice(instance.compilation_device); - { - std::lock_guard slock(lock); - XrtSession* session = GetSessionForXrtDevice( - session_cache_.get(), xrt_device, &session_map); - SessionWork* session_work = &session_work_map[session]; - tensorflow::Scope device_scope = - session->root()->WithDevice(xrt_device); - const XrtSession::CachedNode& cached_node = GetCompileNode( - session, device_scope, instance.compilation_device); - session_work->feed_inputs.insert( - {cached_node.holders[0], cache_keys[i].serialized_computation}); - session_work->outputs_handles.push_back(cached_node.outputs[0]); - session_work->index_mapping.push_back(i); - } - } else { - results[i] = computation_ptr; - } - }; - env::ScheduleClosure(util::MultiWait::Completer(mwait, std::move(builder))); - } - mwait->Wait(); - mwait->Reset(session_work_map.size()); - - for (auto& session_and_work : session_work_map) { - XrtSession* session = session_and_work.first; - const SessionWork& session_work = session_and_work.second; - - auto session_runner = [&, this, session]() { - std::vector outputs; - CheckCompileStatus( - session->session()->Run(session_work.feed_inputs, - session_work.outputs_handles, &outputs), - instances, session_work); - XLA_CHECK_EQ(outputs.size(), session_work.outputs_handles.size()); - - double compile_time = timed.Elapsed(); - size_t output_index = 0; - for (auto li : session_work.index_mapping) { - CompileInstance* instance = &instances[li]; - MaybeSaveLongCompileHlo(compile_time, instance->computation); - results[li] = std::make_shared( - this, std::move(instance->computation), program_shapes[li], - std::move(instance->devices), - outputs[output_index].scalar()(), - instance->compilation_device); - ++output_index; - - compilation_cache_.Add(std::move(cache_keys[li]), results[li]); - CreateCompileHandlesCounter()->AddValue(1); - } - }; - env::ScheduleIoClosure( - util::MultiWait::Completer(mwait, std::move(session_runner))); - } - mwait->Wait(); - return results; -} - -void XrtComputationClient::CheckCompileStatus( - const xla::Status& status, const std::vector& instances, - const SessionWork& session_work) { - if (!status.ok()) { - std::vector computations; - std::vector output_shapes; - for (auto li : session_work.index_mapping) { - computations.push_back(&instances[li].computation); - output_shapes.push_back(instances[li].output_shape); - } - util::ReportComputationError(status, computations, output_shapes); - } -} - -std::vector -XrtComputationClient::ExecuteComputation( - const Computation& computation, absl::Span arguments, - const std::string& device, const ExecuteComputationOptions& options) { - metrics::TimedSection timed(ExecuteMetric()); - tsl::profiler::TraceMe activity("ExecuteComputation", - tsl::profiler::TraceMeLevel::kInfo); - - XrtSessionCache::SessionMap session_map; - tensorflow::ClientSession::FeedType feed_inputs; - std::vector exec_ops = CreateExecuteOps( - &session_map, dynamic_cast(computation), - BuildParallelArguments(arguments), options.explode_tuple, {device}, - &feed_inputs); - - XrtSession* session = - GetSessionForDevice(session_cache_.get(), device, &session_map); - std::vector outputs; - util::CheckComputationStatus( - session->session()->Run(feed_inputs, {exec_ops.front()}, &outputs), - {&computation.computation()}, {&computation.program_shape().result()}); - XLA_CHECK_EQ(outputs.size(), 1); - - return GetComputationResults(outputs[0], computation.program_shape().result(), - device); -} - -std::vector> -XrtComputationClient::ExecuteReplicated( - const Computation& computation, - const std::vector>& arguments, - absl::Span devices, - const ExecuteReplicatedOptions& options) { - metrics::TimedSection timed(ExecuteReplicatedMetric()); - tsl::profiler::TraceMe activity("ExecuteReplicated", - tsl::profiler::TraceMeLevel::kInfo); - - XrtSessionCache::SessionMap session_map; - tensorflow::ClientSession::FeedType feed_inputs; - std::vector exec_ops = CreateExecuteOps( - &session_map, dynamic_cast(computation), arguments, - options.explode_tuple, devices, &feed_inputs); - std::vector computations(devices.size()); - std::fill(computations.begin(), computations.end(), &computation); - - return RunComputations(session_map, exec_ops, computations, devices, - feed_inputs); -} - -std::vector> -XrtComputationClient::RunComputations( - const XrtSessionCache::SessionMap& session_map, - const std::vector& exec_ops, - absl::Span computations, - absl::Span devices, - const tensorflow::ClientSession::FeedType& feed_inputs) { - tsl::profiler::TraceMe activity("RunComputations", - tsl::profiler::TraceMeLevel::kInfo); - // In the PyTorch/XRT interface we keep a map (options_.workers_map) from a - // worker+taskno, to the GRPC server which is the entry point for that worker. - // Since XRT could re-distribute ops internally, if we have N hosts - // (worker+taskno), we could have all the workers pointing to a single GRPC - // entry point, or we could have each worker pointing directly to the target - // host. - // The advantage of the latter approach, is that we do not bottleneck - // (especially when feeding inputs) the single GRPC entry point. - // Using the N:1 approach, the session_replicas below will contain a single - // session, and all the replica executions will go through it (and distributed - // by XRT on the service side). - // Chosing the 1:1 approach (one session per worker), we will have N sessions - // within the session_replicas map, which we will be executing independently. - std::map> session_replicas; - for (size_t i = 0; i < devices.size(); ++i) { - auto worker_hostport = GetWorkerForDevice(devices[i]); - XrtSession* session = session_map.at(worker_hostport.second).get(); - session_replicas[session].push_back(i); - } - XLA_CHECK_EQ(computations.size(), devices.size()); - - auto mwait = std::make_shared(session_replicas.size()); - std::vector> results(devices.size()); - for (auto& sess_replica : session_replicas) { - XrtSession* session = sess_replica.first; - const std::vector& replicas = sess_replica.second; - - auto session_runner = [&, this, session]() { - std::vector exec_nodes; - std::vector xla_computations; - std::vector output_shapes; - for (auto replica : replicas) { - exec_nodes.push_back(exec_ops[replica]); - xla_computations.push_back(&computations[replica]->computation()); - output_shapes.push_back( - &computations[replica]->program_shape().result()); - } - std::vector outputs; - util::CheckComputationStatus( - session->session()->Run(feed_inputs, exec_nodes, &outputs), - xla_computations, output_shapes); - XLA_CHECK_EQ(outputs.size(), exec_nodes.size()); - - for (size_t i = 0; i < outputs.size(); ++i) { - auto replica = replicas[i]; - results[replica] = GetComputationResults( - outputs[i], computations[replica]->program_shape().result(), - devices[replica]); - } - }; - env::ScheduleIoClosure( - util::MultiWait::Completer(mwait, std::move(session_runner))); - } - mwait->Wait(); - return results; -} - -std::vector> -XrtComputationClient::ExecuteParallel( - absl::Span computations, - const std::vector>& arguments, - absl::Span devices, - const ExecuteParallelOptions& options) { - metrics::TimedSection timed(ExecuteParallelMetric()); - tsl::profiler::TraceMe activity("ExecuteParallel", - tsl::profiler::TraceMeLevel::kInfo); - - XrtSessionCache::SessionMap session_map; - tensorflow::ClientSession::FeedType feed_inputs; - std::vector exec_ops = - CreateExecuteOps(&session_map, computations, arguments, - options.explode_tuple, devices, &feed_inputs); - return RunComputations(session_map, exec_ops, computations, devices, - feed_inputs); -} - -template -void XrtComputationClient::SetupExecConfig(const Device& device, - T* exec_config) const { - exec_config->set_core_index_in_replica(0); - exec_config->set_rng_seed(rng_seed_); - if (device.kind != "TPU") { - // TPU ignores those fields, and given that the device list can be in the - // thousands for POD scale, we avoid wasting time filling it up. - xrt::CommonExecutionConfig* cmn_config = - exec_config->mutable_common_config(); - cmn_config->set_replica_id(device.ordinal); - cmn_config->set_run_id(1); - for (auto& devstr : options_.devices) { - Device local_device(devstr); - if (local_device.kind == device.kind) { - cmn_config->add_local_replica_mapping(local_device.ordinal); - } - } - } -} - -std::vector XrtComputationClient::ExecuteChained( - absl::Span ops, const std::string& device) { - tsl::profiler::TraceMe activity("ExecuteChained", - tsl::profiler::TraceMeLevel::kInfo); - static int64_t split_mode = sys_util::GetEnvInt("XRT_SPLIT_CHAINED_EXEC", 0); - return split_mode ? ExecuteChainedSplit(ops, device) - : ExecuteChainedXrt(ops, device); -} - -std::vector XrtComputationClient::ExecuteChainedXrt( - absl::Span ops, const std::string& device) { - metrics::TimedSection timed(ExecuteChainedMetric()); - - XrtSessionCache::SessionMap session_map; - const std::string& xrt_device = TorchDeviceToXrtDevice(device); - tensorflow::ClientSession::FeedType feed_inputs; - XrtSession* session = - GetSessionForXrtDevice(session_cache_.get(), xrt_device, &session_map); - tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); - - xrt::XRTChainedExecuteConfig exec_config; - SetupExecConfig(Device(device), &exec_config); - - xrt::XRTChainedExecutePlan plan; - std::vector result_shapes; - for (size_t i = 0; i < ops.size(); ++i) { - const ExecuteChainedOp& op = ops[i]; - xrt::XRTChainedExecuteOp* plan_op = plan.add_ops(); - const xla::Shape* op_shape = nullptr; - if (op.device_data != nullptr) { - const XrtData& xrt_data = dynamic_cast(*op.device_data); - op_shape = &xrt_data.shape(); - plan_op->set_data_handle(xrt_data.get_handle()); - } else { - const XrtComputation& xrt_computation = - dynamic_cast(*op.computation); - op_shape = &xrt_computation.program_shape().result(); - plan_op->set_computation_handle(xrt_computation.get_handle()); - for (auto& input : op.inputs) { - XLA_CHECK_LT(input.op_index, i); - - xrt::XRTChainedExecuteOp::Input* plan_input = plan_op->add_inputs(); - plan_input->set_op_index(input.op_index); - if (input.output_index) { - plan_input->set_output_index(*input.output_index + 1); - } - } - } - for (auto& output : op.outputs) { - XLA_CHECK(op_shape != nullptr); - - xrt::XRTChainedExecuteOp::Output* plan_output = plan_op->add_outputs(); - plan_output->set_result_index(output.result_index); - if (output.result_index >= result_shapes.size()) { - result_shapes.resize(output.result_index + 1); - } - if (output.output_index) { - plan_output->set_output_index(*output.output_index + 1); - result_shapes[output.result_index] = - xla::ShapeUtil::GetTupleElementShape(*op_shape, - *output.output_index); - } else { - result_shapes[output.result_index] = *op_shape; - } - } - } - - const XrtSession::CachedNode& cached_node = - GetExecuteChainedNode(session, device_scope, device); - feed_inputs.insert({cached_node.holders[0], plan.SerializeAsString()}); - feed_inputs.insert({cached_node.holders[1], exec_config.SerializeAsString()}); - - std::vector outputs; - util::CheckComputationStatus( - session->session()->Run(feed_inputs, {cached_node.outputs[0]}, &outputs), - {}, {}); - XLA_CHECK_EQ(outputs.size(), 1); - - std::vector results; - auto handles_vec = outputs[0].vec(); - for (int64_t i = 0; i < handles_vec.size(); ++i) { - results.push_back(std::make_shared( - this, device, std::move(result_shapes.at(i)), handles_vec(i))); - } - CreateDataHandlesCounter()->AddValue(results.size()); - return results; -} - -std::vector -XrtComputationClient::ExecuteChainedSplit( - absl::Span ops, const std::string& device) { - metrics::TimedSection timed(ExecuteChainedMetric()); - - std::vector uses(ops.size(), 0); - for (auto& op : ops) { - for (auto& input : op.inputs) { - uses[input.op_index] += 1; - } - } - XrtSessionCache::SessionMap session_map; - const std::string& xrt_device = TorchDeviceToXrtDevice(device); - XrtSession* session = - GetSessionForXrtDevice(session_cache_.get(), xrt_device, &session_map); - tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); - std::vector> ops_outputs(ops.size()); - std::vector results; - for (size_t i = 0; i < ops.size(); ++i) { - const ExecuteChainedOp& op = ops[i]; - if (op.device_data != nullptr) { - ops_outputs[i].push_back(op.device_data); - } else { - tensorflow::ClientSession::FeedType feed_inputs; - std::vector arguments; - arguments.reserve(op.inputs.size()); - for (auto& input : op.inputs) { - XLA_CHECK_LT(input.op_index, i); - XLA_CHECK_LT(input.output_index.value_or(0), - ops_outputs[input.op_index].size()); - arguments.push_back( - ops_outputs[input.op_index][input.output_index.value_or(0)]); - } - - std::vector exec_ops = CreateExecuteOps( - &session_map, dynamic_cast(*op.computation), - BuildParallelArguments(arguments), /*explode_tuple=*/true, {device}, - &feed_inputs); - - std::vector outputs; - util::CheckComputationStatus( - session->session()->Run(feed_inputs, {exec_ops.front()}, &outputs), - {&op.computation->computation()}, - {&op.computation->program_shape().result()}); - XLA_CHECK_EQ(outputs.size(), 1); - ops_outputs[i] = GetComputationResults( - outputs[0], op.computation->program_shape().result(), device); - } - - for (auto& output : op.outputs) { - if (output.result_index >= results.size()) { - results.resize(output.result_index + 1); - } - XLA_CHECK_LT(output.output_index.value_or(0), ops_outputs[i].size()); - results[output.result_index] = - ops_outputs[i][output.output_index.value_or(0)]; - } - // Drop references to any intermediate result which is not used anymore. - for (auto& input : op.inputs) { - uses[input.op_index] -= 1; - if (uses[input.op_index] == 0) { - ops_outputs[input.op_index].clear(); - } - } - // We can reset the TF op cache here so that we don't keep allocating new - // TF op nodes on the session graph. - session->Reset(); - } - return results; -} - -std::vector> -XrtComputationClient::DeconstructTuple(absl::Span tuples) { - metrics::TimedSection timed(DeconstructTupleMetric()); - - XrtSessionCache::SessionMap session_map; - std::map session_work_map; - std::vector tuple_elements_count(tuples.size()); - for (size_t i = 0; i < tuples.size(); ++i) { - const XrtData& xrt_data = dynamic_cast(*tuples[i]); - XrtSession* session = GetSessionForDevice(session_cache_.get(), - xrt_data.device(), &session_map); - SessionWork* session_work = &session_work_map[session]; - session_work->index_mapping.push_back(i); - - tensorflow::Scope device_scope = - session->root()->WithDevice(TorchDeviceToXrtDevice(xrt_data.device())); - int64_t count = xla::ShapeUtil::TupleElementCount(xrt_data.shape()); - tuple_elements_count[i] = count; - for (int64_t j = 0; j < count; ++j) { - const XrtSession::CachedNode& cached_node = - GetSubTupleNode(session, device_scope, xrt_data.device()); - session_work->feed_inputs.insert( - {cached_node.holders[0], xrt_data.get_handle()}); - tensorflow::Tensor index_tensor(tensorflow::DT_INT32, - tensorflow::TensorShape({1})); - index_tensor.flat()(0) = j; - session_work->feed_inputs.insert({cached_node.holders[1], index_tensor}); - session_work->outputs_handles.push_back(cached_node.outputs[0]); - } - } - - std::vector> results(tuples.size()); - for (auto& session_work : session_work_map) { - std::vector outputs; - XLA_CHECK_OK(session_work.first->session()->Run( - session_work.second.feed_inputs, session_work.second.outputs_handles, - &outputs)); - XLA_CHECK_EQ(outputs.size(), session_work.second.outputs_handles.size()); - - size_t output_index = 0; - for (auto li : session_work.second.index_mapping) { - const XrtData& xrt_data = dynamic_cast(*tuples[li]); - std::vector tuple_results; - for (size_t i = 0; i < tuple_elements_count[li]; ++i, ++output_index) { - tuple_results.push_back(std::make_shared( - this, xrt_data.device(), - xla::ShapeUtil::GetTupleElementShape(xrt_data.shape(), i), - outputs[output_index].scalar()())); - } - results[li] = std::move(tuple_results); - CreateDataHandlesCounter()->AddValue(tuple_elements_count[li]); - } - } - return results; -} - -XrtSession* XrtComputationClient::GetSessionForTarget( - XrtSessionCache* cache, const std::string& target, - XrtSessionCache::SessionMap* session_map) { - return cache->GetSession(target, session_map); -} - -XrtSession* XrtComputationClient::GetSessionForXrtDevice( - XrtSessionCache* cache, const std::string& xrt_device, - XrtSessionCache::SessionMap* session_map) { - auto worker_hostport = GetWorkerForXrtDevice(xrt_device); - return GetSessionForTarget(cache, worker_hostport.second, session_map); -} - -XrtSession* XrtComputationClient::GetSessionForDevice( - XrtSessionCache* cache, const std::string& device, - XrtSessionCache::SessionMap* session_map) { - return GetSessionForXrtDevice(cache, TorchDeviceToXrtDevice(device), - session_map); -} - -const std::string& XrtComputationClient::TorchDeviceToXrtDevice( - const std::string& device) const { - auto device_target = options_.global_device_map.find(device); - XLA_CHECK(device_target != options_.global_device_map.end()) - << "Unable to find device: " << device; - return device_target->second; -} - -std::unique_ptr XrtComputationClient::CreateXrtComputation( - const xla::XlaComputation& computation, - absl::Span devices, - const xla::Shape* output_shape) const { - std::unique_ptr xrt_computation( - new xrt::XLAComputation()); - auto config = xrt_computation->mutable_config(); - - // The computation here assumes that all devices participate in replication. - config->set_num_cores_per_replica(1); - if (devices.size() > 1) { - auto device_assignment = config->mutable_device_assignment(); - auto computation_device = device_assignment->add_computation_devices(); - for (int64_t i = 0; i < devices.size(); ++i) { - Device device(devices[i]); - auto replica_device = computation_device->add_replica_devices(); - if (device.kind == "TPU") { - const std::string& xrt_device = TorchDeviceToXrtDevice(devices[i]); - const auto& core_coords = GetDeviceMeshCoords(xrt_device); - for (auto coord : core_coords) { - replica_device->add_value(coord); - } - } else if (device.kind == "GPU") { - // For GPU use X,Y,Z=0 and CORE=GPU_ORDINAL (where GPU_ORDINAL is the - // global ordinal value). - replica_device->add_value(0); - replica_device->add_value(0); - replica_device->add_value(0); - replica_device->add_value(device.ordinal); - } else { - XLA_ERROR() << "Unsupported replication device type: " << device.kind; - } - } - config->set_num_replicas(devices.size()); - } - - *config->mutable_program_shape() = - computation.GetProgramShape().value().ToProto(); - if (output_shape != nullptr) { - *config->mutable_program_shape()->mutable_result() = - output_shape->ToProto(); - } - *xrt_computation->mutable_hlo_snapshot() = - std::move(*computation.Snapshot().value()); - return xrt_computation; -} - -tensorflow::Tensor XrtComputationClient::GetArgumentsInputs( - absl::Span arguments, const std::string& device) { - tensorflow::Tensor inputs_tensor( - tensorflow::DT_INT64, - tensorflow::TensorShape({static_cast(arguments.size())})); - for (size_t i = 0; i < arguments.size(); ++i) { - const XrtData& xrt_data = dynamic_cast(*arguments[i]); - XLA_CHECK_EQ(device, xrt_data.device()); - inputs_tensor.flat()(i) = xrt_data.get_handle(); - } - return inputs_tensor; -} - -std::vector XrtComputationClient::CreateExecuteOps( - XrtSessionCache::SessionMap* session_map, - absl::Span computations, - const std::vector>& arguments, bool explode_tuple, - absl::Span devices, - tensorflow::ClientSession::FeedType* feed_inputs) { - std::vector exec_ops; - for (size_t i = 0; i < computations.size(); ++i) { - const XrtComputation* xrt_computation = - dynamic_cast(computations[i]); - auto inputs = GetArgumentsInputs(arguments[i], devices[i]); - const std::string& xrt_device = TorchDeviceToXrtDevice(devices[i]); - XrtSession* session = - GetSessionForXrtDevice(session_cache_.get(), xrt_device, session_map); - tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); - const XrtSession::CachedNode& cached_node = - GetExecuteNode(session, device_scope, devices[i]); - feed_inputs->insert( - {cached_node.holders[0], xrt_computation->get_handle()}); - - xrt::XRTExecutionConfig exec_config; - exec_config.set_release_input_handles(false); - exec_config.set_release_compilation_handle(false); - exec_config.set_return_exploded_tuple(explode_tuple); - SetupExecConfig(Device(devices[i]), &exec_config); - - feed_inputs->insert( - {cached_node.holders[1], exec_config.SerializeAsString()}); - feed_inputs->insert({cached_node.holders[2], inputs}); - - exec_ops.push_back(cached_node.outputs[0]); - } - return exec_ops; -} - -std::vector XrtComputationClient::CreateExecuteOps( - XrtSessionCache::SessionMap* session_map, const XrtComputation& computation, - const std::vector>& arguments, bool explode_tuple, - absl::Span devices, - tensorflow::ClientSession::FeedType* feed_inputs) { - std::vector exec_ops; - for (size_t i = 0; i < arguments.size(); ++i) { - auto inputs = GetArgumentsInputs(arguments[i], devices[i]); - const std::string& xrt_device = TorchDeviceToXrtDevice(devices[i]); - XrtSession* session = - GetSessionForXrtDevice(session_cache_.get(), xrt_device, session_map); - tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); - const XrtSession::CachedNode& cached_node = - GetExecuteNode(session, device_scope, devices[i]); - feed_inputs->insert({cached_node.holders[0], computation.get_handle()}); - - xrt::XRTExecutionConfig exec_config; - exec_config.set_release_input_handles(false); - exec_config.set_release_compilation_handle(false); - exec_config.set_return_exploded_tuple(explode_tuple); - SetupExecConfig(Device(devices[i]), &exec_config); - - feed_inputs->insert( - {cached_node.holders[1], exec_config.SerializeAsString()}); - feed_inputs->insert({cached_node.holders[2], inputs}); - - exec_ops.push_back(cached_node.outputs[0]); - } - return exec_ops; -} - -void XrtComputationClient::ReleaseHandles( - std::vector* handles, - const std::function& - op_generator, - metrics::Metric* timed_metric, metrics::Counter* destroy_counter) { - tsl::profiler::TraceMe activity("ReleaseHandles", - tsl::profiler::TraceMeLevel::kInfo); - std::vector released_handles; - { - std::lock_guard lock(lock_); - released_handles.swap(*handles); - } - if (!released_handles.empty()) { - metrics::TimedSection timed(timed_metric); - - XrtSessionCache::SessionMap session_map; - std::map> session_handles_map; - for (auto& handle : released_handles) { - XrtSession* session = GetSessionForDevice(session_cache_.get(), - handle.device, &session_map); - session_handles_map[session].push_back(handle); - } - for (const auto& session_and_handles : session_handles_map) { - XrtSession* session = session_and_handles.first; - const std::vector& session_handles = - session_and_handles.second; - tensorflow::Tensor handles_tensor( - tensorflow::DT_INT64, - tensorflow::TensorShape({static_cast(session_handles.size())})); - auto flat_handles_tensor = handles_tensor.flat(); - for (size_t i = 0; i < session_handles.size(); ++i) { - flat_handles_tensor(i) = session_handles[i].handle; - } - tensorflow::Scope device_scope = session->root()->WithDevice( - TorchDeviceToXrtDevice(session_handles.front().device)); - const XrtSession::CachedNode& cached_node = - op_generator(session, device_scope, session_handles.front().device); - tensorflow::ClientSession::FeedType feed_inputs; - feed_inputs.insert({cached_node.holders[0], handles_tensor}); - - std::vector outputs; - XLA_CHECK_OK(session->session()->Run( - feed_inputs, {}, {cached_node.operations[0]}, &outputs)); - } - destroy_counter->AddValue(released_handles.size()); - } -} - -void XrtComputationClient::StartHandleReleaser() { - static const size_t kMinReleaserThreads = 8; - int64_t num_threads = sys_util::GetEnvInt( - "XLA_HANDLE_RELEASE_THREADS", - std::max(options_.devices.size(), kMinReleaserThreads)); - triggered_task_.reset( - new util::TriggeredTask([this]() { HandleReleaser(); }, num_threads)); -} - -void XrtComputationClient::HandleReleaser() { - auto data_op_generator = - [this](XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) -> const XrtSession::CachedNode& { - return GetReleaseAllocationHandleNode(session, scope, device); - }; - ReleaseHandles(&released_data_handles_, data_op_generator, - ReleaseDataHandlesTimeMetric(), DestroyDataHandlesCounter()); - - auto compile_op_generator = - [this](XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) -> const XrtSession::CachedNode& { - return GetReleaseCompileHandleNode(session, scope, device); - }; - ReleaseHandles(&released_compile_handles_, compile_op_generator, - ReleaseCompileHandlesTimeMetric(), - DestroyCompileHandlesCounter()); -} - -void XrtComputationClient::ReleaseHandle(int64_t handle, - const std::string& device, - std::vector* handles) { - { - std::lock_guard lock(lock_); - handles->push_back({device, handle}); - } - triggered_task_->Activate(); -} - -void XrtComputationClient::ReleaseXrtData(const std::string& device, - int64_t handle) { - ReleaseHandle(handle, device, &released_data_handles_); - ReleaseDataHandlesCounter()->AddValue(1); -} - -void XrtComputationClient::ReleaseXrtComputation( - const std::string& compilation_device, int64_t handle) { - ReleaseHandle(handle, compilation_device, &released_compile_handles_); - ReleaseCompileHandlesCounter()->AddValue(1); -} - -std::pair -XrtComputationClient::GetWorkerForXrtDevice( - const std::string& xrt_device) const { - tensorflow::DeviceNameUtils::ParsedName parsed_device = - ParseFullXrtDevice(xrt_device); - auto worker_hostport = - options_.workers_map.find(Worker(parsed_device.job, parsed_device.task)); - XLA_CHECK(worker_hostport != options_.workers_map.end()) << xrt_device; - return std::pair(worker_hostport->first, - worker_hostport->second); -} - -std::pair -XrtComputationClient::GetWorkerForDevice(const std::string& device) const { - return GetWorkerForXrtDevice(TorchDeviceToXrtDevice(device)); -} - -const std::vector& XrtComputationClient::GetDeviceMeshCoords( - const std::string& xrt_device) const { - auto it = device_mesh_coords_.find(xrt_device); - if (it == device_mesh_coords_.end()) { - TF_LOG(FATAL) << "Missing mesh coordinates for device: " << xrt_device; - } - return it->second; -} - -tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology( - const std::string& job, int task_no, const std::string& worker_host_port, - const tensorflow::ConfigProto& config) { - tensorflow::SessionOptions session_options; - session_options.env = tsl::Env::Default(); - session_options.target = worker_host_port; - session_options.config = config; - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - tensorflow::ClientSession session(root, session_options); - std::string system_device = absl::StrCat( - "/job:", job, "/replica:0/task:", task_no, "/device:TPU_SYSTEM:0"); - tensorflow::Scope tpu_system_scope = root.WithDevice(system_device); - const auto unique_name = - tpu_system_scope.GetUniqueNameForOp("ConfigureDistributedTPU"); - tensorflow::NodeBuilder builder = - tensorflow::NodeBuilder(unique_name, "ConfigureDistributedTPU") - .Attr("embedding_config", "") - .Attr("tpu_embedding_config", "") - .Attr("is_global_init", false); - // TODO: Remove this once the new TF build can be relied upon, on the Cloud - // TPU side. - const tensorflow::ClusterDef cluster_def = config.cluster_def(); - if (cluster_def.job_size() > 1 || - (cluster_def.job_size() == 1 && cluster_def.job()[0].tasks_size() > 1)) { - builder.Attr("enable_whole_mesh_compilations", true); - } - - tpu_system_scope.UpdateBuilder(&builder); - - tensorflow::Node* result; - root.UpdateStatus(builder.Finalize(tpu_system_scope.graph(), &result)); - XLA_CHECK_OK(tpu_system_scope.status()); - root.UpdateStatus(tpu_system_scope.DoShapeInference(result)); - - std::vector outputs; - XLA_CHECK_OK(root.status()); - XLA_CHECK_OK(session.Run({tensorflow::Output(result, 0)}, &outputs)); - XLA_CHECK_EQ(outputs.size(), 1); - - return ParseProto(outputs[0]); -} - -void XrtComputationClient::InitializeDevices( - std::unique_ptr topology_proto) { - if (topology_proto == nullptr) { - std::set tpu_workers; - for (const auto& dev_target : options_.global_device_map) { - tensorflow::DeviceNameUtils::ParsedName parsed_device = - ParseFullXrtDevice(dev_target.second); - if (parsed_device.type == "TPU") { - tpu_workers.emplace(parsed_device.job, parsed_device.task); - } - } - if (!tpu_workers.empty()) { - const Worker& worker = *tpu_workers.begin(); - auto it = options_.workers_map.find(worker); - XLA_CHECK(it != options_.workers_map.end()); - - TF_VLOG(1) << "Configuring TPU for master worker " << worker.name << ":" - << worker.task_no << " at " << it->second; - tensorflow::tpu::TopologyProto worker_topology_proto = - InitializeAndFetchTopology(worker.name, worker.task_no, it->second, - session_cache_->GetConfig()); - if (topology_proto == nullptr) { - topology_proto = absl::make_unique( - std::move(worker_topology_proto)); - } - } - if (topology_proto != nullptr) { - TF_VLOG(1) << "TPU topology: " << topology_proto->DebugString(); - } - } - for (const auto& dev_target : options_.global_device_map) { - tensorflow::DeviceNameUtils::ParsedName parsed_device = - ParseFullXrtDevice(dev_target.second); - if (parsed_device.type != "TPU") { - continue; - } - XLA_CHECK_LE(parsed_device.task, topology_proto->num_tasks()); - XLA_CHECK_LE(parsed_device.id, topology_proto->num_tpu_devices_per_task()); - // The topology proto 'device_coordinates' is a linear list of - // [num_tasks][devices_per_task][mesh_shape_size] coordinates, where the - // mesh coordinates are usually [x, y, z, c] ('x', 'y' and 'z' being the - // spatial chip coordinated and 'c' the core number). - int64_t base_index = parsed_device.task * - topology_proto->num_tpu_devices_per_task() * - topology_proto->mesh_shape_size() + - parsed_device.id * topology_proto->mesh_shape_size(); - std::vector device_mesh_coords(topology_proto->mesh_shape_size()); - for (int i = 0; i < topology_proto->mesh_shape_size(); ++i) { - device_mesh_coords[i] = - topology_proto->device_coordinates(base_index + i); - } - device_mesh_coords_.insert( - {dev_target.second, std::move(device_mesh_coords)}); - } - - // Create the mesh service only if we have more than one worker, or if - // multi-processing is active. - std::string mesh_service_address = - sys_util::GetEnvString(env::kEnvMeshService, ""); - std::string mp_device = GetMultiProcessingDevice(); - if (!mesh_service_address.empty() && !mp_device.empty()) { - int host_ordinal = sys_util::GetEnvInt(env::kEnvHostOrdinal, -1); - Device device(mp_device); - if (host_ordinal <= 0) { - if (device.ordinal == 0) { - CreateMeshService(mesh_service_address, topology_proto.get()); - } - } else { - // Here we are in the sea-of-devices case. - if (device.ordinal == 0) { - service::grpc::Config config = - CreateMeshServiceConfig(topology_proto.get()); - service::MeshClient::Get()->SetConfig(host_ordinal, config); - } - } - SetupGpuRuntime(); - } -} - -void XrtComputationClient::SetupGpuRuntime() { - struct NcclUniqueIdFactory : public tensorflow::NcclUniqueIdFactory { - std::string GetUniqueId(absl::Span replicas) override { - return service::MeshClient::Get()->GetNcclUniqueUid(replicas); - } - }; - - tensorflow::SetNcclUniqueIdFactory(std::make_shared()); -} - -service::grpc::Config XrtComputationClient::CreateMeshServiceConfig( - const tensorflow::tpu::TopologyProto* topology_proto) const { - struct Device { - std::string local_name; - std::string global_name; - }; - - service::grpc::Config config; - config.set_mesh_size(sys_util::GetEnvInt(env::kEnvWorldSize, 1)); - if (topology_proto != nullptr) { - config.mutable_proto()->CopyFrom(*topology_proto); - } - - std::map> workers_devices; - for (const auto& dev_target : options_.global_device_map) { - tensorflow::DeviceNameUtils::ParsedName parsed_device = - ParseFullXrtDevice(dev_target.second); - std::string local_name = - absl::StrCat(parsed_device.type, ":", parsed_device.id); - workers_devices[Worker(parsed_device.job, parsed_device.task)].push_back( - {local_name, dev_target.first}); - } - for (auto& worker_address : options_.workers_map) { - service::grpc::Worker* worker = config.add_workers(); - worker->set_name(worker_address.first.name); - worker->set_task_no(worker_address.first.task_no); - worker->set_address(worker_address.second); - for (auto& worker_device : workers_devices[worker_address.first]) { - service::grpc::Device* device = worker->add_devices(); - device->set_local_name(worker_device.local_name); - device->set_global_name(worker_device.global_name); - } - } - return config; -} - -void XrtComputationClient::CreateMeshService( - const std::string& address, - const tensorflow::tpu::TopologyProto* topology_proto) { - service::grpc::Config config = CreateMeshServiceConfig(topology_proto); - - TF_VLOG(1) << "Creating mesh service bound to " << address; - mesh_service_ = - absl::make_unique(address, std::move(config)); -} - -std::vector -XrtComputationClient::GetComputationResults( - const tensorflow::Tensor& xrt_result, const xla::Shape& result_shape, - const std::string& device) { - std::vector results; - if (xrt_result.dims() == 1) { - auto handles_vec = xrt_result.vec(); - for (int64_t i = 0; i < handles_vec.size(); ++i) { - results.push_back(std::make_shared( - this, device, xla::ShapeUtil::GetTupleElementShape(result_shape, i), - handles_vec(i))); - } - } else { - results.push_back(std::make_shared( - this, device, result_shape, xrt_result.scalar()())); - } - CreateDataHandlesCounter()->AddValue(results.size()); - return results; -} - -std::string XrtComputationClient::GetResourceDomain( - const std::string& device) const { - return GetWorkerForDevice(device).second; -} - -std::string XrtComputationClient::GetDefaultDevice() const { - return options_.default_device; -} - -size_t XrtComputationClient::GetNumDevices() const { - return options_.devices.size(); -} - -std::vector XrtComputationClient::GetLocalDevices() const { - return std::vector(options_.devices.begin(), - options_.devices.end()); -} - -std::vector XrtComputationClient::GetAllDevices() const { - std::vector devices; - for (const auto& dev_target : options_.global_device_map) { - devices.push_back(dev_target.first); - } - return devices; -} - -void XrtComputationClient::SetReplicationDevices( - std::shared_ptr> devices) { - std::lock_guard lock(lock_); - replication_devices_ = std::move(devices); -} - -std::shared_ptr> -XrtComputationClient::GetReplicationDevices() { - std::lock_guard lock(lock_); - return replication_devices_; -} - -void XrtComputationClient::SetRngSeed(size_t seed) { rng_seed_ = seed; } - -std::map XrtComputationClient::GetMetrics() const { - static const std::map* metric_remap = - new std::map{ - {"/tensorflow/xrt/ops/allocate", "XrtAllocate"}, - {"/tensorflow/xrt/ops/allocate_from_tensor", "XrtAllocateFromTensor"}, - {"/tensorflow/xrt/ops/sub_tuple", "XrtSubTuple"}, - {"/tensorflow/xrt/ops/make_tuple", "XrtMakeTuple"}, - {"/tensorflow/xrt/ops/compile", "XrtCompile"}, - {"/tensorflow/xrt/ops/release_compilation", "XrtReleaseCompilation"}, - {"/tensorflow/xrt/ops/execute", "XrtExecute"}, - {"/tensorflow/xrt/ops/execute_chained", "XrtExecuteChained"}, - {"/tensorflow/xrt/ops/read_literal", "XrtReadLiteral"}, - {"/tensorflow/xrt/ops/read_tensor", "XrtReadTensor"}, - {"/tensorflow/xrt/ops/write_literal", "XrtWriteLiteral"}, - {"/tensorflow/xrt/ops/release_allocation", "XrtReleaseAllocation"}, - {"/tensorflow/xrt/ops/release_all_allocations", - "XrtReleaseAllAllocations"}, - {"/tensorflow/xrt/ops/compact_allocations", "XrtCompactAllocations"}, - {"/tensorflow/xrt/memory_manager/compaction", "XrtCompaction"}, - {"/tensorflow/xrt/memory_manager/try_free_memory", - "XrtTryFreeMemory"}, - {"/tensorflow/xrt/executor/program_memory_evict", "XrtExecutorEvict"}, - {"/tensorflow/xrt/ds_executor/program_memory_evict", - "XrtExecutorEvict"}}; - - std::map metrics_data; - xrt::XRTMetricsCollect metrics; - metrics.add_metrics_regex("/tensorflow/xrt/.*"); - - for (auto& worker_target : options_.workers_map) { - tensorflow::SessionOptions session_options; - session_options.env = tsl::Env::Default(); - session_options.target = worker_target.second; - - // GPU device cannot reuse ClusterSpec from session cache, otherwise - // tensorflow throws an error here: - // https://github.com/tensorflow/tensorflow/blob/1cb0c5b850657ae1362a241fabb16253336dd8c3/tensorflow/core/distributed_runtime/master.cc#L402 - if (!absl::StartsWith(GetDefaultDevice(), "GPU")) { - session_options.config = session_cache_->GetConfig(); - } - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - tensorflow::ClientSession session(root, session_options); - std::string cpu0_device = absl::StrCat( - "/job:", worker_target.first.name, - "/replica:0/task:", worker_target.first.task_no, "/device:CPU:0"); - tensorflow::Scope cpu_system_scope = root.WithDevice(cpu0_device); - auto metrics_value = - tensorflow::ops::Const(cpu_system_scope, metrics.SerializeAsString()); - tensorflow::Output result = - tensorflow::ops::XRTMetricsCollect(cpu_system_scope, metrics_value); - XLA_CHECK_OK(cpu_system_scope.status()); - - std::vector outputs; - XLA_CHECK_OK(session.Run({result}, &outputs)); - XLA_CHECK_EQ(outputs.size(), 1); - - xrt::MetricsReport report = ParseProto(outputs[0]); - for (auto& xrt_metric : report.metrics()) { - Metric metric; - if (xrt_metric.values_oneof_case() == - xrt::MetricValues::kPercentilesValue) { - const xrt::Percentiles& xrt_percentile = xrt_metric.percentiles_value(); - Percentile percentile; - switch (xrt_metric.unit_of_measure()) { - case xrt::MetricValues::NUMBER: - percentile.unit_of_measure = Percentile::UnitOfMeaure::kNumber; - break; - case xrt::MetricValues::TIME: - percentile.unit_of_measure = Percentile::UnitOfMeaure::kTime; - break; - case xrt::MetricValues::BYTES: - percentile.unit_of_measure = Percentile::UnitOfMeaure::kBytes; - break; - default: - TF_LOG(FATAL) << "Invalid unit of measure for xrt metric: " - << xrt_metric.name(); - break; - } - percentile.start_nstime = xrt_percentile.start_nstime(); - percentile.end_nstime = xrt_percentile.end_nstime(); - percentile.min_value = xrt_percentile.min_value(); - percentile.max_value = xrt_percentile.max_value(); - percentile.mean = xrt_percentile.mean(); - percentile.stddev = xrt_percentile.stddev(); - percentile.num_samples = xrt_percentile.num_samples(); - percentile.total_samples = xrt_percentile.total_samples(); - percentile.accumulator = xrt_percentile.accumulator(); - for (auto& xrt_point : xrt_percentile.points()) { - percentile.points.push_back( - Percentile::Point{xrt_point.percentile(), xrt_point.value()}); - } - metric.percentile = std::move(percentile); - } else if (xrt_metric.values_oneof_case() == - xrt::MetricValues::kInt64Value) { - metric.int64_value = xrt_metric.int64_value(); - } else { - continue; - } - - std::string metric_name; - auto it = metric_remap->find(xrt_metric.name()); - if (it != metric_remap->end()) { - metric_name = it->second; - } else { - metric_name = xrt_metric.name(); - } - if (options_.workers_map.size() > 1) { - metric_name = absl::StrCat(metric_name, ".", worker_target.first.name, - ".", worker_target.first.task_no); - } - metrics_data.emplace(std::move(metric_name), std::move(metric)); - } - } - return metrics_data; -} - -ComputationClient::MemoryInfo XrtComputationClient::GetMemoryInfo( - const std::string& device) { - const std::string& xrt_device = TorchDeviceToXrtDevice(device); - XrtSessionCache::SessionMap session_map; - XrtSession* session = - GetSessionForXrtDevice(session_cache_.get(), xrt_device, &session_map); - tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); - const XrtSession::CachedNode& cached_node = - GetMemoryInfoNode(session, device_scope, device); - - std::vector outputs; - XLA_CHECK_OK(session->session()->Run({cached_node.outputs[0]}, &outputs)); - - xrt::MemoryInfo mem_info = ParseProto(outputs[0]); - return {mem_info.kb_free(), mem_info.kb_total()}; -} - -void XrtComputationClient::PrepareToExit() { - if (mesh_service_ != nullptr) { - TF_VLOG(1) << "Shutting down mesh service ..."; - mesh_service_->Shutdown(); - TF_VLOG(1) << "Shutting down mesh service ... done!"; - } - if (triggered_task_ != nullptr) { - TF_VLOG(1) << "Waiting XRT handle releaser thread ..."; - size_t run_id = triggered_task_->Activate(); - triggered_task_->WaitForRun(run_id); - TF_VLOG(1) << "Waiting XRT handle releaser thread ... done!"; - triggered_task_->Stop(); - } -} - -void XrtComputationClient::InitSession(XrtSession* session) const { - struct InitNode { - int count; - const XrtSession::CachedNode& (XrtComputationClient::*node_ctor)( - XrtSession*, const tensorflow::Scope&, const std::string&)const; - } const init_nodes[] = { - {16, &XrtComputationClient::GetCompileNode}, - {16, &XrtComputationClient::GetExecuteNode}, - {16, &XrtComputationClient::GetExecuteChainedNode}, - {16, &XrtComputationClient::GetReadNode}, - {16, &XrtComputationClient::GetReleaseAllocationHandleNode}, - {16, &XrtComputationClient::GetReleaseCompileHandleNode}, - {16, &XrtComputationClient::GetSubTupleNode}, - {16, &XrtComputationClient::GetMemoryInfoNode}, - }; - auto devices = GetLocalDevices(); - for (auto& device : devices) { - // HACK: The XRT ops on the remote GRPC service has only recently been - // enabled, so until TF 1.14 is out, we cannot add XRT ops on CPU. - // If there is only one device, even if CPU, this is the local session, - // which carries the XRT op (as we include them in the BUILD). - if (device.compare(0, 4, "CPU:") == 0 && devices.size() > 1) { - continue; - } - const std::string& xrt_device = TorchDeviceToXrtDevice(device); - tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); - for (auto& init : init_nodes) { - for (int i = 0; i < init.count; ++i) { - (this->*init.node_ctor)(session, device_scope, device); - } - } - } - session->Reset(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetCompileNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtCompile"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtCompile_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING)}); - cache->Add(std::make_shared( - tensorflow::ops::XRTCompile(scope, holders[0]).handle, holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetExecuteNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtExecute"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtExecute_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64), - tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), - tensorflow::ops::Placeholder( - scope, tensorflow::DT_INT64, - tensorflow::ops::Placeholder::Shape({-1}))}); - cache->Add(std::make_shared( - tensorflow::ops::XRTExecute(scope, holders[0], holders[1], - {tensorflow::Output(holders[2])}), - holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetExecuteChainedNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtExecuteChained"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtExecuteChained_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), - tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING)}); - cache->Add(std::make_shared( - tensorflow::ops::XRTExecuteChained(scope, holders[0], holders[1]), - holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetReadNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtRead"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtRead_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); - cache->Add(std::make_shared( - tensorflow::ops::XRTReadLiteral(scope, holders[0]), holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetAllocateNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device, const xla::Shape& shape) const { - // Create the proper key for the allocation node. Since the node has shape and - // layouts attributes, these need to be included within the key. - std::stringstream ss; - ss << "XRTAllocateFromTensor(" << shape << ")"; - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(ss.str(), device)); - if (cache->Empty()) { - XLA_COUNTER("XRTAllocateFromTensor_Empty", 1); - tensorflow::TensorShape tensor_shape(shape.dimensions()); - tensorflow::TensorShape equiv_tensor_shape = - MakeEquivalentTensorShape(shape); - std::vector layout(shape.layout().minor_to_major().begin(), - shape.layout().minor_to_major().end()); - std::vector holders( - {tensorflow::ops::Placeholder( - scope, XlaTypeToDataType(shape.element_type()), - tensorflow::ops::Placeholder::Shape(equiv_tensor_shape))}); - tensorflow::ops::XRTAllocateFromTensor::Attrs alloc_attrs = - tensorflow::ops::XRTAllocateFromTensor::Layouts(layout); - cache->Add(std::make_shared( - tensorflow::ops::XRTAllocateFromTensor(scope, {holders[0].output}, - {tensor_shape}, alloc_attrs), - holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& -XrtComputationClient::GetReleaseAllocationHandleNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtReleaseAllocationHandle"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtReleaseAllocationHandle_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); - cache->Add(std::make_shared( - tensorflow::ops::XRTReleaseAllocationHandle(scope, holders[0]), - holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetReleaseCompileHandleNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtReleaseCompileHandle"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtReleaseCompileHandle_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); - cache->Add(std::make_shared( - tensorflow::ops::XRTReleaseCompilationHandle(scope, holders[0]), - holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetSubTupleNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtSubTuple"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtSubTuple_Empty", 1); - std::vector holders( - {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64), - tensorflow::ops::Placeholder( - scope, tensorflow::DT_INT32, - tensorflow::ops::Placeholder::Shape({1}))}); - cache->Add(std::make_shared( - tensorflow::ops::XRTSubTuple(scope, holders[0], holders[1]), holders)); - } - return cache->Get(); -} - -const XrtSession::CachedNode& XrtComputationClient::GetMemoryInfoNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const { - static const std::string op_name("XrtMemoryInfo"); - XrtSession::NodeCache* cache = - session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); - if (cache->Empty()) { - XLA_COUNTER("XrtMemoryInfo_Empty", 1); - std::vector holders; - cache->Add(std::make_shared( - tensorflow::ops::XRTMemoryInfo(scope), holders)); - } - return cache->Get(); -} - -tensorflow::DataType XrtComputationClient::XlaTypeToDataType( - xla::PrimitiveType dtype) { - switch (dtype) { - case xla::PrimitiveType::PRED: - return tensorflow::DT_BOOL; - case xla::PrimitiveType::S8: - return tensorflow::DT_INT8; - case xla::PrimitiveType::U8: - return tensorflow::DT_UINT8; - case xla::PrimitiveType::S16: - return tensorflow::DT_INT16; - case xla::PrimitiveType::U16: - return tensorflow::DT_UINT16; - case xla::PrimitiveType::S32: - return tensorflow::DT_INT32; - case xla::PrimitiveType::U32: - return tensorflow::DT_UINT32; - case xla::PrimitiveType::S64: - return tensorflow::DT_INT64; - case xla::PrimitiveType::U64: - return tensorflow::DT_UINT64; - case xla::PrimitiveType::F32: - return tensorflow::DT_FLOAT; - case xla::PrimitiveType::F64: - return tensorflow::DT_DOUBLE; - case xla::PrimitiveType::BF16: - return tensorflow::DT_BFLOAT16; - case xla::PrimitiveType::F16: - return tensorflow::DT_HALF; - case xla::PrimitiveType::C64: - return tensorflow::DT_COMPLEX64; - case xla::PrimitiveType::C128: - return tensorflow::DT_COMPLEX128; - default: - break; - } - XLA_ERROR() << "Unable to convert XLA type " << dtype - << " to tensorflow DataType"; -} - -tensorflow::TensorShape XrtComputationClient::MakeEquivalentTensorShape( - const xla::Shape& shape) { - xla::Shape eqiv_shape = - xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape); - return tensorflow::TensorShape(eqiv_shape.dimensions()); -} - -std::vector> -XrtComputationClient::BuildParallelArguments( - absl::Span arguments) { - std::vector> para_arguments(1); - para_arguments[0].insert(para_arguments[0].end(), arguments.begin(), - arguments.end()); - return para_arguments; -} - -tensorflow::ConfigProto XrtComputationClient::CreateConfigProto( - const Options& options) { - static const std::string* const grpc_proto = new std::string("grpc://"); - tensorflow::ConfigProto config; - if (options.workers_map.size() > 1) { - tensorflow::ClusterDef* cluster_def = config.mutable_cluster_def(); - std::map jobs; - for (auto& worker_target : options.workers_map) { - auto it = jobs.find(worker_target.first.name); - if (it == jobs.end()) { - tensorflow::JobDef* job = cluster_def->add_job(); - job->set_name(worker_target.first.name); - it = jobs.emplace(worker_target.first.name, job).first; - } - tensorflow::JobDef* job = it->second; - (*job->mutable_tasks())[worker_target.first.task_no] = - StripPrefix(worker_target.second, *grpc_proto); - } - } - return config; -} - -XrtComputationClient::Worker XrtComputationClient::ParseWorker( - const std::string& worker) { - std::vector parts = absl::StrSplit(worker, ':'); - XLA_CHECK(parts.size() == 1 || parts.size() == 2) << worker; - return parts.size() == 1 ? Worker(parts[0], 0) - : Worker(parts[0], std::stoi(parts[1])); -} - -std::string XrtComputationClient::GetLocalTarget(const Options& options) { - std::string local_worker = sys_util::GetEnvString(env::kEnvLocalWorker, ""); - std::string local_target; - if (!local_worker.empty()) { - XrtComputationClient::Worker worker = ParseWorker(local_worker); - if (worker.name == kLocalService) { - auto it = options.workers_map.find(worker); - if (it != options.workers_map.end()) { - local_target = it->second; - } - } - } - return local_target; -} - -void XrtComputationClient::MaybeCreateLocalService(const Options& options) { - if (local_service_ != nullptr) { - TF_VLOG(1) << "Local service has been created, return"; - return; - } - std::string grpc_root("grpc://"); - std::string local_worker = sys_util::GetEnvString(env::kEnvLocalWorker, ""); - XrtComputationClient::Worker worker("", -1); - if (!local_worker.empty()) { - worker = ParseWorker(local_worker); - } - int task_index = -1; - std::string job_name; - std::vector hosts; - for (auto& worker_target : options.workers_map) { - if (worker_target.first.name == kLocalService && - worker_target.second.compare(0, grpc_root.size(), grpc_root) == 0) { - hosts.push_back(worker_target.second.substr(grpc_root.size())); - if (worker.task_no < 0 || worker_target.first == worker) { - XLA_CHECK_EQ(task_index, -1) - << "Multiple workers matching the local one: '" << local_worker - << "'"; - job_name = worker_target.first.name; - task_index = worker_target.first.task_no; - } - } - } - if (task_index >= 0 && !job_name.empty()) { - std::string cluster_spec = - absl::StrCat(job_name, "|", absl::StrJoin(hosts, ";")); - TF_VLOG(2) << "Local Service Cluster Spec: " << cluster_spec; - local_service_ = new XrtLocalService(cluster_spec, job_name, task_index); - local_service_->Start(); - } -} - -std::string XrtComputationClient::GetMultiProcessingDevice() { - return sys_util::GetEnvString(env::kEnvMpDevice, ""); -} - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xrt_computation_client.h b/torch_xla/csrc/runtime/xrt_computation_client.h deleted file mode 100644 index 97b830c443b4..000000000000 --- a/torch_xla/csrc/runtime/xrt_computation_client.h +++ /dev/null @@ -1,677 +0,0 @@ -#ifndef XLA_CLIENT_XRT_COMPUTATION_CLIENT_H_ -#define XLA_CLIENT_XRT_COMPUTATION_CLIENT_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/optional.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/protobuf/tpu/topology.pb.h" -#include "torch_xla/csrc/runtime/cache.h" -#include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/mesh_service.h" -#include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/triggered_task.h" -#include "torch_xla/csrc/runtime/util.h" -#include "torch_xla/csrc/runtime/xrt_local_service.h" -#include "torch_xla/csrc/runtime/xrt_session.h" -#include "torch_xla/csrc/runtime/xrt_session_cache.h" - -namespace torch_xla { -namespace runtime { - -class XrtLocker { - public: - void Lock() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return !locked_; }); - CheckResetException(); - locked_ = true; - } - - void Unlock(std::exception_ptr exptr) { - std::lock_guard lock(mutex_); - locked_ = false; - exptr_ = std::move(exptr); - cv_.notify_all(); - } - - void Barrier() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return !locked_; }); - cv_.notify_all(); - CheckResetException(); - } - - private: - void CheckResetException() { - std::exception_ptr exptr = std::move(exptr_); - exptr_ = nullptr; - if (exptr != nullptr) { - std::rethrow_exception(exptr); - } - } - - std::mutex mutex_; - std::condition_variable cv_; - bool locked_ = false; - std::exception_ptr exptr_; -}; - -class DataHandleLocker : public XrtLocker { - public: - static const int64_t dummy_handle; -}; - -class XrtComputationClient : public ComputationClient { - struct DeviceHandle { - std::string device; - int64_t handle; - }; - - class XrtHandle { - public: - XrtHandle(int64_t handle, std::function releaser, - bool async = false) - : handle_(handle), releaser(std::move(releaser)) { - if (async) { - locker = std::make_shared(); - } else { - locker = nullptr; - } - } - - ~XrtHandle() { - // Handle might only contain dummy value, need to wait for the - // true handle assigniment. - if (locker) { - XLA_TIMED("HandleBarrierWait"); - locker->Barrier(); - } - releaser(handle_); - } - - // Lock the current XrtHandle and prevent other caller from accessing the - // handle_ value. This function will return an ExceptionCleanup object which - // will rethrow the exception if there is one and unlock the XrtHandle upon - // destruction. - torch_xla::runtime::util::ExceptionCleanup LockHandle() { - std::shared_ptr locker_copy = this->locker; - locker_copy->Lock(); - return torch_xla::runtime::util::ExceptionCleanup( - [locker_copy = std::move(locker_copy)]( - torch_xla::runtime::util::ExceptionCleanup::StatusType status) { - locker_copy->Unlock(std::move(status)); - }); - } - - void update_handle(int64_t handle) { - // handle can only be updated once when it is dummy. - XLA_CHECK_EQ(handle_, DataHandleLocker::dummy_handle); - handle_ = handle; - } - - int64_t handle() { - // Handle might only contain dummy value, need to wait for the - // true handle assigniment - if (locker) { - XLA_TIMED("HandleBarrierWait"); - locker->Barrier(); - } - return handle_; - ; - } - - private: - int64_t handle_; - std::shared_ptr locker; - std::function releaser; - }; - - using XrtHandlePtr = std::shared_ptr; - - struct XrtData : public Data { - XrtData(std::string device, xla::Shape device_shape) - : Data(std::move(device), std::move(device_shape)), - handle_ptr(nullptr) {} - XrtData(XrtComputationClient* self, std::string device, - xla::Shape device_shape, int64_t handle) - : Data(std::move(device), std::move(device_shape)), - handle_ptr(std::make_shared( - handle, [self, device = this->device()](int64_t handle) { - self->ReleaseXrtData(device, handle); - })) {} - - XrtData(XrtComputationClient* self, std::string device, - xla::Shape device_shape, XrtHandlePtr handle) - : Data(std::move(device), std::move(device_shape)), - handle_ptr(handle) {} - - int64_t get_handle() const { - XLA_CHECK(HasValue()); - return handle_ptr->handle(); - } - - OpaqueHandle GetOpaqueHandle() override { return get_handle(); } - - void Assign(const Data& data) override; - - bool HasValue() const override { return handle_ptr != nullptr; } - - XrtHandlePtr handle_ptr; - }; - - struct XrtComputation : public Computation { - XrtComputation(XrtComputationClient* self, xla::XlaComputation computation, - xla::ProgramShape program_shape, - std::vector devices, int64_t handle, - std::string compilation_device) - : Computation(std::move(computation), std::move(program_shape), - std::move(devices)), - handle_ptr(std::make_shared( - handle, [self, compilation_device = std::move( - compilation_device)](int64_t handle) { - self->ReleaseXrtComputation(compilation_device, handle); - })) {} - - int64_t get_handle() const { return handle_ptr->handle(); } - - XrtHandlePtr handle_ptr; - }; - - public: - struct Device { - Device() = default; - Device(const std::string& device_str); - - std::string kind; - int ordinal = 0; - }; - - struct Worker { - Worker(std::string name, int task_no) - : name(std::move(name)), task_no(task_no) {} - - bool operator<(const Worker& rhs) const { - if (task_no != rhs.task_no) { - return task_no < rhs.task_no; - } - return name.compare(rhs.name) < 0; - } - - bool operator==(const Worker& rhs) const { - return task_no == rhs.task_no && name == rhs.name; - } - - std::string name; - int task_no; - }; - - struct Options { - std::string default_device; - // Maps a PyTorch device ID (example, "GPU:0", "TPU:0") to the full - // coordinates in TF device format - // (ie, /job:tpu_worker/replica:0/task:0/device:TPU:0), of the worker - // exposing that device. These devices are all the devices present within - // the TPU mesh. - std::map global_device_map; - // These are the devices that this instance of PyTorch is handling. These - // devices are in the form of "CPU:0", "TPU:3", ... For each of these - // devices, there is an entry within the global_device_map. - std::set devices; - // Maps a TPU Worker with an EndPoint. - std::map workers_map; - }; - - XrtComputationClient(); - - DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; - - std::vector CreateAsyncDatas( - absl::Span tensors) override; - - std::vector LockAsyncDatas( - absl::Span datas) override; - - std::vector GetDataShards(DataPtr data) override { return {data}; } - - DataPtr WrapDataShards(const std::vector& shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - std::optional GetDataSharding(DataPtr handle) override { - // Returns an empty sharding result, since XRT does not support sharding. - return std::optional(); - } - - std::vector TransferToServer( - absl::Span tensors) override; - - void TransferToServer(absl::Span tensors, - absl::Span datas) override; - - DataPtr TransferShardsToServer(absl::Span tensor_shards, - std::string device, xla::Shape shape, - xla::OpSharding sharding) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - DataPtr CopyToDevice(DataPtr data, std::string dst) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - std::vector TransferFromServer( - absl::Span handles) override; - - std::vector Compile( - std::vector instances) override; - - std::vector ExecuteComputation( - const Computation& computation, absl::Span arguments, - const std::string& device, - const ExecuteComputationOptions& options) override; - - std::vector> ExecuteReplicated( - const Computation& computation, - const std::vector>& arguments, - absl::Span devices, - const ExecuteReplicatedOptions& options) override; - - std::vector> ExecuteParallel( - absl::Span computations, - const std::vector>& arguments, - absl::Span devices, - const ExecuteParallelOptions& options) override; - - std::vector ExecuteChained(absl::Span ops, - const std::string& device) override; - - std::vector> DeconstructTuple( - absl::Span tuples) override; - - std::string GetResourceDomain(const std::string& device) const override; - - std::string GetDefaultDevice() const override; - - size_t GetNumDevices() const override; - - std::vector GetLocalDevices() const override; - - std::vector GetAllDevices() const override; - - int GetProcessIndex() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - int GetNumProcesses() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - const absl::flat_hash_map< - std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& - GetDeviceAttributes(const std::string& device) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - void SetReplicationDevices( - std::shared_ptr> devices) override; - - std::shared_ptr> GetReplicationDevices() override; - - void SetRngSeed(size_t seed) override; - - std::map GetMetrics() const override; - - MemoryInfo GetMemoryInfo(const std::string& device) override; - - void PrepareToExit() override; - - void WaitDeviceOps(const std::vector& devices) override { - // XRT Device Computation is guranteed to finish when ExecuteComputation - // returns. No need to implement WaitDeviceOps. - return; - }; - - static Worker ParseWorker(const std::string& worker); - - static std::string GetMultiProcessingDevice(); - - private: - // The data structure used for the key in the compilation cache. Compilations - // handles are valid within given domain (essentially the host+port worker - // endpoints), so the key must include the domain. - struct CompilationCacheKey { - struct Hash { - size_t operator()(const CompilationCacheKey& entry) const { - util::PartialHasher hasher; - hash_t h = util::DataHash(entry.domain.data(), entry.domain.size()); - return util::HashReduce( - util::HashCombine(h, hasher(entry.serialized_computation))); - } - }; - - CompilationCacheKey(std::string domain, std::string serialized_computation) - : domain(std::move(domain)), - serialized_computation(std::move(serialized_computation)) {} - CompilationCacheKey() = default; - CompilationCacheKey(CompilationCacheKey&&) = default; - CompilationCacheKey& operator=(CompilationCacheKey&&) = default; - bool operator==(const CompilationCacheKey& rhs) const { - return domain == rhs.domain && - serialized_computation == rhs.serialized_computation; - } - - std::string domain; - std::string serialized_computation; - }; - - // When we split a batch operation into per-session batches, we use this data - // structure to collect the per-session work. - struct SessionWork { - tensorflow::ClientSession::FeedType feed_inputs; - std::vector outputs_handles; - std::vector operations; - std::vector index_mapping; - }; - - XrtSession* GetSessionForTarget(XrtSessionCache* cache, - const std::string& target, - XrtSessionCache::SessionMap* session_map); - XrtSession* GetSessionForXrtDevice(XrtSessionCache* cache, - const std::string& xrt_device, - XrtSessionCache::SessionMap* session_map); - XrtSession* GetSessionForDevice(XrtSessionCache* cache, - const std::string& device, - XrtSessionCache::SessionMap* session_map); - - const std::string& TorchDeviceToXrtDevice(const std::string& device) const; - - template - void SetupExecConfig(const Device& device, T* exec_config) const; - - std::unique_ptr CreateXrtComputation( - const xla::XlaComputation& computation, - absl::Span devices, - const xla::Shape* output_shape) const; - - tensorflow::Tensor GetArgumentsInputs(absl::Span arguments, - const std::string& device); - - std::vector CreateExecuteOps( - XrtSessionCache::SessionMap* session_map, - absl::Span computations, - const std::vector>& arguments, bool explode_tuple, - absl::Span devices, - tensorflow::ClientSession::FeedType* feed_inputs); - - std::vector CreateExecuteOps( - XrtSessionCache::SessionMap* session_map, - const XrtComputation& computation, - const std::vector>& arguments, bool explode_tuple, - absl::Span devices, - tensorflow::ClientSession::FeedType* feed_inputs); - - std::vector> RunComputations( - const XrtSessionCache::SessionMap& session_map, - const std::vector& exec_ops, - absl::Span computations, - absl::Span devices, - const tensorflow::ClientSession::FeedType& feed_inputs); - - std::vector TransferToServerHelper( - absl::Span tensors, absl::Span datas); - - std::vector TransferToServerInternal( - absl::Span tensors, absl::Span datas); - - // Retrieves the worker,worker_host pair for a given PyTorch device (ie, - // TPU:0). - std::pair GetWorkerForDevice( - const std::string& device) const; - - // Retrieves the worker,worker_host pair for a given XRT device (ie, - // /job:tpu_worker/replica:0/task:0/device:TPU:0). - std::pair GetWorkerForXrtDevice( - const std::string& xrt_device) const; - - void ReleaseHandles(std::vector* handles, - const std::function& op_generator, - metrics::Metric* timed_metric, - metrics::Counter* destroy_counter); - - void ReleaseHandle(int64_t handle, const std::string& device, - std::vector* handles); - - void ReleaseXrtData(const std::string& device, int64_t handle); - - void ReleaseXrtComputation(const std::string& compilation_device, - int64_t handle); - - // Starts the handle releaser thread (which runs the HandleReleaser() API). - void StartHandleReleaser(); - - // The handler releaser function. Runs in the releaser thread and never - // returns. - void HandleReleaser(); - - // Retrieves the mesh coordinates of a given XRT device. - const std::vector& GetDeviceMeshCoords( - const std::string& xrt_device) const; - - void InitializeDevices( - std::unique_ptr topology_proto); - - service::grpc::Config CreateMeshServiceConfig( - const tensorflow::tpu::TopologyProto* topology_proto) const; - - void CreateMeshService(const std::string& address, - const tensorflow::tpu::TopologyProto* topology_proto); - - void SetupGpuRuntime(); - - std::vector GetComputationResults( - const tensorflow::Tensor& xrt_result, const xla::Shape& result_shape, - const std::string& device); - - void InitSession(XrtSession* session) const; - - // Implement the chained execution using the XRTExecuteChained op support. - std::vector ExecuteChainedXrt(absl::Span ops, - const std::string& device); - - // Implement the chained execution using multiple XRTExecute in many RPC round - // trips. - std::vector ExecuteChainedSplit( - absl::Span ops, const std::string& device); - - // Creates an XRT graph with an XRTCompile operation: - // - // XRTCompile( - // holders[0] - // ) - // - // With: - // holders[0] = XLA Computation place-holder (DT_STRING) - const XrtSession::CachedNode& GetCompileNode(XrtSession* session, - const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRT graph with an XRTExecute operation: - // - // XRTExecute( - // holders[0], - // holders[1], - // holders[2] - // ) - // - // With: - // holders[0] = XLA Computation handle place-holder (DT_INT64) - // holders[1] = xrt::XRTExecutionConfig place-holder (DT_STRING) - // holders[2] = Inputs for the XRTExecute (DT_INT64[]) - const XrtSession::CachedNode& GetExecuteNode(XrtSession* session, - const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRT graph with an XRTExecute operation: - // - // XRTExecuteChained( - // holders[0], - // holders[1] - // ) - // - // With: - // holders[0] = xrt::XRTChainedExecutePlan place-holder (DT_STRING) - // holders[1] = xrt::XRTChainedExecuteConfig place-holder (DT_STRING) - const XrtSession::CachedNode& GetExecuteChainedNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRT graph with an XRTReadLiteral operation: - // - // XRTReadLiteral( - // holders[0] - // ) - // - // With: - // holders[0] = The handle place-holder to be read (DT_INT64) - const XrtSession::CachedNode& GetReadNode(XrtSession* session, - const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRTAllocateFromTensor node for creating a device tensor with - // the given shape and layout: - // - // XRTAllocateFromTensor( - // holders[0] - // ) - // - // With: - // holders[0] = Tensor place-holder (DT_* - depends on shape type) - const XrtSession::CachedNode& GetAllocateNode(XrtSession* session, - const tensorflow::Scope& scope, - const std::string& device, - const xla::Shape& shape) const; - - // Creates an XRTReleaseAllocationHandle node: - // - // XRTReleaseAllocationHandle( - // holders[0] - // ) - // - // With: - // holders[0] = To be released handle place-holder (DT_INT64) - const XrtSession::CachedNode& GetReleaseAllocationHandleNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRTReleaseCompilationHandle node: - // - // XRTReleaseCompilationHandle( - // holders[0] - // ) - // - // With: - // holders[0] = To be released compilation handle place-holder (DT_INT64) - const XrtSession::CachedNode& GetReleaseCompileHandleNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRTSubTuple node: - // - // XRTSubTuple( - // holders[0], - // holders[1] - // ) - // - // With: - // holders[0] = Tuple handle place-holder (DT_INT64) - // holders[1] = Tuple index place-holder (DT_INT32[]) - const XrtSession::CachedNode& GetSubTupleNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const; - - // Creates an XRTMemoryInfo node: - // - // XRTMemoryInfo() - const XrtSession::CachedNode& GetMemoryInfoNode( - XrtSession* session, const tensorflow::Scope& scope, - const std::string& device) const; - - // Checks the result of a compile operation, and dumps the XLA computation - // graphs in case of error. - static void CheckCompileStatus(const xla::Status& status, - const std::vector& instances, - const SessionWork& session_work); - - // Converts an XLA data type to a tensorflow data type. - static tensorflow::DataType XlaTypeToDataType(xla::PrimitiveType dtype); - - static tensorflow::TensorShape MakeEquivalentTensorShape( - const xla::Shape& shape); - - // Builds an argument vector usable in a replicated context, out of a single - // replica argument vector. Essentially turns a [N] into a [1][N]. - static std::vector> BuildParallelArguments( - absl::Span arguments); - - static std::vector PartitionTransferToServer( - absl::Span tensors); - - // Extracts the xla::XlaComputation pointers out of Computation ones. Used to - // be passed to xrt_util::CheckComputationStatus() for its error reporting. - static std::vector GetXlaComputations( - absl::Span computations); - - static tensorflow::ConfigProto CreateConfigProto(const Options& options); - - static tensorflow::tpu::TopologyProto InitializeAndFetchTopology( - const std::string& job, int task_no, const std::string& worker_host_port, - const tensorflow::ConfigProto& config); - - static std::string GetLocalTarget(const Options& options); - - // Checks whether a local GRPC service is required, and starts it if need it. - void MaybeCreateLocalService(const Options& options); - - Options options_; - std::mutex lock_; - std::map> device_mesh_coords_; - std::unique_ptr session_cache_; - std::unique_ptr alloc_session_cache_; - std::unique_ptr triggered_task_; - XrtLocalService* local_service_ = nullptr; - util::Cache - compilation_cache_; - std::atomic rng_seed_; - // Access to the following members must be done while holding lock_. - // XRT thread safety semantics. - std::vector released_data_handles_; - std::vector released_compile_handles_; - // The mesh service which is used to coordinate all the client hosts which are - // feeding different TPU devices in a POD (or slice) training. - std::unique_ptr mesh_service_; - std::shared_ptr> replication_devices_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_XRT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/xrt_local_service.cc b/torch_xla/csrc/runtime/xrt_local_service.cc deleted file mode 100644 index b9fe4c263c48..000000000000 --- a/torch_xla/csrc/runtime/xrt_local_service.cc +++ /dev/null @@ -1,67 +0,0 @@ -#include "torch_xla/csrc/runtime/xrt_local_service.h" - -#include - -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.h" -#include "tensorflow/core/protobuf/cluster.pb.h" -#include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/tsl/platform/errors.h" -#include "tensorflow/tsl/platform/status.h" - -namespace torch_xla { -namespace runtime { -namespace { - -void FillServerDef(const std::string& cluster_spec, const std::string& job_name, - int task_index, tensorflow::ServerDef* options) { - options->set_protocol("grpc"); - options->set_job_name(job_name); - options->set_task_index(task_index); - - size_t my_num_tasks = 0; - tensorflow::ClusterDef* cluster = options->mutable_cluster(); - for (auto& job_str : absl::StrSplit(cluster_spec, ',')) { - tensorflow::JobDef* job_def = cluster->add_job(); - // Split each entry in the flag into 2 pieces, separated by "|". - std::vector job_pieces = absl::StrSplit(job_str, '|'); - XLA_CHECK_EQ(2, job_pieces.size()) << job_str; - const std::string& cjob_name = job_pieces[0]; - const std::string& spec = job_pieces[1]; - job_def->set_name(cjob_name); - std::vector host_ports = absl::StrSplit(spec, ';'); - for (size_t i = 0; i < host_ports.size(); ++i) { - (*job_def->mutable_tasks())[i] = host_ports[i]; - } - size_t num_tasks = host_ports.size(); - if (job_name == options->job_name()) { - my_num_tasks = num_tasks; - } - LOG(INFO) << "Peer " << cjob_name << " " << num_tasks << " {" - << absl::StrJoin(host_ports, ", ") << "}"; - } - XLA_CHECK_NE(my_num_tasks, 0) << "Job '" << options->job_name() - << "' does not appear in the cluster spec"; - XLA_CHECK_LT(options->task_index(), my_num_tasks) - << "Task index " << options->task_index() << " is invalid (job '" - << options->job_name() << "' contains " << my_num_tasks << " tasks"; -} - -} // namespace - -XrtLocalService::XrtLocalService(const std::string& cluster_spec, - const std::string& job_name, int task_index) { - TF_LOG(INFO) << "libtpu status: " << tensorflow::tpu::FindAndLoadTpuLibrary(); - tensorflow::ServerDef server_def; - FillServerDef(cluster_spec, job_name, task_index, &server_def); - TF_CHECK_OK(tensorflow::NewServer(server_def, &server_)); -} - -void XrtLocalService::Start() { TF_CHECK_OK(server_->Start()); } - -void XrtLocalService::Join() { TF_CHECK_OK(server_->Join()); } - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xrt_local_service.h b/torch_xla/csrc/runtime/xrt_local_service.h deleted file mode 100644 index 6cd5964860ef..000000000000 --- a/torch_xla/csrc/runtime/xrt_local_service.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef XLA_CLIENT_XRT_LOCAL_SERVICE_H_ -#define XLA_CLIENT_XRT_LOCAL_SERVICE_H_ - -#include -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { - -// A TF server running on a local interface. -class XrtLocalService { - public: - // The cluster_spec has format: - // CLUSTER_SPEC = JOB,... - // JOB = NAME|ADDRESS_LIST - // NAME = The name of the job - // ADDRESS_LIST = HOST:PORT;... - // HOST = Hostname or IP address - // PORT = Port number - // - // The job_name must match one of the job names in the cluster_spec, and - // represents this job. - // The task_index must be within the range of the ADDRESS_LIST of the current - // job in the cluster_spec. - XrtLocalService(const std::string& cluster_spec, const std::string& job_name, - int task_index); - - // Starts the service. - void Start(); - - // Joins the service. - void Join(); - - private: - std::unique_ptr server_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_XRT_LOCAL_SERVICE_H_ diff --git a/torch_xla/csrc/runtime/xrt_session.cc b/torch_xla/csrc/runtime/xrt_session.cc deleted file mode 100644 index b58455d418de..000000000000 --- a/torch_xla/csrc/runtime/xrt_session.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "torch_xla/csrc/runtime/xrt_session.h" - -#include "absl/strings/str_cat.h" - -namespace torch_xla { -namespace runtime { - -XrtSession::XrtSession(const tensorflow::SessionOptions& session_options) - : target_(session_options.target), - root_(tensorflow::Scope::NewRootScope()), - session_(root_, session_options) {} - -void XrtSession::Reset() { - for (auto& name_cache : node_cache_) { - name_cache.second.Rewind(); - } -} - -std::string XrtSession::GetCacheKey(const std::string& op_name, - const std::string& device) { - return absl::StrCat(op_name, ";", device); -} - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xrt_session.h b/torch_xla/csrc/runtime/xrt_session.h deleted file mode 100644 index e734cbdc6627..000000000000 --- a/torch_xla/csrc/runtime/xrt_session.h +++ /dev/null @@ -1,102 +0,0 @@ -#ifndef XLA_CLIENT_XRT_SESSION_H_ -#define XLA_CLIENT_XRT_SESSION_H_ - -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/xla/types.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { - -// Encapsulates an XRT session and its associated node cache. XrtSession are not -// thread safe, but are always accessed by one thread at a time. The -// XrtSessionCache will keep creating new sessions if not enough are available -// to satisfy the threads requests. -class XrtSession { - public: - // A cached node captures that single node, or the mini-graph root node, - // together with the place-holders necessary to feed the node/sub-graph. - // The end-point node can be either a tensorflow Operation or an Output. - struct CachedNode { - CachedNode(tensorflow::Output output, - std::vector holders) - : holders(std::move(holders)) { - outputs.push_back(std::move(output)); - } - CachedNode(tensorflow::Operation operation, - std::vector holders) - : holders(std::move(holders)) { - operations.push_back(std::move(operation)); - } - CachedNode(std::vector outputs, - std::vector holders) - : outputs(std::move(outputs)), holders(std::move(holders)) {} - CachedNode(std::vector operations, - std::vector holders) - : operations(std::move(operations)), holders(std::move(holders)) {} - - std::vector outputs; - std::vector operations; - std::vector holders; - }; - - // The node cache holds a set of CachedNode of the same kind (by the means of - // the NodeTypes entries). - // The NodeCache access is not thread safe, but so is XrtSession. - class NodeCache { - public: - bool Empty() const { return position_ >= nodes_.size(); } - - const CachedNode& Get() { - XLA_CHECK_LT(position_, nodes_.size()); - ++position_; - return *nodes_[position_ - 1]; - } - - void Add(std::shared_ptr node) { - nodes_.push_back(std::move(node)); - } - - void Rewind() { position_ = 0; } - - private: - std::vector> nodes_; - size_t position_ = 0; - }; - - explicit XrtSession(const tensorflow::SessionOptions& session_options); - - const std::string& target() const { return target_; } - - tensorflow::Scope* root() { return &root_; } - - tensorflow::ClientSession* session() { return &session_; } - - NodeCache* GetNodeCache(const std::string& key) { return &node_cache_[key]; } - - void Reset(); - - static std::string GetCacheKey(const std::string& op_name, - const std::string& device); - - private: - std::string target_; - tensorflow::Scope root_; - tensorflow::ClientSession session_; - std::map node_cache_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_XRT_SESSION_H_ diff --git a/torch_xla/csrc/runtime/xrt_session_cache.cc b/torch_xla/csrc/runtime/xrt_session_cache.cc deleted file mode 100644 index 030a7945554c..000000000000 --- a/torch_xla/csrc/runtime/xrt_session_cache.cc +++ /dev/null @@ -1,74 +0,0 @@ -#include "torch_xla/csrc/runtime/xrt_session_cache.h" - -#include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/sys_util.h" - -namespace torch_xla { -namespace runtime { - -XrtSessionCache::XrtSessionCache(tensorflow::ConfigProto config, - std::function initfn, - std::string local_target) - : config_(std::move(config)), - initfn_(std::move(initfn)), - local_target_(std::move(local_target)) {} - -XrtSessionCache::Ref XrtSessionCache::GetSession(const std::string& target) { - std::lock_guard lock(lock_); - auto& session_queue = session_map_[target]; - if (!session_queue.empty()) { - std::shared_ptr session = std::move(session_queue.back()); - session_queue.pop_back(); - session->Reset(); - return Ref(this, std::move(session)); - } - return Ref(this, CreateSession(target)); -} - -XrtSession* XrtSessionCache::GetSession(const std::string& target, - SessionMap* session_map) { - auto it = session_map->find(target); - if (it == session_map->end()) { - it = session_map->emplace(target, GetSession(target)).first; - } - return it->second.get(); -} - -void XrtSessionCache::AddSession(std::shared_ptr session) { - std::lock_guard lock(lock_); - session_map_[session->target()].push_back(std::move(session)); -} - -std::shared_ptr XrtSessionCache::CreateSession( - const std::string& target) const { - XLA_COUNTER("XrtSessionCount", 1); - tensorflow::SessionOptions session_options; - session_options.env = tsl::Env::Default(); - session_options.target = target; - if (target != local_target_) { - session_options.config = config_; - } - - tensorflow::RPCOptions* rpc_options = - session_options.config.mutable_rpc_options(); - - std::string compression = sys_util::GetEnvString("XRT_GRPC_COMPRESSION", ""); - if (!compression.empty()) { - rpc_options->set_compression_algorithm(compression); - rpc_options->set_compression_level( - sys_util::GetEnvInt("XRT_GRPC_COMPRESSION_LEVEL", 3)); - } - - bool multi_stream = sys_util::GetEnvBool("XRT_GRPC_MULTISTREAM", true); - rpc_options->set_disable_session_connection_sharing(multi_stream); - - std::shared_ptr session = - std::make_shared(session_options); - if (initfn_ != nullptr) { - initfn_(session.get()); - } - return session; -} - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xrt_session_cache.h b/torch_xla/csrc/runtime/xrt_session_cache.h deleted file mode 100644 index 74a365cf14b6..000000000000 --- a/torch_xla/csrc/runtime/xrt_session_cache.h +++ /dev/null @@ -1,101 +0,0 @@ -#ifndef XLA_CLIENT_XRT_SESSION_CACHE_H_ -#define XLA_CLIENT_XRT_SESSION_CACHE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/types.h" -#include "torch_xla/csrc/runtime/xrt_session.h" - -namespace torch_xla { -namespace runtime { - -// Caches XrtSession objects. The XrtSession objects handed out by this class -// will be at exclusive use of the caller. -class XrtSessionCache { - public: - // A reference to an existing XrtSession. Its destructor will return it to the - // cache. - class Ref { - public: - Ref(XrtSessionCache* cache, std::shared_ptr session) - : cache_(cache), session_(std::move(session)) {} - - Ref(Ref&& ref) { MoveFrom(std::move(ref)); } - - Ref(const Ref&) = delete; - - ~Ref() { ReturnToCache(); } - - Ref& operator=(Ref&& rhs) { - if (&rhs != this) { - MoveFrom(std::move(rhs)); - } - return *this; - } - - Ref& operator=(const Ref&) = delete; - - XrtSession* operator->() const { return get(); } - - XrtSession* get() const { return session_.get(); } - - private: - void MoveFrom(Ref&& rhs) { - ReturnToCache(); - cache_ = rhs.cache_; - rhs.cache_ = nullptr; - session_ = std::move(rhs.session_); - } - - void ReturnToCache() { - if (cache_ != nullptr) { - cache_->AddSession(std::move(session_)); - cache_ = nullptr; - } - } - - XrtSessionCache* cache_ = nullptr; - std::shared_ptr session_; - }; - - // Map from session target to XrtSession reference. - using SessionMap = std::map; - - XrtSessionCache(tensorflow::ConfigProto config, - std::function initfn, - std::string local_target); - - const tensorflow::ConfigProto& GetConfig() const { return config_; } - - // Retrieves a new session reference, for which the caller will have exclusive - // access. Once the reference object is destroyed, the session will be - // returned to the cache. - Ref GetSession(const std::string& target); - - // Retrieves an XRT session by first checking the references already stored in - // the session_map, and, if missing, one will be fetched from the cache and - // added to the session_map. - XrtSession* GetSession(const std::string& target, SessionMap* session_map); - - void AddSession(std::shared_ptr session); - - private: - std::shared_ptr CreateSession(const std::string& target) const; - - tensorflow::ConfigProto config_; - std::function initfn_; - std::string local_target_; - std::mutex lock_; - std::map>> session_map_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_XRT_SESSION_CACHE_H_ diff --git a/torch_xla/distributed/_xrt_run_server.py b/torch_xla/distributed/_xrt_run_server.py deleted file mode 100644 index e8ed8feb339e..000000000000 --- a/torch_xla/distributed/_xrt_run_server.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -This script is for starting the xrt_server. It also polls the PID and -checks if it exist. It would kill the server, when the process whose -PID it was tracking dies. -NOTE: This script should be used only by xrt_init.py and not anyone else. -""" -import os -import argparse -import psutil -import time -import signal -import multiprocessing -import torch_xla - - -def _polling(pid_to_track): - - def is_pid_alive(pid): - # The idea behind this is: if the process doesn't exist, - # getting a process status should throw an error. - # If the process exist, then we check if it hasn't gone - # into zombie state. This can happen when we run torchrun - # from neuron_parallel_compile. - try: - return psutil.Process(pid).status() != psutil.STATUS_ZOMBIE - except: - return False - - while is_pid_alive(pid_to_track): - time.sleep(10) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--port", required=True) - parser.add_argument("--pid_to_track", default=None) - args = parser.parse_args() - polling_process = multiprocessing.Process( - target=_polling, args=(int(args.pid_to_track),)) - server_process = multiprocessing.Process( - target=torch_xla._XLAC._run_xrt_local_service, args=(int(args.port),)) - polling_process.start() - server_process.start() - polling_process.join() - os.kill(server_process.pid, signal.SIGKILL) diff --git a/torch_xla/distributed/cluster.py b/torch_xla/distributed/cluster.py deleted file mode 100644 index fe4a2eba2ff4..000000000000 --- a/torch_xla/distributed/cluster.py +++ /dev/null @@ -1,477 +0,0 @@ -import cloud_tpu_client -import logging -import multiprocessing -import re -import requests -import subprocess -import time - -from torch_xla.distributed.worker import ClientWorker -from torch_xla.distributed.worker import ServiceWorker -import torch_xla.utils.utils as xu - -try: - from googleapiclient import discovery - from oauth2client.client import GoogleCredentials -except ImportError: - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the xla_dist. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip') - -_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal' - -# Silence noisy logging -logging.getLogger('oauth2client').setLevel(logging.ERROR) -logging.getLogger('googleapiclient').setLevel(logging.ERROR) - - -class Cluster(object): - - def __init__(self, - client_workers, - service_workers, - check_client_machine_type=True, - check_service_machine_type=True, - client_master_ip=None): - """Creates a cluster object. - - Args: - client_workers: a list of ClientWorker objects. - service_workers: a list of ServiceWorker objects. - check_client_machine_type: whether to check if client workers all have the - same machine type. - check_service_machine_type: whether to check if service workers all have - the same machine type. - client_master_ip: the ip of client worker to set as master. If not - provided, the VM running the current process is the master. - """ - for client_worker in client_workers: - if not isinstance(client_worker, ClientWorker): - raise ValueError( - 'client_workers argument must be a list of ClientWorker') - for service_worker in service_workers: - if not isinstance(service_worker, ServiceWorker): - raise ValueError( - 'service_workers argument must be a list of ServiceWorker') - self._client_workers = list(client_workers) - self._service_workers = list(service_workers) - self._check_client_machine_type = check_client_machine_type - self._check_service_machine_type = check_service_machine_type - - if not client_master_ip: - client_master_ip = ClusterResolver.get_instance_metadata( - 'instance/network-interfaces/0/ip') - self._client_master = next( - filter(lambda cw: cw.get_internal_ip() == client_master_ip, - self._client_workers)) - - # Put client master at front of client worker list. - self._client_workers.remove(self._client_master) - self._client_workers.insert(0, self._client_master) - - def get_client_master(self): - return self._client_master - - def get_client_workers(self): - return self._client_workers - - def get_service_workers(self): - return self._service_workers - - def validate(self): - """Validates the current cluster configuration. - - Raises: - RuntimeError: If the cluster is misconfigured, this validation will - raise an error. For example, if the VMs are in different zones, - or not all of the CPU workers have the same size (number of CPU - cores, RAM size) we raise an exception. For TPUs we similarly - raise an exception if different zones or machine/accelerator_type. - """ - if len(self._client_workers) == 0 or len(self._service_workers) == 0: - raise RuntimeError( - 'Both client_workers and service_workers should not be empty') - - if len(self._client_workers) != len(self._service_workers): - raise RuntimeError( - 'The client_workers and service_workers must have a 1:1 mapping') - - zones = {worker._zone for worker in self._client_workers} - zones.update(worker._zone for worker in self._service_workers) - if len(zones) != 1: - raise RuntimeError( - 'All workers must be in the same zone, got: {}'.format(zones)) - - if self._check_client_machine_type: - client_machine_types = { - worker._machine_type for worker in self._client_workers - } - if len(client_machine_types) != 1: - raise RuntimeError( - 'All client_workers must have the same machine_type, got: {}'. - format(client_machine_types)) - - if self._check_service_machine_type: - server_machine_types = { - worker._machine_type for worker in self._service_workers - } - if len(server_machine_types) != 1: - raise RuntimeError( - 'All service_workers must have the same machine_type, got: {}'. - format(server_machine_types)) - - runtime_versions = { - worker._runtime_version for worker in self._service_workers - } - if len(runtime_versions) != 1: - raise RuntimeError( - 'All service workers must have the same runtime_version, got: {}'. - format(zones)) - - def __eq__(self, other): - return (self._client_workers == other._client_workers and - self._service_workers == other._service_workers) - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return ('{{client_workers: {client_workers}, ' - 'service_workers: {service_workers}}}').format( - client_workers=self._client_workers, - service_workers=self._service_workers) - - def list_tpus_with_health(self, health): - - def _tpu_with_health(tpu_name): - ctc = cloud_tpu_client.Client(tpu_name) - if ctc.health() == health: - return tpu_name - - tpus = set() - for service_worker in self._service_workers: - tpus.add(service_worker._tpu) - results = xu.parallel_work(len(tpus), _tpu_with_health, tpus) - return [res for res in results if res] - - def wait_for_healthy_service(self): - - def wait_for_healthy_service_worker(tpu_name): - ctc = cloud_tpu_client.Client(tpu=tpu_name) - ctc.wait_for_healthy() - - tpus = self.list_tpus_with_health('UNHEALTHY_MAINTENANCE') - if tpus: - xu.parallel_work(len(tpus), wait_for_healthy_service_worker, tpus) - - def wait_for_healthy_client(self, dist_executor, timeout=1200, interval=10): - - def wait_for_healthy_client_worker(client_worker): - heartbeat_check = [ - 'echo', 'client_worker', '$(hostname)', 'is', 'healthy' - ] - check_timeout = time.time() + timeout - - def _healthy_client_worker(): - proc = multiprocessing.Process( - target=dist_executor._build_and_run_ssh, - args=( - heartbeat_check, - client_worker, - )) - proc.daemon = True - proc.start() - proc.join(interval) - - if proc.is_alive(): - proc.terminate() - return False - - return proc.exitcode == 0 - - while not _healthy_client_worker(): - logging.warning( - 'Waiting for client_worker "{}" to become healthy'.format( - client_worker)) - if time.time() + interval > check_timeout: - raise RuntimeError( - 'Timed out waiting for client_worker {} to become healthy'.format( - client_worker)) - - logging.warning('client_worker "{}" is healthy.'.format(client_worker)) - - xu.parallel_work( - len(self._client_workers), wait_for_healthy_client_worker, - self._client_workers) - - -class ClusterResolver(object): - """Cluster Resolver for Client VM and Cloud TPU mesh.""" - - @staticmethod - def get_instance_metadata(metadata): - response = requests.get( - '{}/computeMetadata/v1/{}'.format(_GCE_METADATA_ENDPOINT, metadata), - headers={'Metadata-Flavor': 'Google'}) - return response.content.decode('utf-8') - - @staticmethod - def _parse_resource_url(url, name): - parts = url.split('/') - idx = parts.index(name) - return parts[idx + 1] - - @staticmethod - def _get_internal_ip_to_hostname_mapping(tpu_name, zone, num_vm): - """Gets TPU VM internal IP to hostname mapping. - - Currently TPU CLH does not expose any TPU host machine name. SSH to each worker and - get that instead. - - Returns: - A map of TPU VM internal IP to TPU VM hostname. - """ - ip_to_host_name = {} - - def add_tpuvm_ip_to_hostname_mapping(worker_index): - proc = subprocess.Popen([ - 'gcloud', 'alpha', 'compute', 'tpus', 'tpu-vm', 'ssh', - '--internal-ip', tpu_name, '--zone', zone, '--worker', - str(worker_index), '--command', 'hostname; hostname -i' - ], - stdout=subprocess.PIPE) - hostname = proc.stdout.readline().decode('utf-8').rstrip('\n') - ip = proc.stdout.readline().decode('utf-8').rstrip('\n') - ip_to_host_name[ip] = hostname - - xu.parallel_work(num_vm, add_tpuvm_ip_to_hostname_mapping, - list(range(num_vm))) - return ip_to_host_name - - def __init__(self, tpu, vms=None, zone=None, project=None): - """Creates a new ClusterResolver object.""" - - if not tpu: - raise ValueError('tpu must be a non-empty string') - if vms: - if not isinstance(vms, list) or len(vms) == 0: - raise ValueError('vms must be a non-empty list if provided') - - self._tpus = tpu if isinstance(tpu, list) else [tpu] - self._vms = vms - self._zone = zone - self._project = project - self._tpuvm_mode = None - self._tpuvm_mode_with_remote_coordinator = None - self._set_tpuvm_mode() - - self._compute_service = discovery.build( - 'compute', - 'v1', - credentials=GoogleCredentials.get_application_default(), - cache_discovery=False) - - if project is None: - self._project = ClusterResolver.get_instance_metadata( - 'project/project-id') - if zone is None: - zone_path = ClusterResolver.get_instance_metadata('instance/zone') - self._zone = ClusterResolver._parse_resource_url(zone_path, 'zones') - self._vm_master = ClusterResolver.get_instance_metadata('instance/name') - - def _set_tpuvm_mode(self): - self._tpuvm_mode = False - self._tpuvm_mode_with_remote_coordinator = False - accel_type = ClusterResolver.get_instance_metadata( - 'instance/attributes/accelerator-type') - if re.match(r'v[0-9]+-[0-9]+', accel_type): - # Only VM with TPU attached will carry the accelerator-type metadata - self._tpuvm_mode = True - return - - api_version = cloud_tpu_client.Client( - tpu=self._tpus[0])._get_tpu_property('apiVersion') - if api_version == 'V2_ALPHA1': - # Only TPUVM api version should be V2_ALPHA1 - self._tpuvm_mode = True - # Current vm does not carry the accelerator-type metadata but tpu specified - # is a TPUVM, assume it is a remote coordinator. - self._tpuvm_mode_with_remote_coordinator = True - - def _get_instance_group(self): - """Gets the instance group that the current VM belongs to.""" - resp = self._compute_service.instances().get( - project=self._project, - zone=self._zone, - instance=self._vm_master, - fields='metadata').execute() - - if 'metadata' in resp and 'items' in resp['metadata']: - for item in resp['metadata']['items']: - if (item['key'] == 'created-by' and - 'instanceGroupManagers' in item['value']): - return ClusterResolver._parse_resource_url(item['value'], - 'instanceGroupManagers') - - raise RuntimeError(('A vm list must be passed to ClusterResolver ' - 'if not using an instance group')) - - def _get_member_instance_names(self, instance_group): - """Gets all the instance names that belong to the given instance group.""" - resp = self._compute_service.instanceGroups().listInstances( - project=self._project, zone=self._zone, - instanceGroup=instance_group).execute() - - instances = [] - for item in resp.get('items', []): - if 'instance' not in item or 'status' not in item: - continue - instance_path = item['instance'] - instances.append( - ClusterResolver._parse_resource_url(instance_path, 'instances')) - - return instances - - def get_client_workers(self): - """Gets client workers. - - The instance group that the current VM belongs to is picked up from - the GCE instance metadata set of the VM. If a list of VMs was used for - initializing cluster resolver, we use that instead. - - Returns: - A list of ClientWorker. - - Raises: - RuntimeError: If the red VM cluster is not healthy. - """ - if not self._vms: - # Using an instance group - instance_group = self._get_instance_group() - self._vms = self._get_member_instance_names(instance_group) - if len(self._vms) == 0: - raise RuntimeError('Client worker vms is empty in instance group') - - workers = [] - batch = self._compute_service.new_batch_http_request() - - def add_client_worker(request_id, resp, exception): - """Callback for each request in BatchHttpRequest.""" - if exception is not None: - raise exception - hostname = ClusterResolver._parse_resource_url(resp['selfLink'], - 'instances') - if resp['status'] != 'RUNNING': - raise RuntimeError( - ('Instance {hostname} is not running yet. ' - 'Re-run when all VMs are running').format(hostname=hostname)) - worker = ClientWorker( - internal_ip=resp['networkInterfaces'][0]['networkIP'], - machine_type=ClusterResolver._parse_resource_url( - resp['machineType'], 'machineTypes'), - zone=ClusterResolver._parse_resource_url(resp['zone'], 'zones'), - hostname=hostname) - workers.append(worker) - - for vm in self._vms: - req = self._compute_service.instances().get( - project=self._project, - zone=self._zone, - instance=vm, - fields=('machineType,metadata,selfLink,' - 'networkInterfaces/networkIP,status,zone')) - batch.add(req, add_client_worker) - batch.execute() - - return workers - - def get_tpu_workers(self, as_client_worker=False): - """Gets TPU VM cluster info. - - Calls the TPU CLH to get TPU node data and returns list of TPU worker - VMs internal IP addresses. If zone and project are not specified at - ClusterResolver init time, we infer these bits from GCE metadata. - - Returns: - A list of ServiceWorker or a list of ClientWorker. - - Raises: - RuntimeError: If the TPU DNE or the TPU is in not in HEALTHY state. - """ - workers = [] - - def add_tpu_worker(tpu_name): - ctc = cloud_tpu_client.Client(tpu=tpu_name) - tpu_name = ctc.name() - if ctc.state() != 'READY': - raise RuntimeError( - ('TPU {tpu_name} is not READY yet. ' - 'Re-run when all TPUs are READY').format(tpu_name=tpu_name)) - if ctc.health() != 'HEALTHY': - raise RuntimeError( - ('TPU {tpu_name} is not HEALTHY yet. ' - 'Re-run when all TPUs are HEALTHY').format(tpu_name=tpu_name)) - - runtime_version = ctc.runtime_version() - machine_type = ctc.accelerator_type() - zone = ClusterResolver._parse_resource_url(ctc._full_name(), 'locations') - network_endpoints = ctc.network_endpoints() - - if as_client_worker: - ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping( - tpu_name, zone, len(network_endpoints)) - - for endpoint in network_endpoints: - if as_client_worker: - internal_ip = endpoint['ipAddress'] - hostname = ip_to_host_name[internal_ip] - worker = ClientWorker( - internal_ip=internal_ip, - machine_type=machine_type, - zone=zone, - hostname=hostname) - else: - worker = ServiceWorker( - internal_ip=endpoint['ipAddress'], - port=endpoint['port'], - machine_type=machine_type, - zone=zone, - runtime_version=runtime_version, - tpu=tpu_name) - workers.append(worker) - - xu.parallel_work(len(self._tpus), add_tpu_worker, self._tpus) - - return workers - - def get_cluster(self): - """Gets client and server side cluster info. - - If a list of vms is not provided at ClusterResolver crate time the current - VM's instance group is picked up and we use that to resolve the VM mesh. - - Returns: - A Cluster object with both client and server mesh configuration. - - Raises: - RuntimeError: If the VM cluster is not healthy. Also if the TPU - cluster is not healthy. - """ - service_workers = self.get_tpu_workers(as_client_worker=False) - client_workers = self.get_tpu_workers( - as_client_worker=True) if self._tpuvm_mode else self.get_client_workers( - ) - client_master_ip = None - if self._tpuvm_mode_with_remote_coordinator: - # If the script is being run from a remote coordinator with a TPUVM, client_master_ip - # should be TPUVM IP instead of the remote coordinator IP. - client_master_ip = client_workers[0].get_internal_ip() - cluster = Cluster( - client_workers, service_workers, client_master_ip=client_master_ip) - cluster.validate() - return cluster - - def get_tpuvm_mode(self): - return self._tpuvm_mode diff --git a/torch_xla/distributed/worker.py b/torch_xla/distributed/worker.py deleted file mode 100644 index 9962125fda10..000000000000 --- a/torch_xla/distributed/worker.py +++ /dev/null @@ -1,114 +0,0 @@ -class Worker(object): - - def __init__(self, internal_ip, machine_type, zone): - if not isinstance(internal_ip, str): - raise ValueError('internal_ip must be of type str') - self._internal_ip = internal_ip - if not isinstance(machine_type, str): - raise ValueError('machine_type must be of type str') - self._machine_type = machine_type - if not isinstance(zone, str): - raise ValueError('zone must be of type str') - self._zone = zone - - def get_internal_ip(self): - return self._internal_ip - - def get_zone(self): - return self._zone - - -class ClientWorker(Worker): - - def __init__(self, internal_ip, machine_type, zone, hostname=None): - super(ClientWorker, self).__init__(internal_ip, machine_type, zone) - if hostname is not None and not isinstance(hostname, str): - raise ValueError('hostname must be of type str') - self._hostname = hostname - - def get_hostname(self): - return self._hostname - - def __repr__(self): - return ('{{{internal_ip}, {machine_type}, {zone},' - ' {hostname}}}').format( - internal_ip=self._internal_ip, - machine_type=self._machine_type, - zone=self._zone, - hostname=self._hostname) - - def __eq__(self, other): - return (self._internal_ip == other._internal_ip and - self._machine_type == other._machine_type and - self._zone == other._zone and self._hostname == other._hostname) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(repr(self)) - - def __repr__(self): - return ('{{{internal_ip}, {machine_type}, {zone},' - ' {hostname}}}').format( - internal_ip=self._internal_ip, - machine_type=self._machine_type, - zone=self._zone, - hostname=self._hostname) - - def __eq__(self, other): - return (self._internal_ip == other._internal_ip and - self._machine_type == other._machine_type and - self._zone == other._zone and self._hostname == other._hostname) - - def __ne__(self, other): - return not self.__eq__(self, other) - - def __hash__(self): - return hash(repr(self)) - - -class ServiceWorker(Worker): - - def __init__(self, - internal_ip, - port, - machine_type, - zone, - runtime_version, - tpu=None): - super(ServiceWorker, self).__init__(internal_ip, machine_type, zone) - self._port = int(port) - if not isinstance(runtime_version, str): - raise ValueError('runtime_version must be of type str') - self._runtime_version = runtime_version - if tpu is not None and not isinstance(tpu, str): - raise ValueError('tpu must be of type str') - self._tpu = tpu - - def get_port(self): - return self._port - - def __repr__(self): - return ('{{{internal_ip}, {port}, {machine_type}, {zone},' - ' {runtime_version}, {tpu}}}').format( - internal_ip=self._internal_ip, - port=self._port, - machine_type=self._machine_type, - zone=self._zone, - runtime_version=self._runtime_version, - tpu=self._tpu) - - def __eq__(self, other): - return (self._internal_ip == other._internal_ip and - self._port == other._port and - self._machine_type == other._machine_type and - self._zone == other._zone and - self._runtime_version == other._runtime_version and - self._tpu == other._tpu) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(repr(self)) diff --git a/torch_xla/distributed/xla_dist.py b/torch_xla/distributed/xla_dist.py deleted file mode 100755 index a1c8b2c472b9..000000000000 --- a/torch_xla/distributed/xla_dist.py +++ /dev/null @@ -1,696 +0,0 @@ -#!/usr/bin/env python -"""Tool to distribute training on Cloud TPU Pods.""" - -import argparse -import cloud_tpu_client -import logging -import multiprocessing -import os -import re -import signal -import subprocess -import sys -import time -import threading -from torch_xla import runtime as xr -import torch_xla.core.xla_env_vars as xenv -from torch_xla.distributed.cluster import ClusterResolver -import torch_xla.utils.utils as xu - - -def get_args_parser() -> argparse.ArgumentParser: - """Helper function parsing the command line options.""" - - parser = argparse.ArgumentParser( - description='PyTorch on TPU distributed training launcher.', - epilog=('Usage example: python3 -m' - ' torch_xla.distributed.xla_dist --tpu=[TPU_NAME]' - ' --conda-env torch-xla-nightly -- python3 train.py')) - - cluster_group = parser.add_argument_group('Cluster Setup') - cluster_group.add_argument( - '--tpu', type=str, required=True, help='Name of the Cloud TPU pod.') - cluster_group.add_argument( - '--vm', - action='append', - type=str, - help=('List of single Compute VM instance names. ' - 'If not provided we assume usage of instance groups.')) - - docker_group = parser.add_argument_group('Docker Setup') - docker_group.add_argument( - '--docker-container', - default='', - type=str, - help='Name of docker container if running in docker.') - docker_group.add_argument( - '--docker-image', - default='', - type=str, - help='Name of docker image if running in container.') - docker_group.add_argument( - '--docker-run-flag', - action='append', - type=str, - help='Docker run flags to run container with (ex. --shm-size, ...).') - - conda_group = parser.add_argument_group('Conda Setup') - conda_group.add_argument( - '--conda-env', - default='', - type=str, - help='Name of the conda environment if running with conda.') - - parser.add_argument( - '--env', - action='append', - type=str, - help='List of environment variables to distribute.') - parser.add_argument( - '--restart-tpuvm-pod-server', - action='store_true', - help='Restart the long running XRT local service for this training.') - parser.add_argument( - '--tpuvm-server-port', - default=51011, - type=int, - help='Port that XRT local service will be start on.') - parser.add_argument( - 'positional', - nargs='+', - type=str, - help='The python command to launch training including model parameters.') - return parser - - -def parse_args(args): - parser = get_args_parser() - return parser.parse_args(args) - - -def resolve_and_execute(flags): - """Resolves the command line flags and launches a distributed process""" - cluster_resolver = ClusterResolver(flags.tpu, vms=flags.vm) - cluster = cluster_resolver.get_cluster() - tpuvm_mode = cluster_resolver.get_tpuvm_mode() - executor = DistributedExecutor( - cluster, - docker_container=flags.docker_container, - docker_image=flags.docker_image, - docker_run_flags=flags.docker_run_flag, - conda_env=flags.conda_env, - env_vars=flags.env, - restart_server=flags.restart_tpuvm_pod_server, - tpuvm_mode=tpuvm_mode, - tpuvm_server_port=flags.tpuvm_server_port) - executor.run(flags.positional) - - -def concat_cmd_list(cmd_list, delimiter=' ', quote='"'): - concat = '' - for cmd in cmd_list: - if re.match('^{}.*{}$'.format(quote, quote), cmd): - token = cmd - else: - token = quote + cmd + quote - if concat: - concat += delimiter - concat += token - return concat - - -class DistributedExecutor(object): - - SCRIPT_PATH_TMPL = '/tmp/{pid}/dist_training_ptxla_{worker}.sh' - XRT_RUN_SERVER_CMD = 'torch_xla.core.xrt_run_server' - XRT_RUN_SERVER_PROCESS = 'torch_xla.core._xrt_run_server' - MESH_SERVICE_PORT = 8477 # Use single port to disallow concurrent runs - DIST_ENV_VARS = [ - xenv.TPU_CONFIG, - xenv.LOCAL_WORKER, - xenv.SERVICE_ADDRESS, - xenv.WORLD_SIZE, - xenv.ORDINAL, - xenv.TPU_NUM_DEVICES, - 'XLA_EMIT_STEPLOG', - ] - DEFAULT_CONTAINER_NAME = 'pytorchtpudistrunner' - MAX_TPU_RETRY = 50 - HEARTBEAT_CHECK_PERIOD = 30 - - def _get_logger(self): - logger = logging.getLogger(self.__class__.__name__) - logger.setLevel(logging.INFO) - logger.propagate = False - formatter = logging.Formatter( - fmt='%(asctime)-12s %(clientip)s [%(ordinal)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S') - sh = logging.StreamHandler() - sh.setLevel(logging.INFO) - sh.setFormatter(formatter) - logger.addHandler(sh) - return logger - - def _initialize(self): - """Initializes members that need to be cleanly initialized for each run.""" - self._last_heartbeats = { - cw.get_internal_ip(): { - 'last_time': time.time(), - 'count': 0, - } for cw in self._cluster.get_client_workers() - } - self._error_queue = multiprocessing.Queue() - self._last_heartbeat_check_time = 0 - - def __init__(self, - cluster, - docker_container=None, - docker_image=None, - docker_run_flags=None, - conda_env=None, - env_vars=None, - restart_server=None, - tpuvm_mode=None, - tpuvm_server_port=None): - self._cluster = cluster - self._initialize() - self.logger = self._get_logger() - self.docker_container = docker_container or self.DEFAULT_CONTAINER_NAME - self.docker_image = docker_image - self.docker_run_flags = list(docker_run_flags) if docker_run_flags else [] - self.conda_env = conda_env - self.env_vars = list(env_vars) if env_vars else [] - self.tpuvm_mode = tpuvm_mode - self.restart_server = restart_server - self.tpuvm_server_port = tpuvm_server_port - self.tpu_name = self._cluster.get_service_workers()[0]._tpu - - for env_var in self.env_vars: - if re.match(r'\w*=\w*', env_var) is None: - raise ValueError( - ('Environment variable to distribute ({}) should follow ' - 'the form: X=Y').format(env_var)) - for dist_var in self.DIST_ENV_VARS: - if re.match('{}=.*'.format(dist_var), env_var): - raise ValueError( - ('{} should not be in the training command provided as they' - ' will interfere with the values set for distributed' - ' training'.format(dist_var))) - - def _check_client_mesh_health(self, uneven_health_timeout, - even_health_timeout): - min_delay = max(uneven_health_timeout, even_health_timeout) + 1 - count = None - now = time.time() - if xu.getenv_as('XLA_DEBUG_LOG_HEARTBEATS', bool, False): - self.logger.info( - 'Worker Heartbeats: {}'.format(self._last_heartbeats), - extra={ - 'clientip': '', - 'ordinal': '' - }) - - for cw_hb in self._last_heartbeats.values(): - min_delay = min(min_delay, now - cw_hb['last_time']) - if count is None: - count = cw_hb['count'] - elif count >= 0 and count != cw_hb['count']: - count = -1 - - if count < 0 and min_delay > uneven_health_timeout: - self._error_queue.put( - RuntimeError('Client mesh is unhealthy with uneven heartbeats')) - elif count > 0 and min_delay > even_health_timeout: - self._error_queue.put( - RuntimeError('Client mesh is unhealthy with even heartbeats')) - - def _stream_logs(self, process, client_worker): - client_ip = client_worker.get_internal_ip() - ordinal = self._cluster.get_client_workers().index(client_worker) - - def _stream_output(stream, log_fn): - for std in iter(stream.readline, b''): - std_line = std.decode('utf-8').rstrip('\n') - if 'torch_xla.core.xla_model::mark_step' in std_line: - hb_stream = self._last_heartbeats[client_ip] - # Only single thread updates each of these, so there is no race - hb_stream['last_time'] = time.time() - hb_stream['count'] += 1 - continue - log_fn(std_line, extra={'clientip': client_ip, 'ordinal': ordinal}) - - stdout = threading.Thread( - target=_stream_output, - daemon=True, - args=( - process.stdout, - self.logger.info, - )) - stdout.start() - stderr = threading.Thread( - target=_stream_output, - daemon=True, - args=( - process.stderr, - self.logger.error, - )) - stderr.start() - stdout.join() - stderr.join() - - def _is_retry(self): - return self.trials >= 1 - - def _build_scp_cmd(self, local_path, remote_path, client_worker): - if not self._is_retry(): - if self.tpuvm_mode: - return [ - 'gcloud', - 'alpha', - '-q', - 'compute', - 'tpus', - 'tpu-vm', - 'scp', - '--internal-ip', - '--zone={}'.format(client_worker.get_zone()), - '--worker={}'.format(client_worker.get_hostname().split('-')[-1]), - local_path, - '{}:{}'.format(self.tpu_name, remote_path), - ] - else: - return [ - 'gcloud', - '-q', - 'compute', - 'scp', - '--internal-ip', - '--zone={}'.format(client_worker.get_zone()), - local_path, - '{}:{}'.format(client_worker.get_hostname(), remote_path), - ] - - return [ - 'scp', - '-oStrictHostKeyChecking=no', - '-i', - '~/.ssh/google_compute_engine', - local_path, - '{}@{}:{}'.format(os.getlogin(), client_worker.get_hostname(), - remote_path), - ] - - def _build_ssh_cmd(self, remote_cmd, client_worker): - if isinstance(remote_cmd, list): - remote_cmd = concat_cmd_list(remote_cmd) - if not self._is_retry(): - if self.tpuvm_mode: - return [ - 'gcloud', - 'alpha', - '-q', - 'compute', - 'tpus', - 'tpu-vm', - 'ssh', - '--internal-ip', - '{}'.format(self.tpu_name), - '--zone {}'.format(client_worker.get_zone()), - '--worker {}'.format(client_worker.get_hostname().split('-')[-1]), - '--command', - '\'{}\''.format(remote_cmd), - ] - else: - return [ - 'gcloud', - '-q', - 'compute', - 'ssh', - '--internal-ip', - '--zone={}'.format(client_worker.get_zone()), - '{}'.format(client_worker.get_hostname()), - '--command', - '\'{}\''.format(remote_cmd), - ] - return [ - 'ssh', - '-oStrictHostKeyChecking=no', - '-i', - '~/.ssh/google_compute_engine', - '{}@{}'.format(os.getlogin(), client_worker.get_hostname()), - '\'{}\''.format(remote_cmd), - ] - - def _run_remote_cmd(self, cmd, client_worker, shell=True, log=True): - cmd = concat_cmd_list(cmd, quote='') if shell else cmd - proc = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell) - if log: - self._stream_logs(proc, client_worker) - proc.wait() - if proc.returncode == 255: - self._error_queue.put( - RuntimeError( - 'Client mesh is unhealthy due to dead client worker: {}'.format( - client_worker))) - return proc.returncode - - def _build_and_run_ssh(self, remote_cmd, client_worker, shell=True, log=True): - cmd = self._build_ssh_cmd(remote_cmd, client_worker) - return self._run_remote_cmd(cmd, client_worker, shell=shell, log=log) - - def _docker_run_cmd(self, cmd): - docker_cmd = [ - 'docker', - 'run', - '--name={}'.format(self.docker_container), - '--network=host', - ] - docker_cmd.extend(self.docker_run_flags) - for env_var in self.DIST_ENV_VARS: - docker_cmd.extend(['-e', env_var]) - for env_kv in self.env_vars: - key = re.match(r'(\w*)=.*', env_kv) - if key: - docker_cmd.extend(['-e', key.group(1)]) - docker_cmd.append(self.docker_image) - docker_cmd.extend(cmd) - return docker_cmd - - def _tpuvm_env_vars_cmd(self, worker_idx): - env_vars = { - xenv.TPU_CHIPS_PER_HOST_BOUNDS: '2,2,1', - xenv.TPUVM_MODE: 1, - xenv.CLOUD_TPU_TASK_ID: worker_idx, - } - accelerator_type = self._cluster.get_service_workers()[0]._machine_type - master_worker_network_endpoints = self._cluster.get_client_workers( - )[0].get_internal_ip() - - accelerator_type_to_host_bounds = { - # v2 - 'v2-8': '1,1,1', - 'v2-32': '2,2,1', - 'v2-128': '4,4,1', - 'v2-256': '4,8,1', - 'v2-512': '8,8,1', - # v3 - 'v3-8': '1,1,1', - 'v3-32': '2,2,1', - 'v3-64': '2,4,1', - 'v3-128': '4,4,1', - 'v3-256': '4,8,1', - 'v3-512': '8,8,1', - 'v3-1024': '8,16,1', - 'v3-2048': '16,16,1', - # v4 - 'v4-8': '1,1,1', - 'v4-16': '1,1,2', - 'v4-32': '1,1,4', - 'v4-64': '1,2,4', - 'v4-128': '2,2,4', - 'v4-256': '2,2,8', - 'v4-512': '2,4,8', - 'v4-1024': '4,4,8', - 'v4-2048': '4,4,16', - 'v4-4096': '4,8,16', - } - - env_vars[xenv.TPU_HOST_BOUNDS] = accelerator_type_to_host_bounds[ - accelerator_type] - env_vars[xenv.TPU_MESH_CTLER_ADDR] = '{}:{}'.format( - master_worker_network_endpoints, '8476') - env_vars[xenv.TPU_MESH_CTLER_PORT] = 8476 - return env_vars - - def _env_vars_cmd(self, worker_idx): - client_worker = self._cluster.get_client_workers()[worker_idx] - accelerator_gen = self._cluster.get_service_workers( - )[0]._machine_type.split('-')[0] - accelerator_gen_to_tpu_num_devices = { - 'v2': 8, - 'v3': 8, - 'v4': 4, - } - worker_name = 'c_localservice' if self.tpuvm_mode else 'c_tpu_worker' - env_vars = { - xenv.LOCAL_WORKER: - '{}:{}'.format(worker_name, worker_idx), - xenv.SERVICE_ADDRESS: - '{}:{}'.format(self._cluster.get_client_master().get_internal_ip(), - self.MESH_SERVICE_PORT), - xenv.WORLD_SIZE: - len(self._cluster.get_client_workers()), - xenv.ORDINAL: - worker_idx, - xenv.TPU_NUM_DEVICES: - accelerator_gen_to_tpu_num_devices[accelerator_gen], - 'XLA_EMIT_STEPLOG': - 1, - } - if self.tpuvm_mode: - env_vars.update(self._tpuvm_env_vars_cmd(worker_idx)) - - # Only for master - if client_worker == self._cluster.get_client_master(): - xrt_server_config = [ - '{worker_name};{worker_idx};{worker_ip}:{worker_port}'.format( - worker_name=worker_name, - worker_idx=idx, - worker_ip=service_worker.get_internal_ip(), - worker_port=self.tpuvm_server_port - if self.tpuvm_mode else service_worker.get_port()) for idx, - service_worker in enumerate(self._cluster.get_service_workers()) - ] - xrt_tpu_config = '|'.join(xrt_server_config) - env_vars[xenv.TPU_CONFIG] = '{}'.format(xrt_tpu_config) - - export_cmd = [] - for k in env_vars: - export_cmd.append(['export', '{}={}'.format(k, env_vars[k])]) - for kv in self.env_vars: - export_cmd.append(['export', '{}'.format(kv)]) - return export_cmd - - def _prepare_scripts(self, cmd): - worker_script_map = {} - for i in range(len(self._cluster.get_client_workers())): - script_path = self.SCRIPT_PATH_TMPL.format(pid=os.getpid(), worker=i) - - # ex. script = [['conda', 'activate', 'pytorch'], ['python3', 'train.py']] - script = [] - script.extend(self._env_vars_cmd(i)) - # Setup environment for non-interactive non-login shell over ssh - script.append(['.', '/etc/profile']) - if self.tpuvm_mode: - # Start the local tf server if it is not already running. - script.append([ - 'python3', '-m', self.XRT_RUN_SERVER_CMD, '--port', - str(self.tpuvm_server_port) - ]) - if self.restart_server: - script[-1].append('--restart') - if self.docker_image: - script.append(self._docker_run_cmd(cmd)) - else: - if self.conda_env: - script.append(['conda', 'activate', self.conda_env]) - script.append(cmd) - - # ex. script_body = 'conda activate pytorch; python3 train.py' - script_cmd_list = [concat_cmd_list(command) for command in script] - script_body = concat_cmd_list(script_cmd_list, delimiter='; ') - os.makedirs(os.path.dirname(script_path), exist_ok=True) - with open(script_path, 'w') as f: - f.write(script_body) - subprocess.call(['chmod', '+x', script_path]) - worker_script_map[self._cluster.get_client_workers()[i]] = { - 'local_path': - script_path, - 'remote_path': - os.path.join('{}-remote'.format(os.path.dirname(script_path)), - os.path.basename(script_path)), - } - - return worker_script_map - - def _scp_scripts(self, script_map): - - def _gcloud_scp(local_path, remote_path, client_worker): - self._build_and_run_ssh( - ['mkdir', '-p', os.path.dirname(remote_path)], client_worker) - scp_cmd = self._build_scp_cmd(local_path, remote_path, client_worker) - proc = subprocess.Popen( - scp_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - self._stream_logs(proc, client_worker) - - threads = [] - for i, client_worker in enumerate(script_map): - local_path = script_map[client_worker]['local_path'] - remote_path = script_map[client_worker]['remote_path'] - if i == 0: - # ssh keygen single time - _gcloud_scp(local_path, remote_path, client_worker) - continue - thread = threading.Thread( - target=_gcloud_scp, - daemon=True, - args=( - local_path, - remote_path, - client_worker, - )) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() - - def _cleanup(self, script_map): - - def _cleanup_worker(local_script, remote_script, client_worker): - rm_tmp_dir = ['rm', '-rf', os.path.dirname(remote_script)] - self._build_and_run_ssh(rm_tmp_dir, client_worker, log=False) - subprocess.call(['rm', '-rf', os.path.dirname(local_script)]) - if self.docker_image: - rm_container = ['docker', 'rm', '-f', self.docker_container] - self._build_and_run_ssh(rm_container, client_worker, log=False) - rm_pgroup = ( - 'kill -9 -$(ps xao pid,pgid,cmd | grep "bash -c \\"{}\\""' - r' | grep -v grep | awk "{{print \$2}}")').format(remote_script) - self._build_and_run_ssh(rm_pgroup, client_worker, log=False) - - threads = [] - for client_worker in script_map: - thread = threading.Thread( - target=_cleanup_worker, - args=( - script_map[client_worker]['local_path'], - script_map[client_worker]['remote_path'], - client_worker, - )) - thread.start() - threads.append(thread) - - # Cleanup states in case of restart - self._initialize() - - for thread in threads: - thread.join() - - def _start_run(self, script_map): - - def _run_script(script_paths, client_worker): - script_path = script_paths['remote_path'] - exit_code = self._build_and_run_ssh([script_path], client_worker) - if exit_code != 0: - raise RuntimeError( - 'Remote command exited with code: {}'.format(exit_code)) - - def _regular_health_check(): - uneven_health_timeout = xu.getenv_as('XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, - 900) - even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT', int, - 1800) - while True: - self._check_client_mesh_health(uneven_health_timeout, - even_health_timeout) - time.sleep(self.HEARTBEAT_CHECK_PERIOD) - - threading.Thread(target=_regular_health_check, daemon=True).start() - xu.parallel_work( - len(script_map), _run_script, script_map.values(), script_map.keys()) - - def _run_cmd(self, script_map): - try: - self._scp_scripts(script_map) - self._start_run(script_map) - except KeyboardInterrupt: - self.logger.warning( - 'Child process received Ctrl^C. Exiting...', - extra={ - 'clientip': '', - 'ordinal': '' - }) - sys.exit(128 + signal.SIGINT) - - def run(self, cmd): - self.trials = 0 - while self.trials <= self.MAX_TPU_RETRY: - try: - self.logger.info( - 'Command to distribute: {}'.format(concat_cmd_list(cmd)), - extra={ - 'clientip': '', - 'ordinal': '' - }) - self.logger.info( - f'Cluster configuration: {self._cluster}', - extra={ - 'clientip': '', - 'ordinal': '' - }) - - script_map = self._prepare_scripts(cmd) - proc = multiprocessing.Process(target=self._run_cmd, args=(script_map,)) - proc.start() - while True: - if not proc.is_alive(): - sys.exit(proc.exitcode) - if len(self._cluster.list_tpus_with_health( - 'UNHEALTHY_MAINTENANCE')) != 0: - # TPU Maintenance: kill all training, wait for healthy, and restart - break - if not self._error_queue.empty(): - # Potential HostError on GCE VM: kill all, wait, and restart - self.logger.warning( - self._error_queue.get(), extra={ - 'clientip': '', - 'ordinal': '' - }) - break - - proc.join(10) - - # First wait for VMs to come back then cleanup all others - self._cluster.wait_for_healthy_client(self) - self._cleanup(script_map) - proc.terminate() - self._cluster.wait_for_healthy_service() - self.trials += 1 - except KeyboardInterrupt: - self.logger.info( - 'Cleaning up processes (takes a couple of seconds)', - extra={ - 'clientip': '', - 'ordinal': '' - }) - self._cleanup(script_map) - sys.exit(128 + signal.SIGINT) - - self.logger.info( - 'Max number of retries reached.', extra={ - 'clientip': '', - 'ordinal': '' - }) - - -def main(args=None): - os.environ[xenv.PJRT_SELECT_DEFAULT_DEVICE] = '0' - if xr.using_pjrt(): - logging.warning( - 'PJRT runtime detected. `xla_dist` is NOT compatible with PJRT, and you may run into unexpected errors. Unset $PJRT_DEVICE to silence this warning.' - ) - - FLAGS = parse_args(args) - if (FLAGS.docker_container or FLAGS.docker_image or - FLAGS.docker_run_flag) and FLAGS.conda_env: - raise ValueError('Docker Setup arguments and Conda Setup' - ' arguments are mutually exclusive.') - - # Resolve VM and TPU clusters. - resolve_and_execute(FLAGS) - - -if __name__ == '__main__': - main() diff --git a/torch_xla/distributed/xrt_init.py b/torch_xla/distributed/xrt_init.py deleted file mode 100644 index 072d56052b47..000000000000 --- a/torch_xla/distributed/xrt_init.py +++ /dev/null @@ -1,249 +0,0 @@ -import os -import re -import socket -import subprocess -import torch.distributed as dist -import torch_xla.core.xla_model as xm -import torch_xla.core.xla_env_vars as xenv -from torch_xla.utils.utils import get_free_tcp_ports - -XRT_SERVER_REGEX = 'torch_xla.distributed._xrt_run_server' -_TCP_STORE = None -_INIT_XRT_ALREADY_CALLED = False - - -def _create_devices(dev_kind, world_size): - # Create global XLA devices. Adapted from xmp.spawn() to function across nodes - devices = [] - dev_type = 'GPU' - - for gindex in range(0, world_size): - tfdevice = f'{dev_type}:{gindex};/job:localservice/replica:0/task:{gindex}/device:XLA_{dev_type}:0' - devices.append(tfdevice) - os.environ[xenv.DEVICE_MAP] = '|'.join(devices) - - -def _setup_workers(world_size, rank, local_world_size, local_rank): - # Set up workers across nodes. xmp.spawn() does this locally by figuring out free ports on the node - # We do this globally by doing an allgather of locally obtained free socket addresses - # Note that this follows the original scheme, in the new scheme only one address per node needs exchange - host = socket.gethostname() - if local_rank == 0: - ports = [str(i) for i in get_free_tcp_ports(local_world_size)] - _TCP_STORE.set(host, ' '.join(ports)) - else: - ports_str = _TCP_STORE.get(host).decode('UTF-8') - ports = list(ports_str.split(' ')) - - my_worker = '{}:{};grpc://{}:{}'.format('localservice', rank, host, - ports[local_rank]) - all_workers = [] - for i in range(0, world_size): - if rank == i: - _TCP_STORE.set(f'worker:{i}', my_worker) - all_workers.append(my_worker) - else: - worker = _TCP_STORE.get(f'worker:{i}').decode('UTF-8') - all_workers.append(worker) - os.environ['XRT_WORKERS'] = '|'.join(all_workers) - - -def _get_address_from_store(key, rank): - if rank == 0: - port = get_free_tcp_ports()[0] - host = socket.getfqdn() - service_addr = '{}:{}'.format(host, port) - _TCP_STORE.set(key, service_addr) - else: - service_addr = _TCP_STORE.get(key).decode('UTF-8') - - return service_addr - - -def _set_mesh_config(rank): - address = _get_address_from_store('xrt_mesh_config', rank) - if not os.environ.get(xenv.SERVICE_ADDRESS, None): - os.environ[xenv.SERVICE_ADDRESS] = address - if not os.environ.get("TPU_MESH_CONTROLLER_ADDRESS", None): - address = _get_address_from_store('tpu_mesh_config', rank) - _, port = address.split(":") - os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = address - os.environ["TPU_MESH_CONTROLLER_PORT"] = port - - -def _set_tpu_xrt_envs(local_rank, rank, group_rank, local_world_size, - world_size): - total_nodes = world_size // local_world_size - - xrt_tpu_config = [] - tpu_config_port = None - for i in range(total_nodes): - key = f'worker_{i}_address' - if group_rank == i and local_rank == 0: - tpu_config_port = get_free_tcp_ports()[0] - host = socket.getfqdn() - address = '{}:{}'.format(host, tpu_config_port) - _TCP_STORE.set(key, address) - else: - address = _TCP_STORE.get(key).decode('UTF-8') - if total_nodes == 1: - xrt_tpu_config.append(f'localservice;{i};{address}') - else: - xrt_tpu_config.append(f'c_localservice;{i};{address}') - - if rank == 0: - os.environ[xenv.TPU_CONFIG] = '|'.join(xrt_tpu_config) - os.environ[xenv.TPU_NUM_DEVICES] = str(local_world_size) - - os.environ[ - xenv. - LOCAL_WORKER] = f'localservice:{group_rank}' if total_nodes == 1 else f'c_localservice:{group_rank}' - os.environ[xenv.WORLD_SIZE] = str(world_size) - os.environ[xenv.HOST_WORLD_SIZE] = str(total_nodes) - os.environ[xenv.ORDINAL] = str(rank) - os.environ[xenv.LOCAL_ORDINAL] = str(local_rank) - os.environ[xenv.MP_DEVICE] = f'TPU:{rank}' - if not os.environ.get('TF_GRPC_DEFAULT_OPTIONS', None): - os.environ['TF_GRPC_DEFAULT_OPTIONS'] = ( - 'grpc.keepalive_time_ms=60000,grpc.keepalive_timeout_ms=14400000,' - 'grpc.http2.max_pings_without_data=0,grpc.http2.min_ping_interval_without_data_ms=300000' - ) - # We don't want torch_xla to start the local server internally. - # We are starting the xrt server by ourselves - os.environ['XRT_START_LOCAL_SERVER'] = '0' - - return tpu_config_port - - -def _set_neuron_envs(rank, world_size, local_world_size): - os.environ["NEURON_USE_LOAD_COLLECTIVES"] = '1' - os.environ['NEURON_GLOBAL_DEVICE_ID'] = str(rank) - os.environ['NEURON_GLOBAL_DEVICE_COUNT'] = str(world_size) - if not os.environ.get('NEURON_RT_VISIBLE_CORES', None): - os.environ['NEURON_RT_VISIBLE_CORES'] = ','.join( - [str(i) for i in range(local_world_size)]) - - -def _setup_nccl_service(dev_kind, rank): - # Set up NCCL COMM ID required for NCCL communicator IDs - address = _get_address_from_store('nccl_info', rank) - if dev_kind == 'NEURON': - os.environ['NEURON_RT_ROOT_COMM_ID'] = address - elif dev_kind == 'GPU': - os.environ['NEURON_RT_ROOT_COMM_ID'] = address - os.environ['XRT_MESH_SERVICE_ADDRESS'] = address - else: - raise RuntimeError('NCCL service setup failed!') - - -def set_xrt_envs(world_size, rank, local_rank): - # Set up all the XRT specific env variables, adapted from xmp.spawn() - os.environ[xenv.WORLD_SIZE] = str(world_size) - os.environ[xenv.ORDINAL] = str(rank) - os.environ[xenv.LOCAL_ORDINAL] = str(local_rank) - os.environ[xenv.LOCAL_WORKER] = 'localservice:' + str(rank) - - os.environ[xenv.MP_DEVICE] = f'GPU:{rank}' - gpus_to_use = os.environ.get('CUDA_VISIBLE_DEVICES') - if gpus_to_use is not None: - # If gpu devices are set by a scheduling entity (eg. SLURM) we index into - # comma separated string containing numbered gpu devies - gpus_to_use_list = gpus_to_use.split(',') - os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use_list[local_rank] - else: - # If no explicit visible devices are provided, local_rank is used to identify - # the gpu used by this process - os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) - - -def init_xrt_context(master_addr=None, master_port=None, store=None): - """Initializes the XLA device depending on the kind of the device. Or is a no-op if init_xrt_context - has already been called. - - Args: - master_addr (string): This is used to set up the TCPStore. If none is provided, it is obtained - from the environment variable. Also not required/used if store argument is passed in. - - master_port (int): This is used to set up the TCPStore. If none is provided, it is obtained from - environment variable. Also not required/used if store argument is passed in. - - store (TCPstore): A TCPstore object to use instead of creating a new one. If None a TCPStore object - will be setup for you. - Default: None - """ - global _INIT_XRT_ALREADY_CALLED - - if _INIT_XRT_ALREADY_CALLED: - return - - # Call this in the actual test case, to work with torch/xla workers - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - group_rank = int(os.environ['GROUP_RANK']) - local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) - - if master_addr is None: - master_addr = os.environ['MASTER_ADDR'] - - if master_port is None: - master_port = os.environ['MASTER_PORT'] - - dev_list = os.listdir('/dev/') - #checking the dev kind, need similar filter for TPU - neuron_devs = list(filter(lambda v: re.match('neuron', v), dev_list)) - if neuron_devs: - dev_kind = 'NEURON' - else: - dev_kind = 'GPU' - - os.environ.pop(xenv.TPU_CONFIG, None) - os.environ.pop(xenv.TPU_NUM_DEVICES, None) - os.environ.pop(xenv.GPU_NUM_DEVICES, None) - - # This is required if we want to dynamically grab free ports. - # Useful in shared settings when we cannot predetermine what ports are taken. - is_server = True if rank == '0' else False - global _TCP_STORE - if store is None: - assert master_addr is not None - assert master_port is not None - _TCP_STORE = dist.TCPStore(master_addr, int(master_port), world_size, - is_server) - else: - _TCP_STORE = store - - node_list = None - - if dev_kind == 'NEURON': #similar check for TPU.. - tpu_config_port = _set_tpu_xrt_envs(local_rank, rank, group_rank, - local_world_size, world_size) - elif dev_kind == 'GPU': - _setup_nccl_service(dev_kind, rank) - set_xrt_envs(world_size, rank, local_rank) - _create_devices(dev_kind, world_size) - _setup_workers(world_size, rank, local_world_size, local_rank) - - _set_mesh_config(rank) - - if dev_kind == 'NEURON': #similar check for TPU.. - _setup_nccl_service(dev_kind, rank) - _set_neuron_envs(rank, world_size, local_world_size) - - total_nodes = world_size // local_world_size - if local_rank == 0: - local_env = os.environ.copy() - subprocess.Popen([ - 'python3', '-m', XRT_SERVER_REGEX, '--port', - str(tpu_config_port), '--pid_to_track', - str(os.getppid()) - ], - env=local_env, - start_new_session=True) - - dev = xm.xla_device() - xm.set_replication(dev, [dev]) - - # if we get to this point, we know the function completed successfully - # and we can switch the flag to True - _INIT_XRT_ALREADY_CALLED = True From 24493994aaf811d020303dec6bf08c4bb91f991b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:34:10 +0000 Subject: [PATCH 02/22] Remove dead import --- torch_xla/distributed/xla_backend.py | 12 - torch_xla/distributed/xla_multiprocessing.py | 360 +------------------ 2 files changed, 2 insertions(+), 370 deletions(-) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index dca1594f1339..0faabe62662a 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -1,15 +1,10 @@ -import distutils.util -import os import torch import torch.distributed as dist -import torch_xla import torch_xla.core.xla_model as xm import logging from torch._C._distributed_c10d import ( ProcessGroup, - Work, ) -from .xrt_init import init_xrt_context def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -42,13 +37,6 @@ def __init__(self, prefix_store, rank, size, timeout): self.prefix_store = prefix_store # reserved for future use. self.timeout = timeout self._mesh = [] - # Initialize xrt neuron environment - # Passes in the store created by torch.distributed to avoid - # creating two TCP stores. We only want to call this - # when the user is using torchrun and not xmp.spawn() - # or some other flow. - if os.getenv('TORCHELASTIC_RUN_ID') != None: - init_xrt_context(store=prefix_store) def getBackendName(self): return 'xla' diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index 18de3436ab58..35f7a67fab42 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -1,336 +1,6 @@ -import collections -import os -import re -import socket -import sys import torch.multiprocessing -import torch_xla -import torch_xla.core.xla_env_vars as xenv -import torch_xla.core.xla_model as xm from torch_xla import runtime as xr from torch_xla._internal import pjrt -import torch_xla.utils.utils as xu -import traceback - -PreForkConfig = collections.namedtuple('PreForkConfig', 'dev_kind num_devices') -WorkerConfigEntry = collections.namedtuple('WorkerConfigEntry', - 'worker_name ordinal host_port') - -_LOCAL_WORKER = 'localservice' -_CUDA_VISIBLE_DEVICES = 'CUDA_VISIBLE_DEVICES' - - -def _is_xla_config(): - for env in [ - xenv.TPU_CONFIG, xenv.LOCAL_WORKER, xenv.GPU_NUM_DEVICES, - xenv.CPU_NUM_DEVICES - ]: - if os.environ.get(env, None) is not None: - return True - return False - - -# TODO: Some usages of this function are to calculate the number of hosts (a TPU concept), -# and some are to calculate the number of processes within a world (which can span multiple hosts). -# The latter should really be what this function is supposed to do. It's so confusing. We -# should improve it. -def _get_world_size(): - # We cannot use the xla_model.py API here, as the features used in that module - # needs the setup provided by this one. - return int(os.environ.get(xenv.WORLD_SIZE, '1')) - - -def _create_gpu_devices(num_gpus): - devices = [] - for h in range(0, _get_world_size()): - for i in range(0, num_gpus): - gindex = h * num_gpus + i - # We use CUDA_VISIBLE_DEVICES to limit the set of CUDA devices per process - # to 1, and its device index is always 0. We use the task to disambiguate - # TF devices. - tfdevice = '/job:{}/replica:0/task:{}/device:XLA_GPU:0'.format( - _LOCAL_WORKER, gindex) - devices.append('GPU:{};{}'.format(gindex, tfdevice)) - os.environ[xenv.DEVICE_MAP] = '|'.join(devices) - - -def _parse_workers_config(config): - # XRT_WORKERS='worker:0;ismz9:25822' - workers = collections.OrderedDict() - for worker in config.split('|'): - m = re.match(r'(\w+):(\d+);((grpc://)?[a-zA-Z0-9_\-\.]+:\d+)', worker) - if not m: - raise ValueError('Bad worker syntax: {}'.format(worker)) - workers['{}:{}'.format(m.group(1), m.group(2))] = WorkerConfigEntry( - worker_name=m.group(1), ordinal=int(m.group(2)), host_port=m.group(3)) - return workers - - -def _parse_tpu_config(config): - # XRT_TPU_CONFIG='tpu_worker;0;ismz9:25822' - workers = collections.OrderedDict() - for worker in config.split('|'): - m = re.match(r'(\w+);(\d+);([a-zA-Z0-9_\-\.]+:\d+)', worker) - if not m: - raise ValueError('Bad worker syntax: {}'.format(worker)) - workers['{}:{}'.format(m.group(1), m.group(2))] = WorkerConfigEntry( - worker_name=m.group(1), ordinal=int(m.group(2)), host_port=m.group(3)) - return workers - - -def _get_devices_per_worker(): - num_tpus = os.environ.get(xenv.TPU_NUM_DEVICES, None) - if os.environ.get(xenv.TPU_CONFIG, None) is not None or num_tpus is not None: - return int(num_tpus or '8'), 'TPU' - num_gpus = os.environ.get(xenv.GPU_NUM_DEVICES, None) - if num_gpus is not None: - return int(num_gpus), 'GPU' - num_cpus = os.environ.get(xenv.CPU_NUM_DEVICES, None) - if num_cpus is not None: - return int(num_cpus), 'CPU' - raise RuntimeError('Missing TPU or GPU configuration') - - -def _get_multiprocessing_device(): - return os.environ.get(xenv.MP_DEVICE, None) - - -def _get_local_worker_index(): - host_ordinal = os.environ.get(xenv.HOST_ORDINAL, None) - if host_ordinal is not None: - return int(host_ordinal) - worker = os.environ.get(xenv.LOCAL_WORKER, None) - if worker is None: - return 0 - m = re.match(r'(\w+):(\d+)', worker) - if not m: - raise ValueError('Bad worker syntax: {}'.format(worker)) - return int(m.group(2)) - - -def _local_index_to_global(index, num_devices): - return _get_local_worker_index() * num_devices + index - - -def _setup_torch_distributed(): - import torch.distributed as dist - - ordinal = int(os.environ[xenv.HOST_ORDINAL]) - world_size = int(os.environ[xenv.HOST_WORLD_SIZE]) - method = os.environ.get(xenv.TORCH_DIST_METHOD, 'gloo') - init_method = 'tcp://{}'.format(os.environ[xenv.TORCH_DIST_ROOT]) - dist.init_process_group( - method, init_method=init_method, rank=ordinal, world_size=world_size) - - -def _setup_world_size(pf_cfg): - # We cannot call into xla_model code at this point, as we do not know whether - # the called code would trigger XLA library initializations (which we must - # not do at this point). So we avoid calling into xm.xrt_world_size(). - host_world_size = _get_world_size() - world_size = host_world_size * pf_cfg.num_devices - os.environ[xenv.WORLD_SIZE] = str(world_size) - if pf_cfg.dev_kind == 'CPU': - # Since XLA CPU does not support across device reduces, and support only - # one device per process, we make each CPU device look like if it was a - # single process host, and use torch.distributed for inter-host reductions. - os.environ[xenv.HOST_WORLD_SIZE] = str(world_size) - else: - os.environ[xenv.HOST_WORLD_SIZE] = str(host_world_size) - - -def _get_mp_device_ordinal(index, gindex): - # If xenv.HOST_ORDINAL is set, we are in a multi CPU setup, where devices - # are numbered locally within the single host (but the ordinal/rank is still - # global). - return index if xenv.HOST_ORDINAL in os.environ else gindex - - -# TODO: Consolidate this with _setup_gpu_worker. -def _setup_gpu_workers(num_devices): - world_size = _get_world_size() - workers_env = os.environ.get(xenv.WORKERS, None) - workers = [] - # TODO: Is this path actually being used? This seems to support multi-host GPUs (is this a thing at all?). - if workers_env is not None: - wcfg = _parse_workers_config(workers_env) - assert world_size == len( - wcfg), 'World size ({}) must match the configured workers ({})'.format( - world_size, len(wcfg)) - for key, worker in wcfg.items(): - _, ordinal = key.split(":") - m = re.match(r'(.*):(\d+)$', worker.host_port) - if not m: - raise RuntimeError('Bad worker HOST:PORT format: {}'.format( - worker.host_port)) - for i in range(0, num_devices): - gindex = int(ordinal) * num_devices + i - workers.append('{}:{};grpc://{}:{}'.format(worker.worker_name, gindex, - m.group(1), - int(m.group(2)) + i)) - else: - assert world_size == 1, ('Cannot use more than one host without {} ' - 'configuration: {}').format( - xenv.WORKERS, world_size) - ports = xu.get_free_tcp_ports(num_devices) - host = socket.getfqdn() - for wid in range(0, num_devices): - workers.append('{}:{};grpc://{}:{}'.format(_LOCAL_WORKER, wid, host, - ports[wid])) - os.environ[xenv.WORKERS] = '|'.join(workers) - - -def _pre_fork_setup_torch_distributed(): - if not xenv.TORCH_DIST_ROOT in os.environ: - os.environ[xenv.TORCH_DIST_ROOT] = '{}:{}'.format( - socket.getfqdn(), - xu.get_free_tcp_ports()[0]) - - -def _pre_fork_cpu_setup(num_devices): - if xenv.HOST_ORDINAL not in os.environ: - # CPU multi-processing must use the host ordinal path, which enables the - # torch.distributed reductions across single CPU cores. Since XLA CPU does - # not support multiple devices within the same process, each XLA CPU device - # is isolated within a single process, which is seen as "host" as well. - os.environ[xenv.HOST_ORDINAL] = '0' - - -def _pre_fork_setup(num_devices): - dev_count, dev_kind = _get_devices_per_worker() - if num_devices is None: - num_devices = dev_count - elif num_devices not in [1, dev_count]: - raise ValueError( - 'The number of devices must be either 1 or {}, got {} instead'.format( - dev_count, num_devices)) - total_devices = _get_world_size() * num_devices - if total_devices > 1 and not os.environ.get(xenv.SERVICE_ADDRESS, None): - # In multi-processing mode, even if there is only one XLA host, we still - # bring up the mesh service. - os.environ[xenv.SERVICE_ADDRESS] = '{}:{}'.format( - socket.getfqdn(), - xu.get_free_tcp_ports()[0]) - if dev_kind == 'GPU': - _setup_gpu_workers(num_devices) - _create_gpu_devices(num_devices) - elif dev_kind == 'CPU': - _pre_fork_cpu_setup(num_devices) - _pre_fork_setup_torch_distributed() - return PreForkConfig(dev_kind=dev_kind, num_devices=num_devices) - - -def _setup_gpu_worker(index, gindex): - os.environ[xenv.MP_DEVICE] = 'GPU:{}'.format( - _get_mp_device_ordinal(index, gindex)) - os.environ[xenv.LOCAL_WORKER] = '{}:{}'.format(_LOCAL_WORKER, gindex) - # Every process is restricted to 1 GPU device, which in such process will be - # named XLA_GPU:0. - os.environ[_CUDA_VISIBLE_DEVICES] = str(index) - # We have expanded the GPU devices in the device map already, in - # _create_gpu_devices(), so delete the key from the environment as it - # otherwise triggers device generation again in computation_client.cc. - os.environ.pop(xenv.GPU_NUM_DEVICES, None) - - -def _setup_cpu_worker(index, gindex): - task_no = 0 - dev_index = _get_mp_device_ordinal(index, gindex) - os.environ[xenv.MP_DEVICE] = 'CPU:{}'.format(dev_index) - os.environ[xenv.LOCAL_WORKER] = '{}:{}'.format(_LOCAL_WORKER, task_no) - os.environ[xenv.WORKERS] = '{}:{};grpc://localhost:{}'.format( - _LOCAL_WORKER, task_no, - xu.get_free_tcp_ports()[0]) - os.environ[ - xenv. - DEVICE_MAP] = 'CPU:{};/job:{}/replica:0/task:{}/device:XLA_CPU:0'.format( - dev_index, _LOCAL_WORKER, task_no) - os.environ.pop(xenv.CPU_NUM_DEVICES, None) - # XLA CPU has no support for cross-replica reduces, so we have to reduce using - # torch.distributed capabilities. Since the logic is to use torch.distributed - # across hosts (with XLA device reduces across devices within the same host), - # we make the single host processes behave like if they were different hosts. - os.environ[xenv.HOST_ORDINAL] = str(gindex) - - -def _wants_tpu_env_config(index, gindex): - return gindex == 0 - - -def _setup_tpu_worker(index, gindex, tpu_env_config): - os.environ[xenv.MP_DEVICE] = 'TPU:{}'.format( - _get_mp_device_ordinal(index, gindex)) - if xenv.LOCAL_WORKER not in os.environ: - # The local worker can be missing for a 1 TPU host setup. Make sure we - # always have one. - assert tpu_env_config, '{} environment must be populated'.format( - xenv.TPU_CONFIG) - tpu_config = _parse_tpu_config(tpu_env_config) - worker = list(tpu_config.values())[0] - os.environ[xenv.LOCAL_WORKER] = '{}:{}'.format(worker.worker_name, - worker.ordinal) - if not _wants_tpu_env_config(index, gindex): - # In multi-processing mode, only the process handling the first device of - # the master worker, will do TPU mesh initialization, so we need to remove - # the environment configs which would prevent the client to be falling in - # the mesh client config path. - os.environ.pop(xenv.TPU_CONFIG, None) - os.environ.pop(xenv.TPU_NUM_DEVICES, None) - - -def _prepare_env_for_index(index, pf_cfg): - _setup_world_size(pf_cfg) - gindex = _local_index_to_global(index, pf_cfg.num_devices) - os.environ[xenv.ORDINAL] = str(gindex) - os.environ[xenv.LOCAL_ORDINAL] = str(index) - - if pf_cfg.dev_kind == 'TPU': - _setup_tpu_worker(index, gindex, os.environ.get(xenv.TPU_CONFIG, None)) - elif pf_cfg.dev_kind == 'GPU': - _setup_gpu_worker(index, gindex) - elif pf_cfg.dev_kind == 'CPU': - _setup_cpu_worker(index, gindex) - _setup_torch_distributed() - return gindex - - -def _setup_replication(): - # At this point xla_model.py APIs are allowed as the setup is already - # completed. - if xm.xrt_world_size() > 1: - device = xm.xla_device() - xm.set_replication(device, [device]) - - -def _start_fn(index, pf_cfg, fn, args): - gindex = _prepare_env_for_index(index, pf_cfg) - # Calling _setup_replication() will trigger XLA library initialization, so the - # environment must be fully setup before doing so. - _setup_replication() - fn(gindex, *args) - - -def _mp_start_fn(index, pf_cfg, fn, args): - exit_code = 0 - try: - _start_fn(index, pf_cfg, fn, args) - except Exception as e: - print( - 'Exception in device={}: {}'.format(_get_multiprocessing_device(), - str(e)), - file=sys.stderr) - traceback.print_exc(limit=16, file=sys.stderr) - exit_code = 17 - sys.exit(exit_code) - - -def _run_direct(fn, args, nprocs, join, daemon, start_method): - nprocs = nprocs or 1 - if nprocs == 1 and join: - fn(0, *args) - else: - return torch.multiprocessing.spawn( - fn, args=args, nprocs=nprocs, join=join, daemon=daemon) def spawn(fn, @@ -364,34 +34,8 @@ def spawn(fn, `nprocs` is 1 the `fn` function will be called directly, and the API will return None. """ - if xr.using_pjrt(): - return pjrt.spawn(fn, nprocs, start_method, args) - - if not _is_xla_config(): - # If this is not an XLA setup, jump to normal multi-processing. - return _run_direct(fn, args, nprocs, join, daemon, start_method) - - pf_cfg = _pre_fork_setup(nprocs) - result = None - if pf_cfg.num_devices == 1: - _start_fn(0, pf_cfg, fn, args) - else: - result = torch.multiprocessing.start_processes( - _mp_start_fn, - args=(pf_cfg, fn, args), - nprocs=pf_cfg.num_devices, - join=join, - daemon=daemon, - start_method=start_method) - - # For GPU, xenv.WORKERS are set in the launcher and then get carried to the children. - # However, if the launcher is reused to do another multi-process experiment, _setup_gpu_workers - # would mistake the xenv.WORKERS as configured to enable multi-host experiments. Each worker then - # represents a host. Therefore, reset it after launching all children. - if pf_cfg.dev_kind == 'GPU': - os.environ.pop(xenv.WORKERS) - - return result + assert xr.using_pjrt(), 'PJRT_DEVICE must be set.' + return pjrt.spawn(fn, nprocs, start_method, args) class MpModelWrapper(object): From 1c25f1e06c286bf28dacdba7bcdd71f9e3bea9aa Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:38:03 +0000 Subject: [PATCH 03/22] formatting --- torch_xla/distributed/xla_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 0faabe62662a..b0a2f0e727f2 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -2,9 +2,7 @@ import torch.distributed as dist import torch_xla.core.xla_model as xm import logging -from torch._C._distributed_c10d import ( - ProcessGroup, -) +from torch._C._distributed_c10d import ProcessGroup def _create_xla_process_group(prefix_store, rank, size, timeout): From fda34e7adc7d74a221c8902bf62c497a9b716719 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:42:29 +0000 Subject: [PATCH 04/22] Remove disable_xrt build option --- .bazelrc | 3 --- 1 file changed, 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index aede23601f0b..763e2feb9f90 100644 --- a/.bazelrc +++ b/.bazelrc @@ -75,9 +75,6 @@ build:tpu --define=with_tpu_support=true test:tpu --local_test_jobs=1 test:cuda --local_test_jobs=1 -# Exclude XRT from the build -build:disable_xrt --define=disable_xrt=true - ######################################################################### # RBE config options below. # Flag to enable remote config From e86898962299ba3dbe96160baf005f48b3976635 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:47:32 +0000 Subject: [PATCH 05/22] Fix runtime init --- torch_xla/__init__.py | 26 -------------------------- torch_xla/runtime.py | 19 +++++++------------ 2 files changed, 7 insertions(+), 38 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 2bb624c508ed..35c82812f524 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -2,33 +2,10 @@ import os import re import tempfile -import subprocess logging.basicConfig() logger = logging.getLogger(__name__) -XRT_RUN_SERVER_PROCESS = 'torch_xla.core._xrt_run_server' -XRT_SERVER_REGEX = '^python3 -m {} [0-9]+$'.format(XRT_RUN_SERVER_PROCESS) - - -def server_is_alive(): - # pgrep returns 0 when at least one running process matches the requested name. - # Otherwise, the exit code is 1. If pgrep is not availiable in the system, it - # will return an exit code 127. - return subprocess.getstatusoutput( - 'pgrep -f "{}"'.format(XRT_SERVER_REGEX))[0] == 0 - - -def _setup_grpc(): - # Setup GRPC options to correctly talk to TPU backends. - options = [ - 'grpc.keepalive_time_ms=60000', # 1 min - 'grpc.keepalive_timeout_ms=14400000', # 4 hrs - 'grpc.http2.max_pings_without_data=0', # unlimited - 'grpc.http2.min_ping_interval_without_data_ms=300000', # 5 min - ] - os.environ['TF_GRPC_DEFAULT_OPTIONS'] = ','.join(options) - def _set_missing_flags(flags, sets): for name, defval in sets: @@ -63,8 +40,6 @@ def _setup_default_env(): _set_missing_env('GRPC_VERBOSITY', 'ERROR') _set_missing_env('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') _set_missing_env('TPU_ML_PLATFORM', 'PyTorch/XLA') - if server_is_alive(): - _set_missing_env('XRT_START_LOCAL_SERVER', '0') _fd, _tmp_fname = -1, '' @@ -114,7 +89,6 @@ def _setup_tpu_vm_library_path() -> bool: # These needs to be called before the _XLAC module is loaded. _setup_default_env() -_setup_grpc() _setup_xla_flags() if int(os.environ.get('PT_XLA_DEBUG', '0')): _fd, _tmp_fname = _setup_debug_env() diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index c09459959a70..a2195320deda 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -9,6 +9,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_backend import torch_xla.utils.utils as xu +import torch_xla._internal.tpu as tpu R = TypeVar('R') FN = TypeVar('FN') @@ -26,29 +27,23 @@ def set_device_type(pjrt_device: str) -> None: def _maybe_select_default_device(): - # Skip if runtime is already configured - if xu.getenv_as( - xenv.PJRT_SELECT_DEFAULT_DEVICE, str, '1' - ) == '0' or xenv.PJRT_DEVICE in os.environ or xenv.GPU_NUM_DEVICES in os.environ or any( - env.startswith('XRT_') for env in os.environ): + if xu.getenv_as(xenv.PJRT_SELECT_DEFAULT_DEVICE, str, '1') == '0' or xenv.PJRT_DEVICE in os.environ: return - logging.warning( - 'XRT configuration not detected. Defaulting to PJRT runtime. To silence ' - 'this warning and continue using PJRT, explicitly set PJRT_DEVICE to a ' - 'supported device or configure XRT. To disable default device selection, ' - 'set PJRT_SELECT_DEFAULT_DEVICE=0') # TODO: Update this link in the release branch - logging.warning('For more information about the status of PJRT, see ' + logging.warning('PJRT is now the default runtime. For more information, see ' 'https://github.com/pytorch/xla/blob/master/docs/xr.md') # Check for libtpu _and_ the TPU device if torch_xla._found_libtpu and os.path.exists('/dev/accel0'): logging.warning('libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.') os.environ[xenv.PJRT_DEVICE] = 'TPU' + # TODO(wcromar): Detect GPU device + elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0: + logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=GPU') + os.environ[xenv.PJRT_DEVICE] = 'GPU' else: logging.warning('Defaulting to PJRT_DEVICE=CPU') os.environ[xenv.PJRT_DEVICE] = 'CPU' - # TODO(wcromar): Detect GPU device too def device_type() -> Optional[str]: From 21f93133706428403695d12bd68072a970ca26ac Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:52:53 +0000 Subject: [PATCH 06/22] Revert "Remove disable_xrt build option" This reverts commit ba312e76e069bef40c8f9803a672b29409862804. --- .bazelrc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.bazelrc b/.bazelrc index 763e2feb9f90..aede23601f0b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -75,6 +75,9 @@ build:tpu --define=with_tpu_support=true test:tpu --local_test_jobs=1 test:cuda --local_test_jobs=1 +# Exclude XRT from the build +build:disable_xrt --define=disable_xrt=true + ######################################################################### # RBE config options below. # Flag to enable remote config From 7f997627c8900a3fe3a3c1d49cff92904121e6be Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:54:36 +0000 Subject: [PATCH 07/22] Add disable XRT option back --- .github/workflows/_build.yml | 10 +++++++++- .github/workflows/build_and_test.yml | 1 + infra/ansible/config/env.yaml | 1 + setup.py | 3 --- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 6a9510b64141..674dc145735a 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -48,6 +48,7 @@ jobs: SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} XLA_CUDA: ${{ inputs.cuda }} + DISABLE_XRT: ${{ inputs.disable_xrt }} steps: - name: Setup Linux uses: pytorch/test-infra/.github/actions/setup-linux@main @@ -87,6 +88,7 @@ jobs: run: | echo "declare -x SCCACHE_BUCKET=${SCCACHE_BUCKET}" | docker exec -i "${pid}" sh -c "cat >> env" echo "declare -x CC=clang-8 CXX=clang++-8" | docker exec -i "${pid}" sh -c "cat >> xla_env" + echo "declare -x DISABLE_XRT=${DISABLE_XRT}" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_CUDA=${XLA_CUDA}" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> default_credentials.json" @@ -105,7 +107,13 @@ jobs: id: upload-docker-image shell: bash run: | - export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:latest-${GITHUB_SHA}" + if [[ ${DISABLE_XRT} == 1 ]]; then + image_tag_base=latest + else + image_tag_base=latest-xrt + fi + + export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:${image_tag_base}-${GITHUB_SHA}" time docker commit "${pid}" "${COMMIT_DOCKER_IMAGE}" time docker push "${COMMIT_DOCKER_IMAGE}" echo "docker-image=${COMMIT_DOCKER_IMAGE}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 6a80fde05d81..e45a7a4d41bb 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -22,6 +22,7 @@ jobs: with: ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base gcr-docker-image: gcr.io/tpu-pytorch/xla_base:latest + disable_xrt: 1 cuda: 1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index 3a25c298ae34..d2996f750367 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -32,6 +32,7 @@ build_env: XLA_SANDBOX_BUILD: 1 BAZEL_REMOTE_CACHE: 1 SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}" + DISABLE_XRT: "{{ disable_xrt }}" amd64: ARCH: amd64 diff --git a/setup.py b/setup.py index 3ae615a500bc..fd2604a010d6 100644 --- a/setup.py +++ b/setup.py @@ -250,9 +250,6 @@ def bazel_build(self, ext): if _check_env_flag('TPUVM_MODE'): bazel_argv.append('--config=tpu') - if _check_env_flag('DISABLE_XRT'): - bazel_argv.append('--config=disable_xrt') - # Remote cache authentication. if _check_env_flag('BAZEL_REMOTE_CACHE'): bazel_argv.append('--config=remote_cache') From 0cb3ce8cf4f25acc602a15e6f2a7d2594cfb1b80 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 17:57:44 +0000 Subject: [PATCH 08/22] formatting --- torch_xla/runtime.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index a2195320deda..3b6554654e26 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -27,7 +27,8 @@ def set_device_type(pjrt_device: str) -> None: def _maybe_select_default_device(): - if xu.getenv_as(xenv.PJRT_SELECT_DEFAULT_DEVICE, str, '1') == '0' or xenv.PJRT_DEVICE in os.environ: + if xu.getenv_as(xenv.PJRT_SELECT_DEFAULT_DEVICE, str, + '1') == '0' or xenv.PJRT_DEVICE in os.environ: return # TODO: Update this link in the release branch From b4a44e9122d0613c04dfa29d09b5ed2feb3d0338 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:25:35 +0000 Subject: [PATCH 09/22] Prune mesh service --- torch_xla/csrc/init_python_bindings.cpp | 24 -- torch_xla/csrc/runtime/BUILD | 35 -- torch_xla/csrc/runtime/mesh_service.cc | 411 ---------------------- torch_xla/csrc/runtime/mesh_service.h | 60 ---- torch_xla/csrc/runtime/mesh_service.proto | 65 ---- 5 files changed, 595 deletions(-) delete mode 100644 torch_xla/csrc/runtime/mesh_service.cc delete mode 100644 torch_xla/csrc/runtime/mesh_service.h delete mode 100644 torch_xla/csrc/runtime/mesh_service.proto diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 24cf6cdb8948..2293a749f81c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/ops/device_data.h" -#include "torch_xla/csrc/runtime/mesh_service.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" @@ -516,24 +515,6 @@ py::object GetRevisions() { return py_dict; } -std::vector Rendezvous(int ordinal, const std::string& tag, - const std::string& payload, - const std::vector& replicas) { - runtime::service::MeshClient* mesh_client = - runtime::service::MeshClient::Get(); - std::vector payloads; - if (mesh_client != nullptr) { - auto rendezvous_payloads = - mesh_client->Rendezvous(ordinal, tag, payload, replicas); - for (auto& rendezvous_payload : rendezvous_payloads) { - payloads.push_back(rendezvous_payload); - } - } else { - XLA_CHECK(replicas.empty() || (replicas.size() == 1 && replicas[0] == 0)); - } - return payloads; -} - py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores, const at::Tensor& score_threshold, const at::Tensor& iou_threshold, int64_t output_size) { @@ -879,11 +860,6 @@ void InitXlaModuleBindings(py::module m) { runtime::GetComputationClient()->GetReplicationDevices(); return replication_devices != nullptr ? replication_devices->size() : 0; }); - m.def("_xla_rendezvous", - [](int ordinal, const std::string& tag, const std::string& payload, - const std::vector& replicas) { - return Rendezvous(ordinal, tag, payload, replicas); - }); py::class_>( m, "IrValue"); diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 384bcc2fb810..372182304752 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -1,15 +1,7 @@ -load( - "@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl", - "tf_proto_library_cc", -) load( "@org_tensorflow//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load( - "//bazel:tensorflow.bzl", - "if_with_tpu_support", -) licenses(["notice"]) # Apache 2.0 @@ -20,16 +12,6 @@ exports_files([ "tf_exported_symbols.lds", ]) -tf_proto_library_cc( - name = "mesh_service_proto", - srcs = ["mesh_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [ - "@org_tensorflow//tensorflow/core/protobuf/tpu:topology_proto", - ], -) cc_library( name = "async_task", @@ -156,23 +138,6 @@ cc_library( hdrs = ["env_vars.h"], ) -cc_library( - name = "mesh_service", - srcs = ["mesh_service.cc"], - hdrs = ["mesh_service.h"], - deps = [ - "nccl_distributed", - ":debug_macros", - ":mesh_service_proto_cc", - ":multi_wait", - ":sys_util", - ":thread_pool", - ":util", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/compiler/xla:statusor", - ], -) - cc_library( name = "metrics_analysis", srcs = ["metrics_analysis.cc"], diff --git a/torch_xla/csrc/runtime/mesh_service.cc b/torch_xla/csrc/runtime/mesh_service.cc deleted file mode 100644 index e24ebeaa81a9..000000000000 --- a/torch_xla/csrc/runtime/mesh_service.cc +++ /dev/null @@ -1,411 +0,0 @@ -#include "torch_xla/csrc/runtime/mesh_service.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/status.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/mesh_service.grpc.pb.h" -#include "torch_xla/csrc/runtime/multi_wait.h" -#include "torch_xla/csrc/runtime/nccl_distributed.h" -#include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/util.h" - -namespace torch_xla { -namespace runtime { -namespace service { -namespace { - -#define GRPC_CHECK_OK(expr) \ - do { \ - auto s = expr; \ - if (!s.ok()) { \ - return ToGrpcStatus(s); \ - } \ - } while (0) - -::grpc::Status ToGrpcStatus(const xla::Status& status) { - return status.ok() - ? ::grpc::Status::OK - : ::grpc::Status(static_cast<::grpc::StatusCode>(status.code()), - std::string(status.message())); -} - -std::ostream& operator<<(std::ostream& ostrm, const ::grpc::Status& status) { - if (status.ok()) { - ostrm << "OK"; - } else { - ostrm << status.error_message() << " (" - << static_cast(status.error_code()) << ")"; - } - return ostrm; -} - -std::basic_ostringstream& operator<<(std::basic_ostringstream ostrm, - const ::grpc::Status& status) { - if (status.ok()) { - ostrm << "OK"; - } else { - ostrm << status.error_message() << " (" - << static_cast(status.error_code()) << ")"; - } - return ostrm; -} - -class MeshServiceImpl : public grpc::MeshService::Service { - public: - explicit MeshServiceImpl(grpc::Config config); - - ::grpc::Status GetConfig(::grpc::ServerContext* context, - const grpc::GetConfigRequest* request, - grpc::GetConfigResponse* response) override; - - ::grpc::Status SetConfig(::grpc::ServerContext* context, - const grpc::SetConfigRequest* request, - grpc::SetConfigResponse* response) override; - - ::grpc::Status Rendezvous(::grpc::ServerContext* context, - const grpc::RendezvousRequest* request, - grpc::RendezvousResponse* response) override; - - ::grpc::Status GetNcclUniqueUid( - ::grpc::ServerContext* context, - const grpc::GetNcclUniqueUidRequest* request, - grpc::GetNcclUniqueUidResponse* response) override; - - private: - class RendezvousData { - public: - explicit RendezvousData(size_t count, const std::set& replicas) - : count_(count), - replicas_(replicas), - mwait_(count), - release_count_(0) {} - - bool Release() { return release_count_.fetch_add(1) == 0; } - - ::grpc::Status Wait(); - - void Complete(int64_t ordinal, std::string payload, - const std::set& replicas); - - const std::map& Payloads() const { - return payloads_; - }; - - private: - size_t count_; - std::set replicas_; - std::mutex lock_; - util::MultiWait mwait_; - std::atomic release_count_; - std::map payloads_; - ::grpc::Status status_; - }; - - std::shared_ptr GetRendezvous( - const std::string& tag, const std::set& replicas); - - void ReleaseRendezvous(const std::string& tag, - const std::shared_ptr& rendezvous); - - static ::grpc::Status HandleRpc( - const std::function<::grpc::Status()>& rpc_fn); - - std::mutex lock_; - std::map configs_; - std::unordered_map> - rendezvous_map_; -}; - -::grpc::Status MeshServiceImpl::RendezvousData::Wait() { - ::grpc::Status status = ToGrpcStatus( - torch_xla::runtime::util::CheckedCall([&]() { mwait_.Wait(); })); - if (status.ok()) { - std::lock_guard lock(lock_); - status = status_; - } - return status; -} - -void MeshServiceImpl::RendezvousData::Complete( - int64_t ordinal, std::string payload, const std::set& replicas) { - std::lock_guard lock(lock_); - if ((!replicas_.empty() && replicas_.count(ordinal) == 0) || - (replicas_.empty() && ordinal >= count_)) { - status_ = ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, - absl::StrCat("Invalid ordinal: ", ordinal)); - } else if (replicas != replicas_) { - status_ = ::grpc::Status( - ::grpc::StatusCode::INVALID_ARGUMENT, - absl::StrCat("Mismatching replicas: (", absl::StrJoin(replicas_, ", "), - ") vs. (", absl::StrJoin(replicas, ", "), ")")); - } else { - auto insert_result = payloads_.emplace(ordinal, std::move(payload)); - if (!insert_result.second) { - status_ = ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, - absl::StrCat("Duplicate ordinal: ", ordinal)); - } - } - mwait_.Done(); -} - -MeshServiceImpl::MeshServiceImpl(grpc::Config config) { - configs_.emplace(0, std::move(config)); -} - -::grpc::Status MeshServiceImpl::GetConfig(::grpc::ServerContext* context, - const grpc::GetConfigRequest* request, - grpc::GetConfigResponse* response) { - auto rpc_fn = [&]() -> ::grpc::Status { - TF_VLOG(3) << "Got config fetch request: peer=" << context->peer(); - response->mutable_config()->CopyFrom(configs_.at(request->ordinal())); - return ::grpc::Status::OK; - }; - return HandleRpc(rpc_fn); -} - -::grpc::Status MeshServiceImpl::SetConfig(::grpc::ServerContext* context, - const grpc::SetConfigRequest* request, - grpc::SetConfigResponse* response) { - auto rpc_fn = [&]() -> ::grpc::Status { - TF_VLOG(3) << "Got config set request: peer=" << context->peer() - << ", ordinal=" << request->ordinal(); - - std::lock_guard lock(lock_); - XLA_CHECK_EQ(configs_.at(0).mesh_size(), request->config().mesh_size()); - configs_.emplace(request->ordinal(), request->config()); - return ::grpc::Status::OK; - }; - return HandleRpc(rpc_fn); -} - -::grpc::Status MeshServiceImpl::Rendezvous( - ::grpc::ServerContext* context, const grpc::RendezvousRequest* request, - grpc::RendezvousResponse* response) { - std::set replicas(request->replicas().begin(), - request->replicas().end()); - auto rendezvous = GetRendezvous(request->tag(), replicas); - rendezvous->Complete(request->ordinal(), request->payload(), replicas); - TF_VLOG(3) << "Entering rendezvous: ordinal=" << request->ordinal() - << ", tag=" << request->tag() << ", peer=" << context->peer(); - ::grpc::Status status = rendezvous->Wait(); - TF_VLOG(3) << "Exiting rendezvous: ordinal=" << request->ordinal() - << ", tag=" << request->tag() << ", peer=" << context->peer() - << ", status=" << status; - if (status.ok()) { - for (auto& ordinal_payload : rendezvous->Payloads()) { - response->add_payloads(ordinal_payload.second); - } - } - ReleaseRendezvous(request->tag(), rendezvous); - return status; -} - -::grpc::Status MeshServiceImpl::GetNcclUniqueUid( - ::grpc::ServerContext* context, - const grpc::GetNcclUniqueUidRequest* request, - grpc::GetNcclUniqueUidResponse* response) { - std::vector replicas; - for (auto& replica : request->replicas()) { - replicas.push_back(replica); - } - TF_VLOG(3) << "Got NCCL UID fetch request: replicas=(" - << absl::StrJoin(replicas, ", ") << "), peer=" << context->peer(); - response->set_uid(nccl_detail::GetNcclUniqueUid(replicas)); - return ::grpc::Status::OK; -} - -std::shared_ptr MeshServiceImpl::GetRendezvous( - const std::string& tag, const std::set& replicas) { - std::lock_guard lock(lock_); - auto it = rendezvous_map_.find(tag); - if (it == rendezvous_map_.end()) { - size_t count = - replicas.empty() ? configs_.at(0).mesh_size() : replicas.size(); - it = rendezvous_map_ - .emplace(tag, std::make_shared(count, replicas)) - .first; - } - return it->second; -} - -void MeshServiceImpl::ReleaseRendezvous( - const std::string& tag, const std::shared_ptr& rendezvous) { - if (rendezvous->Release()) { - std::lock_guard lock(lock_); - rendezvous_map_.erase(tag); - } -} - -::grpc::Status MeshServiceImpl::HandleRpc( - const std::function<::grpc::Status()>& rpc_fn) { - try { - return rpc_fn(); - } catch (const std::exception& ex) { - return ::grpc::Status( - ::grpc::StatusCode::ABORTED, - absl::StrCat("Exception while handling RPC: ", ex.what())); - } -} - -} // namespace - -struct MeshService::Impl { - Impl(const std::string& address, grpc::Config config) - : impl(std::move(config)) { - ::grpc::ServerBuilder builder; - int64_t max_msg_size = - sys_util::GetEnvInt("XRT_MESH_MAX_MSGSIZE", 1024 * 1024 * 1024); - builder.SetMaxReceiveMessageSize(max_msg_size); - builder.SetMaxSendMessageSize(max_msg_size); - builder.AddListeningPort(address, ::grpc::InsecureServerCredentials()); - builder.RegisterService(&impl); - server = builder.BuildAndStart(); - } - - MeshServiceImpl impl; - std::unique_ptr<::grpc::Server> server; -}; - -MeshService::MeshService(const std::string& address, grpc::Config config) - : impl_(new Impl(address, std::move(config))) {} - -MeshService::~MeshService() {} - -void MeshService::Shutdown() { - impl_->server->Shutdown(); - impl_->server->Wait(); -} - -struct MeshClient::Impl { - explicit Impl(const std::string& address) : address(address) { - channel = - ::grpc::CreateChannel(address, ::grpc::InsecureChannelCredentials()); - stub = grpc::MeshService::NewStub(channel); - } - - std::shared_ptr<::grpc::Channel> channel; - std::unique_ptr stub; - std::string address; -}; - -MeshClient* MeshClient::Get() { - auto create_client = []() { - std::string mesh_service_address = - sys_util::GetEnvString("XRT_MESH_SERVICE_ADDRESS", ""); - return !mesh_service_address.empty() ? new MeshClient(mesh_service_address) - : nullptr; - }; - static MeshClient* client = create_client(); - return client; -} - -MeshClient::MeshClient(const std::string& address) : impl_(new Impl(address)) { - int64_t connect_wait_seconds = - sys_util::GetEnvInt("XRT_MESH_CONNECT_WAIT", 300); - TF_LOG(INFO) << "Waiting to connect to client mesh master (" - << connect_wait_seconds << " seconds) " << address; - XLA_CHECK(impl_->channel->WaitForConnected( - std::chrono::system_clock::now() + - std::chrono::seconds(connect_wait_seconds))) - << "Failed to connect to client mesh master: " << address; -} - -MeshClient::~MeshClient() {} - -const std::string& MeshClient::address() const { return impl_->address; } - -grpc::Config MeshClient::GetConfig(int ordinal) const { - ::grpc::ClientContext context; - grpc::GetConfigRequest request; - grpc::GetConfigResponse response; - request.set_ordinal(ordinal); - ::grpc::Status status = impl_->stub->GetConfig(&context, request, &response); - if (!status.ok()) { - XLA_ERROR() << "Failed to retrieve mesh configuration: " << status; - } - return std::move(*response.mutable_config()); -} - -void MeshClient::SetConfig(int ordinal, const grpc::Config& config) const { - ::grpc::ClientContext context; - grpc::SetConfigRequest request; - grpc::SetConfigResponse response; - request.set_ordinal(ordinal); - request.mutable_config()->CopyFrom(config); - ::grpc::Status status = impl_->stub->SetConfig(&context, request, &response); - if (!status.ok()) { - XLA_ERROR() << "Failed to set configuration: " << status; - } -} - -std::vector MeshClient::Rendezvous( - int ordinal, const std::string& tag, const std::string& payload, - absl::Span replicas) const { - ::grpc::ClientContext context; - grpc::RendezvousRequest request; - grpc::RendezvousResponse response; - request.set_tag(tag); - request.set_payload(payload); - request.set_ordinal(ordinal); - for (auto& replica : replicas) { - request.add_replicas(replica); - } - TF_VLOG(3) << "Waiting for rendezvous: ordinal=" << ordinal << " tag=" << tag; - ::grpc::Status status = impl_->stub->Rendezvous(&context, request, &response); - TF_VLOG(3) << "Rendezvous wait complete: " << tag; - if (!status.ok()) { - XLA_ERROR() << "Failed to meet rendezvous '" << tag << "': " << status; - } - std::vector rv_payloads; - for (auto& rv_payload : response.payloads()) { - rv_payloads.push_back(rv_payload); - } - return rv_payloads; -} - -std::string MeshClient::GetNcclUniqueUid( - absl::Span replicas) const { - ::grpc::ClientContext context; - grpc::GetNcclUniqueUidRequest request; - grpc::GetNcclUniqueUidResponse response; - for (auto& replica : replicas) { - request.add_replicas(replica); - } - - TF_VLOG(3) << "Waiting for NCCL UID: replicas=(" - << absl::StrJoin(replicas, ", ") << ")"; - ::grpc::Status status = - impl_->stub->GetNcclUniqueUid(&context, request, &response); - TF_VLOG(3) << "NCCL UID wait complete: " << absl::StrJoin(replicas, ", ") - << ")"; - if (!status.ok()) { - XLA_ERROR() << "Failed to get NCCL UID (" << absl::StrJoin(replicas, ", ") - << "): " << status; - } - return response.uid(); -} - -} // namespace service -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/mesh_service.h b/torch_xla/csrc/runtime/mesh_service.h deleted file mode 100644 index 5547742b6684..000000000000 --- a/torch_xla/csrc/runtime/mesh_service.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef XLA_CLIENT_XRT_MESH_SERVICE_H_ -#define XLA_CLIENT_XRT_MESH_SERVICE_H_ - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/types.h" -#include "torch_xla/csrc/runtime/mesh_service.pb.h" - -namespace torch_xla { -namespace runtime { -namespace service { - -class MeshService { - struct Impl; - - public: - MeshService(const std::string& address, grpc::Config config); - - ~MeshService(); - - void Shutdown(); - - private: - std::unique_ptr impl_; -}; - -class MeshClient { - struct Impl; - - public: - static MeshClient* Get(); - - const std::string& address() const; - - grpc::Config GetConfig(int ordinal) const; - - void SetConfig(int ordinal, const grpc::Config& config) const; - - std::vector Rendezvous(int ordinal, const std::string& tag, - const std::string& payload, - absl::Span replicas) const; - - std::string GetNcclUniqueUid(absl::Span replicas) const; - - private: - MeshClient(const std::string& address); - - ~MeshClient(); - - std::unique_ptr impl_; -}; - -} // namespace service -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_XRT_MESH_SERVICE_H_ diff --git a/torch_xla/csrc/runtime/mesh_service.proto b/torch_xla/csrc/runtime/mesh_service.proto deleted file mode 100644 index b51f1b837a6f..000000000000 --- a/torch_xla/csrc/runtime/mesh_service.proto +++ /dev/null @@ -1,65 +0,0 @@ -syntax = "proto2"; - -import "tensorflow/core/protobuf/tpu/topology.proto"; - -package torch_xla.runtime.service.grpc; - -message Device { - required string local_name = 1; - required string global_name = 2; -} - -message Worker { - required string name = 1; - required int32 task_no = 2; - required string address = 3; - repeated Device devices = 4; -} - -message Config { - optional tensorflow.tpu.TopologyProto proto = 1; - repeated Worker workers = 2; - required int64 mesh_size = 3; -} - -message GetConfigRequest { - required uint32 ordinal = 1; -} - -message GetConfigResponse { - required Config config = 1; -} - -message SetConfigRequest { - required uint32 ordinal = 1; - required Config config = 2; -} - -message SetConfigResponse { -} - -message RendezvousRequest { - required string tag = 1; - required bytes payload = 2; - required uint32 ordinal = 3; - repeated uint32 replicas = 4; -} - -message RendezvousResponse { - repeated bytes payloads = 1; -} - -message GetNcclUniqueUidRequest { - repeated uint32 replicas = 1; -} - -message GetNcclUniqueUidResponse { - required bytes uid = 1; -} - -service MeshService { - rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} - rpc SetConfig(SetConfigRequest) returns (SetConfigResponse) {} - rpc Rendezvous(RendezvousRequest) returns (RendezvousResponse) {} - rpc GetNcclUniqueUid(GetNcclUniqueUidRequest) returns (GetNcclUniqueUidResponse) {} -} From 13873ddf1c44c36ec222156cc568a0bb10ae7bcc Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:31:57 +0000 Subject: [PATCH 10/22] Remove obsolete test --- test/allreduce_torchrun.py | 56 --------------------------------- test/run_tests.sh | 8 ----- test/test_allreduce_torchrun.py | 26 --------------- 3 files changed, 90 deletions(-) delete mode 100644 test/allreduce_torchrun.py delete mode 100644 test/test_allreduce_torchrun.py diff --git a/test/allreduce_torchrun.py b/test/allreduce_torchrun.py deleted file mode 100644 index 173999e61b34..000000000000 --- a/test/allreduce_torchrun.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import os -import torch -import torch_xla -import torch_xla.core.xla_model as xm -import torch.distributed as dist -import torch_xla.distributed.xla_multiprocessing as xmp -from torch_xla.distributed.xrt_init import init_xrt_context -import torch_xla.distributed.xla_backend - - -def _mp_fn_xrt_init(): - rank = int(os.environ['RANK']) - size = int(os.environ['WORLD_SIZE']) - - init_xrt_context() - - device = xm.xla_device() - ones = torch.ones((2, 3)) - xones = ones.to(device) - result = xm.all_reduce('sum', xones) - - result_cpu = result.cpu() - expected = torch.ones((2, 3)) * size - assert torch.all(result_cpu == expected), f'{result_cpu} != {expected}' - - -def _mp_fn_xla_backend(): - rank = int(os.environ['RANK']) - size = int(os.environ['WORLD_SIZE']) - - dist.init_process_group('xla') - device = xm.xla_device() - - ones = torch.ones((2, 3)) - xones = ones.to(device) - dist.all_reduce(xones, op=torch.distributed.ReduceOp.SUM) - - result_cpu = xones.cpu() - expected = torch.ones((2, 3)) * size - assert torch.all(xones.cpu() == expected), f'{xones} != {expected}' - - -if __name__ == '__main__': - print( - 'master_port:{}, master_addr:{}, rank:{}, local_rank:{}, size:{}'.format( - os.environ['MASTER_PORT'], os.environ['MASTER_ADDR'], - os.environ['RANK'], os.environ['LOCAL_RANK'], - os.environ['WORLD_SIZE'])) - parser = argparse.ArgumentParser() - parser.add_argument('--use_xla_backend', action="store_true") - args = parser.parse_args() - if args.use_xla_backend: - _mp_fn_xla_backend() - else: - _mp_fn_xrt_init() diff --git a/test/run_tests.sh b/test/run_tests.sh index e898ea6d6c48..e93003808346 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -108,14 +108,6 @@ function run_xla_backend_mp { MASTER_ADDR=localhost MASTER_PORT=6000 run_test "$@" } -function run_xrt { - if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then - GPU_NUM_DEVICES=2 run_coverage "$@" - else - XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localservice:0;grpc://localhost:$(shuf -i 40701-40999 -n 1)" run_coverage "$@" - fi -} - function run_torch_op_tests { run_dynamic "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA run_test_without_functionalization "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA diff --git a/test/test_allreduce_torchrun.py b/test/test_allreduce_torchrun.py deleted file mode 100644 index c81bb3abf445..000000000000 --- a/test/test_allreduce_torchrun.py +++ /dev/null @@ -1,26 +0,0 @@ -import os -import subprocess -import pathlib - - -def test_local_torchrun_xrt_init(): - # This test launches a allreduce using torchrun launcher, uses native xla_model CCop - ci_dir = pathlib.Path(__file__).parent.resolve() - cmd = f'torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 {ci_dir}/allreduce_torchrun.py' - proc = subprocess.Popen(cmd, shell=True) - return_code = proc.wait() - assert return_code == 0 - - -def test_local_torchrun_xla_backend(): - # This test launches a allreduce using torchrun launcher, uses xla backend - ci_dir = pathlib.Path(__file__).parent.resolve() - cmd = f'torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 {ci_dir}/allreduce_torchrun.py --use_xla_backend' - proc = subprocess.Popen(cmd, shell=True) - return_code = proc.wait() - assert return_code == 0 - - -if __name__ == '__main__': - test_local_torchrun_xrt_init() - test_local_torchrun_xla_backend() From acd5424dbdd0241f555ffe1ee6d313e7f234e8d9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:32:54 +0000 Subject: [PATCH 11/22] Remove other run server script --- torch_xla/core/xrt_run_server.py | 81 -------------------------------- 1 file changed, 81 deletions(-) delete mode 100644 torch_xla/core/xrt_run_server.py diff --git a/torch_xla/core/xrt_run_server.py b/torch_xla/core/xrt_run_server.py deleted file mode 100644 index a3d70af90023..000000000000 --- a/torch_xla/core/xrt_run_server.py +++ /dev/null @@ -1,81 +0,0 @@ -import argparse -import re -import time -import os -import subprocess -import sys - -from pathlib import Path -from torch_xla.__init__ import server_is_alive, XRT_RUN_SERVER_PROCESS, XRT_SERVER_REGEX - - -def kill_service(): - subprocess.call(['pkill', '-f', XRT_SERVER_REGEX]) - # Wait unitl existing server process gets killed. - found_server_process = False - while server_is_alive(): - found_server_process = True - time.sleep(1) - # Server process might still hold the lock to the tpu device after turing into a zombie - # process with name . Sleep a bit longer to make sure it exit completely. - if found_server_process: - time.sleep(5) - - -def run_service(port, flag_env): - if server_is_alive(): - print('Server is already running, use --restart(--restart-tpuvm-pod-server ' - 'if running with xla_dist) to restart the server.') - return - - local_env = os.environ.copy() - # Enable the basic logging by default - local_env['TF_CPP_MIN_LOG_LEVEL'] = '0' - local_env[ - 'TF_CPP_VMODULE'] = 'tpu_configuration_ops=1,tpu_execute_op=1,tpu_compile_op=1,tpu_compile_op_impl=1,tpu_compile_op_common=1,tpu_compile_ops=1,master=1,computation_client=5' - - env_vars = list(flag_env) if flag_env else [] - for env_var in env_vars: - if re.match(r'\w*=\w*', env_var) is None: - raise ValueError(('Environment variable to distribute ({}) should follow ' - 'the form: X=Y').format(env_var)) - (env, var) = env_var.split('=', 1) - local_env[env] = var - - Path('/tmp/xrt_server_log').mkdir(parents=True, exist_ok=True) - time_str = time.strftime('%Y%m%d-%H%M%S') - log_file = open('/tmp/xrt_server_log/server_{}.log'.format(time_str), 'w') - subprocess.Popen(['python3', '-m', XRT_RUN_SERVER_PROCESS, - str(port)], - stdout=log_file, - stderr=subprocess.STDOUT, - env=local_env, - start_new_session=True) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--port', type=int, help='Port that XRT local service will be using.') - parser.add_argument( - '--env', - action='append', - type=str, - help='List of environment variables to distribute.') - - server_state_group = parser.add_mutually_exclusive_group() - server_state_group.add_argument( - '--restart', - action='store_true', - help='Restart the long running XRT local server.') - server_state_group.add_argument( - '--stop', - action='store_true', - help='Stop the long running XRT local server.') - - FLAGS = parser.parse_args() - if FLAGS.restart or FLAGS.stop: - kill_service() - - if not FLAGS.stop: - run_service(FLAGS.port, FLAGS.env) From aa8689776d46bd72790c79261c303c61df6126ed Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:33:07 +0000 Subject: [PATCH 12/22] Remove XRT config --- configuration.yaml | 88 ---------------------------------------------- 1 file changed, 88 deletions(-) diff --git a/configuration.yaml b/configuration.yaml index 8a8a147d8fdd..96c193a8352b 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -1,93 +1,5 @@ --- variables: - xrt_variables: - XRT_MESH_SERVICE_ADDRESS: - description: - - The mesh service address used to create the XRT mesh client. - type: string - default_value: "" - XRT_GRPC_COMPRESSION: - description: - - Configures compression for rpc options on the XRT client. - type: string - default_value: "" - XRT_TPU_CONFIG: - description: - - Addresses for the TPU services to be used by the XRT client. ";" - separated list of addresses, e.g. localservice;0;localhost:51011. - type: string - XRT_WORKERS: - description: - - Addresses for the XRT workers to be used by the XRT client, for - example localservice:0;grpc://localhost:51011 - type: string - XRT_DEVICE_MAP: - description: - - Maps devices to metadata about the job, replica, task that the device - is responsible for. e.g. - CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0 - type: string - XRT_GRPC_MULTISTREAM: - description: - - Used to disable session conneciton sharing for XRT. - type: bool - default_value: true - XRT_START_LOCAL_SERVER: - description: - - Whether or not XRT should start the local service. If true, and if - the devices are CPU or GPU, XRT will try to start the local server. - type: bool - default_value: false - XRT_SHARD_LOCAL_ORDINAL: - description: - - Ordinal to be appended to the paths used by this thread of the xla - client. - type: int - XRT_GRPC_COMPRESSION_LEVEL: - description: - - Configures compression level for rpc options on the XRT client. - type: int - default_value: 3 - XRT_MESH_MAX_MSGSIZE: - description: - - Max size of the mesh used by the XRT mesh service, product of - dimensions e.g. 1024 * 1024 * 1024 - type: int - default_value: 1073741824 # (1024^3) - XRT_MESH_CONNECT_WAIT: - description: - - Number of seconds to wait for a connection to the XRT mesh service, - particularly the client mesh master - type: int - default_value: 300 - XRT_LOCAL_WORKER: - description: - - Local service address for XRT local worker, e.g. localhost:8000. - type: string - default_value: "" - XRT_SHARD_WORLD_SIZE: - description: - - Total number of XRT shards to consider in this client instance. Does - not have a default because there's special behavior when the flag is - not set. - type: int - XRT_MULTI_PROCESSING_DEVICE: - description: - - Service address of the XRT device to be used as a multi processing - device. - type: string - default_value: "" - XRT_HOST_ORDINAL: - description: - - Sets the host ordinal for the XRT computation client. Used to - identify the rank of current device. Does not have a default because - there's special behavior when the flag is not set. - type: int - XRT_SHARD_ORDINAL: - description: - - Sets the shard ordinal for the XRT computation client. - type: int - default_value: -1 pjrt_variables: PJRT_DEVICE: description: From fde6d3978f2a5c69361f240deafb8cb958269763 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:34:15 +0000 Subject: [PATCH 13/22] Update PJRT default device test --- test/pjrt/test_runtime.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index 18373a259a78..be8381c0a415 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -55,14 +55,12 @@ def test_xla_device_error(self): }, False), ('pjrt_cpu', { 'PJRT_DEVICE': 'CPU', 'PJRT_SELECT_DEFAULT_DEVICE': '0' - }, True), ('xrt_tpu', { - 'XRT_TPU_CONFIG': 'localservice;0;localhost:51011' - }, False), ('pjrt_tpu_precedence', { + }, True), ('pjrt_tpu_precedence', { 'PJRT_DEVICE': 'TPU', 'XRT_TPU_CONFIG': 'localservice;0;localhost:51011', - }, True), ('xrt_gpu', { + }, True), ('gpu_num_devives', { 'GPU_NUM_DEVICES': '4' - }, False), ('pjrt_gpu', { + }, True), ('pjrt_gpu', { 'PJRT_DEVICE': 'GPU', 'GPU_NUM_DEVICES': '4' }, True), ('xla_dist_worker', { From 60ab12a49879b25f863a38ab6fb94ab8be00648f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:55:28 +0000 Subject: [PATCH 14/22] Add a file I forgot to save --- torch_xla/csrc/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index abefabb5e7ea..5ceea92377d4 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -263,7 +263,6 @@ ptxla_cc_library( ":tensor", ":version", "//torch_xla/csrc/runtime", - "//torch_xla/csrc/runtime:mesh_service", "//torch_xla/csrc/runtime:metrics", "//torch_xla/csrc/runtime:metrics_analysis", "//torch_xla/csrc/runtime:metrics_reader", From 7132fe40f0de16134241b37ee3dd3220ff31b896 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 18:55:38 +0000 Subject: [PATCH 15/22] if using_pjrt -> @requires_pjrt --- torch_xla/core/xla_model.py | 46 +++++--------------- torch_xla/distributed/xla_multiprocessing.py | 2 +- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index afcc748e8070..e651be24bc46 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -98,7 +98,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None): if kind_devices: return kind_devices[:max_devices] if max_devices else kind_devices - +@runtime.requires_pjrt def xrt_world_size(defval=1): """Retrieves the number of devices which is taking part of the replication. @@ -114,12 +114,10 @@ def xrt_world_size(defval=1): if _WORLD_SIZE is not None: return _WORLD_SIZE - if runtime.using_pjrt(): - return runtime.world_size() - - return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval) + return runtime.world_size() +@runtime.requires_pjrt def get_ordinal(defval=0): """Retrieves the replication ordinal of the current thread. @@ -137,12 +135,9 @@ def get_ordinal(defval=0): if _ORDINAL is not None: return _ORDINAL - if runtime.using_pjrt(): - return runtime.global_ordinal() - - return xu.getenv_as(xenv.ORDINAL, int, defval=defval) - + return runtime.global_ordinal() +@runtime.requires_pjrt def get_local_ordinal(defval=0): """Retrieves the replication local ordinal of the current thread. @@ -156,13 +151,7 @@ def get_local_ordinal(defval=0): Returns: The replication local ordinal of the current thread. """ - if runtime.using_pjrt(): - return runtime.local_ordinal() - - ordinal = xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=-1) - if ordinal >= 0: - return ordinal - return getattr(_get_device_context(), 'device_index', defval) + return runtime.local_ordinal() def is_master_ordinal(local=True): @@ -185,7 +174,7 @@ def master_print(*args, fd=sys.stdout, local=False, flush=False): if is_master_ordinal(local=local): print(*args, file=fd, flush=flush) - +@runtime.requires_pjrt def xla_device(n=None, devkind=None): """Returns a given instance of an XLA device. @@ -205,20 +194,7 @@ def xla_device(n=None, devkind=None): torch_xla._XLAC._xla_set_default_device(device) return torch.device(device) - if runtime.using_pjrt(): - return runtime.xla_device(n, devkind) - - if n is None: - devices = get_xla_supported_devices(devkind=devkind) - assert devices, 'No devices of {} kind'.format(devkind or 'ANY') - # This is a utility API mainly called from tests or simple code which wants - # to just have a single device to run on. Set the default device so that - # the tensor barrier can work correctly and avoid growing graphs surprises. - device = devices[0] - else: - device = 'xla:{}'.format(n) - torch_xla._XLAC._xla_set_default_device(device) - return torch.device(device) + return runtime.xla_device(n, devkind) def _xla_real_device(device): @@ -1083,6 +1059,7 @@ def xla_rendezvous(payload: bytes = b'', return [bytes(p.cpu().tolist()) for p in payloads] +@runtime.requires_pjrt def rendezvous(tag, payload=b'', replicas=[]): """Waits for all the mesh clients to reach the named rendezvous. @@ -1100,10 +1077,7 @@ def rendezvous(tag, payload=b'', replicas=[]): The payloads exchanged by all the other cores, with the payload of core ordinal `i` at position `i` in the returned tuple. """ - if runtime.using_pjrt(): - return xla_rendezvous(payload, replicas or None, tag=tag) - - return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas) + return xla_rendezvous(payload, replicas or None, tag=tag) def do_on_ordinals(target, data=(), ordinals=(0,)): diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index 35f7a67fab42..f8cf66952171 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -3,6 +3,7 @@ from torch_xla._internal import pjrt +@xr.requires_pjrt def spawn(fn, args=(), nprocs=None, @@ -34,7 +35,6 @@ def spawn(fn, `nprocs` is 1 the `fn` function will be called directly, and the API will return None. """ - assert xr.using_pjrt(), 'PJRT_DEVICE must be set.' return pjrt.spawn(fn, nprocs, start_method, args) From 99c36d82d89d3408c1e5d96c27132f47bfb1d7f4 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 19:41:39 +0000 Subject: [PATCH 16/22] Remove irrelevant test case --- test/pjrt/test_runtime.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index be8381c0a415..a8733be10141 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -63,9 +63,7 @@ def test_xla_device_error(self): }, True), ('pjrt_gpu', { 'PJRT_DEVICE': 'GPU', 'GPU_NUM_DEVICES': '4' - }, True), ('xla_dist_worker', { - 'XRT_LOCAL_WORKER': 'c_localservice:2' - }, False)) + }, True)) def test_pjrt_default_device(self, env_vars, expect_using_pjrt): with mock.patch.dict(os.environ, env_vars, clear=True): # Print a warningif we had to select a default runtime From bca0625fbca93ed2e55da8015b1fbed793bc9b89 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 19:43:13 +0000 Subject: [PATCH 17/22] Remove XRT env vars --- torch_xla/core/xla_env_vars.py | 13 ------------- torch_xla/csrc/runtime/env_vars.cc | 10 ---------- 2 files changed, 23 deletions(-) diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index 67b070aa6fae..f67ea2d9fb60 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -1,17 +1,4 @@ -MP_DEVICE = 'XRT_MULTI_PROCESSING_DEVICE' -LOCAL_WORKER = 'XRT_LOCAL_WORKER' -TPU_CONFIG = 'XRT_TPU_CONFIG' TPUVM_MODE = 'TPUVM_MODE' -SERVICE_ADDRESS = 'XRT_MESH_SERVICE_ADDRESS' -DEVICE_MAP = 'XRT_DEVICE_MAP' -WORKERS = 'XRT_WORKERS' -LOCAL_ORDINAL = 'XRT_SHARD_LOCAL_ORDINAL' -ORDINAL = 'XRT_SHARD_ORDINAL' -WORLD_SIZE = 'XRT_SHARD_WORLD_SIZE' -HOST_WORLD_SIZE = 'XRT_HOST_WORLD_SIZE' -HOST_ORDINAL = 'XRT_HOST_ORDINAL' -TORCH_DIST_METHOD = 'XRT_TORCH_DIST_METHOD' -TORCH_DIST_ROOT = 'XRT_TORCH_DIST_ROOT' TPU_NUM_DEVICES = 'TPU_NUM_DEVICES' GPU_NUM_DEVICES = 'GPU_NUM_DEVICES' CPU_NUM_DEVICES = 'CPU_NUM_DEVICES' diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index b6fa87d5b9e6..006f1cab1d6b 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -7,16 +7,6 @@ namespace env { const char* const kEnvNumTpu = "TPU_NUM_DEVICES"; const char* const kEnvNumGpu = "GPU_NUM_DEVICES"; const char* const kEnvNumCpu = "CPU_NUM_DEVICES"; -const char* const kEnvLocalWorker = "XRT_LOCAL_WORKER"; -const char* const kEnvTpuConfig = "XRT_TPU_CONFIG"; -const char* const kEnvDeviceMap = "XRT_DEVICE_MAP"; -const char* const kEnvWorkers = "XRT_WORKERS"; -const char* const kEnvMeshService = "XRT_MESH_SERVICE_ADDRESS"; -const char* const kEnvWorldSize = "XRT_SHARD_WORLD_SIZE"; -const char* const kEnvMpDevice = "XRT_MULTI_PROCESSING_DEVICE"; -const char* const kEnvHostOrdinal = "XRT_HOST_ORDINAL"; -const char* const kEnvShardOrdinal = "XRT_SHARD_ORDINAL"; -const char* const kEnvStartService = "XRT_START_LOCAL_SERVER"; const char* const kEnvTpuvmMode = "TPUVM_MODE"; const char* const kEnvPjRtDevice = "PJRT_DEVICE"; const char* const kEnvPjRtTpuMaxInflightComputations = From 64da4bbcaaa4212cfd7af2c36db87e543747df94 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 19:45:49 +0000 Subject: [PATCH 18/22] fix md link --- torch_xla/runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 3b6554654e26..3e4cb28758ad 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -33,7 +33,7 @@ def _maybe_select_default_device(): # TODO: Update this link in the release branch logging.warning('PJRT is now the default runtime. For more information, see ' - 'https://github.com/pytorch/xla/blob/master/docs/xr.md') + 'https://github.com/pytorch/xla/blob/master/docs/pjrt.md') # Check for libtpu _and_ the TPU device if torch_xla._found_libtpu and os.path.exists('/dev/accel0'): logging.warning('libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.') From ecb3d2d1a7cf5fdae0ea511f96b4caff8754e4c9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 19:45:58 +0000 Subject: [PATCH 19/22] formatting --- torch_xla/core/xla_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e651be24bc46..4ee4cbf03ff4 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -98,6 +98,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None): if kind_devices: return kind_devices[:max_devices] if max_devices else kind_devices + @runtime.requires_pjrt def xrt_world_size(defval=1): """Retrieves the number of devices which is taking part of the replication. @@ -137,6 +138,7 @@ def get_ordinal(defval=0): return runtime.global_ordinal() + @runtime.requires_pjrt def get_local_ordinal(defval=0): """Retrieves the replication local ordinal of the current thread. @@ -174,6 +176,7 @@ def master_print(*args, fd=sys.stdout, local=False, flush=False): if is_master_ordinal(local=local): print(*args, file=fd, flush=flush) + @runtime.requires_pjrt def xla_device(n=None, devkind=None): """Returns a given instance of an XLA device. From 0ca70aff099035cb99ae3099e1410a6ad98b494a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 20:36:18 +0000 Subject: [PATCH 20/22] Remove extra `requires_pjrt` --- torch_xla/core/xla_model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 4ee4cbf03ff4..d54f793a9b1a 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -99,7 +99,6 @@ def get_xla_supported_devices(devkind=None, max_devices=None): return kind_devices[:max_devices] if max_devices else kind_devices -@runtime.requires_pjrt def xrt_world_size(defval=1): """Retrieves the number of devices which is taking part of the replication. @@ -118,7 +117,6 @@ def xrt_world_size(defval=1): return runtime.world_size() -@runtime.requires_pjrt def get_ordinal(defval=0): """Retrieves the replication ordinal of the current thread. @@ -139,7 +137,6 @@ def get_ordinal(defval=0): return runtime.global_ordinal() -@runtime.requires_pjrt def get_local_ordinal(defval=0): """Retrieves the replication local ordinal of the current thread. @@ -177,7 +174,6 @@ def master_print(*args, fd=sys.stdout, local=False, flush=False): print(*args, file=fd, flush=flush) -@runtime.requires_pjrt def xla_device(n=None, devkind=None): """Returns a given instance of an XLA device. @@ -1062,7 +1058,6 @@ def xla_rendezvous(payload: bytes = b'', return [bytes(p.cpu().tolist()) for p in payloads] -@runtime.requires_pjrt def rendezvous(tag, payload=b'', replicas=[]): """Waits for all the mesh clients to reach the named rendezvous. From 6c5e5a1055666ac763bb5d19c1a76cc04cc933c3 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 22:46:58 +0000 Subject: [PATCH 21/22] merge conflicts --- test/run_tests.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index e93003808346..94e4fb636da3 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -134,8 +134,6 @@ function run_xla_op_tests { run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_async_closures.py" - run_test "$CDIR/test_autocast.py" - run_test "$CDIR/test_xla_dist.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/test_ops.py" run_test "$CDIR/test_metrics.py" From f85edc0cb59a70f8da6b19b99280c21c45948c8d Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 23 Jun 2023 23:34:29 +0000 Subject: [PATCH 22/22] Add other autocast back --- test/run_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_tests.sh b/test/run_tests.sh index 94e4fb636da3..cbc4ccf6d6df 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -134,6 +134,7 @@ function run_xla_op_tests { run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_async_closures.py" + run_test "$CDIR/test_autocast.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/test_ops.py" run_test "$CDIR/test_metrics.py"