From 5fb8b20fe1b3e318742b86e04dd45db65dc9332c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 25 Jul 2022 21:49:20 +0000 Subject: [PATCH 01/19] Configure TPU topology for pods --- torch_xla/experimental/pjrt.py | 26 ++++---------- torch_xla/experimental/tpu.py | 66 ++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 19 deletions(-) create mode 100644 torch_xla/experimental/tpu.py diff --git a/torch_xla/experimental/pjrt.py b/torch_xla/experimental/pjrt.py index 606998bcf743..31ba8bf0c745 100644 --- a/torch_xla/experimental/pjrt.py +++ b/torch_xla/experimental/pjrt.py @@ -11,6 +11,7 @@ import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu +from torch_xla.experimental import tpu _PJRT_ORDINALS = threading.local() @@ -46,19 +47,6 @@ def num_visible_tpu_chips(default: int = 4) -> int: return len(visible_devices.split(',')) if visible_devices else default -def configure_tpu_topology(rank: int, processes: int, base_port=8476) -> None: - """Sets default TPU topology environment variables for a single TPU host.""" - ports = list(range(base_port, base_port + processes)) - os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') - os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, '2,2,1') - os.environ.setdefault(xenv.TPU_PROCESS_ADDRESSES, - ','.join(f'localhost:{port}' for port in ports)) - - os.environ.setdefault(xenv.TPU_VISIBLE_DEVICES, str(rank)) - os.environ.setdefault(xenv.TPU_PROCESS_PORT, str(ports[rank])) - os.environ.setdefault(xenv.CLOUD_TPU_TASK_ID, str(rank)) - - def requires_pjrt(fn: FN) -> FN: """Wraps `fn` and checks if this process is using PjRt. @@ -141,13 +129,13 @@ def addressable_device_count() -> int: @requires_pjrt -def run_thread_per_device(rank: int, processes: int, +def run_thread_per_device(process: int, local_processes: int, fn: Callable[..., R]) -> Dict[int, R]: """Runs `fn` in a separate thread on each visible device. Args: - rank: rank of current process - processes: number of processes on this host + process: rank of current process + local_processes: number of processes on this host fn: Function to run on all devices Returns: @@ -155,7 +143,7 @@ def run_thread_per_device(rank: int, processes: int, result of calling `fn`. """ if device_type() == 'TPU': - configure_tpu_topology(rank, processes) + tpu.configure_topology(process, local_processes) xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices()) threads = len(xm.get_xla_supported_devices()) @@ -165,8 +153,8 @@ def _thread_fn(fn, device_index): @functools.wraps(fn) def wrapper(*args, **kwargs): # Assumes same number of threads per process - set_global_ordinal(rank * threads + device_index) - set_local_ordinal(rank * threads + device_index) + set_global_ordinal(tpu.task_id() * threads + device_index) + set_local_ordinal(process * threads + device_index) return fn(*args, **kwargs) diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py new file mode 100644 index 000000000000..52f423962a1e --- /dev/null +++ b/torch_xla/experimental/tpu.py @@ -0,0 +1,66 @@ +import cloud_tpu_client +import os +from typing import Optional, Iterable, Tuple +import numpy as np +import numpy.typing as npt +import requests +import yaml + +import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv + +_GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1' + +MeshShape = Tuple[int, int, int] + +def _parse_mesh_shape(mesh: str) -> MeshShape: + dims = tuple(int(d) for d in mesh.split(',')) + if len(dims) != 3: + raise ValueError("Mesh shape '{}' should be length 3".format(mesh)) + + return dims + +def _multiple_mesh_shapes(mesh1: MeshShape, mesh2: MeshShape) -> MeshShape: + return tuple(d1 * d2 for d1, d2 in zip(mesh1, mesh2)) + +def _get_metadata(key: str) -> str: + path = os.path.join(_GCE_METADATA_ROOT_URL, 'instance/attributes', key) + resp = requests.get(path, headers={'Metadata-Flavor': 'Google'}) + resp.raise_for_status() + + return resp.text + +def task_id() -> Optional[int]: + return xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, int) + +def get_tpu_env(): + metadata = _get_metadata('tpu-env') + + return yaml.load(metadata, yaml.Loader) + +def configure_topology(local_rank: int, local_world_size: int, base_port: int = 8476): + tpu_env = get_tpu_env() + + # Process bounds with 4 chips per process + default_process_bounds = _parse_mesh_shape(tpu_env[xenv.TPU_PROCESS_BOUNDS]) + chips_per_process = _parse_mesh_shape(tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) + + # Process bounds with 1 chip per process + process_bounds = _multiple_mesh_shapes(default_process_bounds, chips_per_process) + + os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') + os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, ','.join(str(dim) for dim in process_bounds)) + + # Assume each TPU has the same number of local processes with the same ports + worker_id = int(tpu_env['WORKER_ID']) + os.environ.setdefault(xenv.CLOUD_TPU_TASK_ID, str(worker_id * local_world_size + local_rank)) + + client = cloud_tpu_client.Client(tpu=tpu_env['NODE_ID']) + host_ips = [e['ipAddress'] for e in client.network_endpoints()] + + ports = list(range(base_port, base_port + local_world_size)) + process_endpoints = [','.join(f'{ip}:{port}' for port in ports) for ip in host_ips] + os.environ.setdefault(xenv.TPU_PROCESS_ADDRESSES, ','.join(process_endpoints)) + + os.environ.setdefault(xenv.TPU_VISIBLE_DEVICES, str(local_rank)) + os.environ.setdefault(xenv.TPU_PROCESS_PORT, str(ports[local_rank])) From 5d1a67e70fb14637eab73082a26278d574f48c77 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 29 Jul 2022 19:56:50 +0000 Subject: [PATCH 02/19] Use metadata for network endpoints --- torch_xla/experimental/tpu.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index 52f423962a1e..c9ea84292789 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -1,8 +1,5 @@ -import cloud_tpu_client import os -from typing import Optional, Iterable, Tuple -import numpy as np -import numpy.typing as npt +from typing import Optional, List, Tuple import requests import yaml @@ -20,7 +17,7 @@ def _parse_mesh_shape(mesh: str) -> MeshShape: return dims -def _multiple_mesh_shapes(mesh1: MeshShape, mesh2: MeshShape) -> MeshShape: +def _multiply_mesh_shapes(mesh1: MeshShape, mesh2: MeshShape) -> MeshShape: return tuple(d1 * d2 for d1, d2 in zip(mesh1, mesh2)) def _get_metadata(key: str) -> str: @@ -38,6 +35,15 @@ def get_tpu_env(): return yaml.load(metadata, yaml.Loader) +def get_worker_ips() -> List[str]: + metadata = _get_metadata('worker-network-endpoints') + + # Workers have format 'hostname:uid:ip,hostname:uid:ip,...' + workers = metadata.split(',') + ips = [worker.split(':')[2] for worker in workers] + + return ips if len(ips) > 1 else ['localhost'] + def configure_topology(local_rank: int, local_world_size: int, base_port: int = 8476): tpu_env = get_tpu_env() @@ -46,7 +52,7 @@ def configure_topology(local_rank: int, local_world_size: int, base_port: int = chips_per_process = _parse_mesh_shape(tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) # Process bounds with 1 chip per process - process_bounds = _multiple_mesh_shapes(default_process_bounds, chips_per_process) + process_bounds = _multiply_mesh_shapes(default_process_bounds, chips_per_process) os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, ','.join(str(dim) for dim in process_bounds)) @@ -55,11 +61,10 @@ def configure_topology(local_rank: int, local_world_size: int, base_port: int = worker_id = int(tpu_env['WORKER_ID']) os.environ.setdefault(xenv.CLOUD_TPU_TASK_ID, str(worker_id * local_world_size + local_rank)) - client = cloud_tpu_client.Client(tpu=tpu_env['NODE_ID']) - host_ips = [e['ipAddress'] for e in client.network_endpoints()] + worker_ips = get_worker_ips() ports = list(range(base_port, base_port + local_world_size)) - process_endpoints = [','.join(f'{ip}:{port}' for port in ports) for ip in host_ips] + process_endpoints = [','.join(f'{ip}:{port}' for port in ports) for ip in worker_ips] os.environ.setdefault(xenv.TPU_PROCESS_ADDRESSES, ','.join(process_endpoints)) os.environ.setdefault(xenv.TPU_VISIBLE_DEVICES, str(local_rank)) From 653d75753011bf616260c94b95d0be4b7ef4bfb9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 29 Jul 2022 19:57:11 +0000 Subject: [PATCH 03/19] Add TPU tests --- test/pjrt/test_experimental_pjrt_tpu.py | 151 ++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 test/pjrt/test_experimental_pjrt_tpu.py diff --git a/test/pjrt/test_experimental_pjrt_tpu.py b/test/pjrt/test_experimental_pjrt_tpu.py new file mode 100644 index 000000000000..677d107efe04 --- /dev/null +++ b/test/pjrt/test_experimental_pjrt_tpu.py @@ -0,0 +1,151 @@ +import concurrent.futures +import functools +import itertools +import os +import time +import requests + +import torch +import torch_xla +from absl.testing import absltest, parameterized +import torch_xla.core.xla_env_vars as xenv +import torch_xla.core.xla_model as xm +from torch_xla.experimental import pjrt +from torch_xla.experimental import tpu + + +def _get_real_devices(): + """Wraps `_xla_get_devices` to make it pickle-able""" + return torch_xla._XLAC._xla_get_devices() + +def _get_all_real_devices(): + """Wraps `_xla_get_all_devices` to make it pickle-able""" + return torch_xla._XLAC._xla_get_all_devices() + +class TestExperimentalPjrtTpu(parameterized.TestCase): + def setUp(self): + time.sleep(1) + pjrt.set_device_type('TPU') + + os.environ.pop(xenv.TPU_VISIBLE_DEVICES, None) + os.environ.pop(xenv.TPU_PROCESS_BOUNDS, None) + + try: + tpu_env = tpu.get_tpu_env() + self.accelerator_type = tpu_env['ACCELERATOR_TYPE'] + except requests.HTTPError as e: + raise EnvironmentError('Failed to get TPU metadata. Are you running on a TPU?') from e + + # TODO: assert ComputationClient is not initialized + # The main process must not initialize the ComputationClient, otherwise + # sub-processes will not be able to initialize the client witht the correct + # settings. + + def test_xla_devices_multiprocess(self): + accelerator_devices = { + 'v3-8': { + 0: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + 1: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + 2: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + 3: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + }, + 'v4-8': { + 0: {0: torch.device('xla:0')}, + 1: {0: torch.device('xla:0')}, + 2: {0: torch.device('xla:0')}, + 3: {0: torch.device('xla:0')}, + }, + } + + if self.accelerator_type not in accelerator_devices: + raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + expected = accelerator_devices[self.accelerator_type] + + devices_per_process = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices_per_process, expected) + + def test_real_devices_multiprocess(self): + accelerator_devices = { + 'v3-8': { + 0: { + 0: ['TPU:0', 'TPU:1'], + 1: ['TPU:0', 'TPU:1'], + }, + 1: { + 0: ['TPU:2', 'TPU:3'], + 1: ['TPU:2', 'TPU:3'], + }, + 2: { + 0: ['TPU:4', 'TPU:5'], + 1: ['TPU:4', 'TPU:5'], + }, + 3: { + 0: ['TPU:6', 'TPU:7'], + 1: ['TPU:6', 'TPU:7'], + }, + }, + 'v4-8': { + 0: {0: ['TPU:0']}, + 1: {0: ['TPU:2']}, + 2: {0: ['TPU:3']}, + 3: {0: ['TPU:1']}, + }, + } + + if self.accelerator_type not in accelerator_devices: + raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + expected = accelerator_devices[self.accelerator_type] + + + devices_per_process = pjrt.run_multiprocess(_get_real_devices) + self.assertDictEqual(devices_per_process, expected) + + all_devices = sorted(itertools.chain.from_iterable(process_devices[0] for process_devices in expected.values())) + expected_all_devices = { + rank: {thread: all_devices for thread in expected[0].keys()} for rank in expected.keys() + } + + all_devices_per_process = pjrt.run_multiprocess(_get_all_real_devices) + self.assertDictEqual(all_devices_per_process, expected_all_devices) + + def test_single_process_all_chips(self): + pass + + def test_single_process_one_chip(self): + accelerator_devices = { + 'v3-8': { + 0: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + }, + 'v4-8': { + 0: {0: torch.device('xla:0')}, + }, + } + + if self.accelerator_type not in accelerator_devices: + raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + expected = accelerator_devices[self.accelerator_type] + + os.environ[xenv.TPU_VISIBLE_DEVICES] = '0' + os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1' + + devices = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices, expected) + + +if __name__ == '__main__': + absltest.main() From fb5155f2d020979de718fa08cc1bf54ff8504631 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 29 Jul 2022 21:20:51 +0000 Subject: [PATCH 04/19] Fix computing num local processes. --- test/pjrt/test_experimental_pjrt_tpu.py | 43 ++++++++++++++++++++++--- torch_xla/experimental/pjrt.py | 9 +----- torch_xla/experimental/tpu.py | 16 +++++++++ 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/test/pjrt/test_experimental_pjrt_tpu.py b/test/pjrt/test_experimental_pjrt_tpu.py index 677d107efe04..a2b0bf4de624 100644 --- a/test/pjrt/test_experimental_pjrt_tpu.py +++ b/test/pjrt/test_experimental_pjrt_tpu.py @@ -24,7 +24,6 @@ def _get_all_real_devices(): class TestExperimentalPjrtTpu(parameterized.TestCase): def setUp(self): - time.sleep(1) pjrt.set_device_type('TPU') os.environ.pop(xenv.TPU_VISIBLE_DEVICES, None) @@ -120,10 +119,31 @@ def test_real_devices_multiprocess(self): all_devices_per_process = pjrt.run_multiprocess(_get_all_real_devices) self.assertDictEqual(all_devices_per_process, expected_all_devices) - def test_single_process_all_chips(self): - pass + def test_xla_devices_single_process_all_chips(self): + accelerator_devices = { + 'v3-8': { + 0: { + i: torch.device(f'xla:{i}') for i in range(8) + }, + }, + 'v4-8': { + 0: { + i: torch.device(f'xla:{i}') for i in range(4) + }, + }, + } + + if self.accelerator_type not in accelerator_devices: + raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + expected = accelerator_devices[self.accelerator_type] + + os.environ[xenv.TPU_VISIBLE_DEVICES] = '0,1,2,3' + os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1' + + devices = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices, expected) - def test_single_process_one_chip(self): + def test_xla_devices_single_process_one_chip(self): accelerator_devices = { 'v3-8': { 0: { @@ -146,6 +166,21 @@ def test_single_process_one_chip(self): devices = pjrt.run_multiprocess(xm.xla_device) self.assertDictEqual(devices, expected) + def test_default_xla_devices(self): + accelerator_num_devices = { + 'v3-8': 8, + 'v4-8': 4, + } + + if self.accelerator_type not in accelerator_num_devices: + raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + expected_num_devices = accelerator_num_devices[self.accelerator_type] + + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as e: + f = e.submit(xm.get_xla_supported_devices, 'TPU') + devices = [torch.device(d) for d in f.result()] + + self.assertListEqual(devices, [torch.device(f'xla:{i}') for i in range(expected_num_devices)]) if __name__ == '__main__': absltest.main() diff --git a/torch_xla/experimental/pjrt.py b/torch_xla/experimental/pjrt.py index 31ba8bf0c745..fcf316a9fc96 100644 --- a/torch_xla/experimental/pjrt.py +++ b/torch_xla/experimental/pjrt.py @@ -40,13 +40,6 @@ def using_pjrt() -> bool: return device_type() is not None -def num_visible_tpu_chips(default: int = 4) -> int: - """Returns number of TPU chips visible to current process.""" - visible_devices = xu.getenv_as(xenv.TPU_VISIBLE_DEVICES, str) - - return len(visible_devices.split(',')) if visible_devices else default - - def requires_pjrt(fn: FN) -> FN: """Wraps `fn` and checks if this process is using PjRt. @@ -187,7 +180,7 @@ def run_multiprocess(fn: Callable[..., R], *args, return_value is the result of calling `fn`. """ if device_type() == 'TPU': - processes = num_visible_tpu_chips() + processes = tpu.num_local_processes() else: processes = 1 diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index c9ea84292789..3a9b330a7848 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -1,3 +1,8 @@ +import functools + + +import functools +import operator import os from typing import Optional, List, Tuple import requests @@ -20,6 +25,9 @@ def _parse_mesh_shape(mesh: str) -> MeshShape: def _multiply_mesh_shapes(mesh1: MeshShape, mesh2: MeshShape) -> MeshShape: return tuple(d1 * d2 for d1, d2 in zip(mesh1, mesh2)) +def _mesh_size(mesh: MeshShape) -> int: + return functools.reduce(operator.mul, mesh) + def _get_metadata(key: str) -> str: path = os.path.join(_GCE_METADATA_ROOT_URL, 'instance/attributes', key) resp = requests.get(path, headers={'Metadata-Flavor': 'Google'}) @@ -27,6 +35,14 @@ def _get_metadata(key: str) -> str: return resp.text +def num_processes(default: int = 4) -> Optional[int]: + process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str) + + return _mesh_size(_parse_mesh_shape(process_bounds)) if process_bounds else default + +def num_local_processes() -> Optional[int]: + return min(4, num_processes()) + def task_id() -> Optional[int]: return xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, int) From 22706096cbde066684d7fe9cda29eaac574f34b1 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 1 Aug 2022 17:30:13 +0000 Subject: [PATCH 05/19] Add unit tests for tpu.py --- test/pjrt/test_experimental_tpu.py | 160 +++++++++++++++++++++++++++++ torch_xla/experimental/tpu.py | 8 +- 2 files changed, 163 insertions(+), 5 deletions(-) create mode 100644 test/pjrt/test_experimental_tpu.py diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py new file mode 100644 index 000000000000..cf9a448f1b25 --- /dev/null +++ b/test/pjrt/test_experimental_tpu.py @@ -0,0 +1,160 @@ +from cmath import exp +import os +import textwrap + +from absl.testing import absltest, parameterized +import torch_xla.core.xla_env_vars as xenv +from torch_xla.experimental import tpu + +from unittest import mock + +class TestExperimentalPjrtTpu(parameterized.TestCase): + @parameterized.named_parameters( + ('default_one_host', None, 4), + ('one_process_one_host', '1,1,1', 1), + ('multi_process_one_host', '2,2,1', 4), + ('multi_process_v4-16', '2,2,2', 8), + ('multi_process_v4-32', '2,2,4', 16), + ) + def test_num_processes(self, process_bounds, expected): + envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {} + with mock.patch.dict(os.environ, envs, clear=True): + n = tpu.num_processes() + + self.assertEqual(n, expected) + + @parameterized.named_parameters( + ('default_one_host', None, 4), + ('one_process_one_host', '1,1,1', 1), + ('multi_process_one_host', '2,2,1', 4), + ('multi_process_v4-16', '2,2,2', 4), + ('multi_process_v4-32', '2,2,4', 4), + ) + def test_num_local_processes(self, process_bounds, expected): + envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {} + with mock.patch.dict(os.environ, envs, clear=True): + n = tpu.num_local_processes() + + self.assertEqual(n, expected) + + + @parameterized.parameters( + (None, None), + ('0', 0), + ('1', 1), + ('15', 15) + ) + def test_task_id(self, task_id, expected): + envs = {xenv.CLOUD_TPU_TASK_ID: task_id} if task_id else {} + with mock.patch.dict(os.environ, envs, clear=True): + i = tpu.task_id() + + self.assertEqual(i, expected) + + def test_tpu_env(self): + tpu_env_yaml = textwrap.dedent(""" + ACCELERATOR_TYPE: 'v4-16' + CHIPS_PER_HOST_BOUNDS: '2,2,1' + HOST_BOUNDS: '1,1,2' + TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1' + TPU_PROCESS_BOUNDS: '1,1,2' + ZONE: 'us-central2-b' + """) + + with mock.patch.object(tpu, '_get_metadata', return_value=tpu_env_yaml): + tpu_env = tpu.get_tpu_env() + + self.assertDictEqual(tpu_env, { + 'ACCELERATOR_TYPE': 'v4-16', + 'CHIPS_PER_HOST_BOUNDS': '2,2,1', + 'HOST_BOUNDS': '1,1,2', + 'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1', + 'TPU_PROCESS_BOUNDS': '1,1,2', + 'ZONE': 'us-central2-b', + }) + + @parameterized.named_parameters( + ('one_host', 't1v-n-ea9d3291-w-0:12345:10.130.0.31', ['localhost']), + ( + 'four_hosts', + 't1v-n-0f996b37-w-0:12345:10.130.0.26,t1v-n-0f996b37-w-1:12346:10.130.0.27,t1v-n-0f996b37-w-2:12347:10.130.0.25,t1v-n-0f996b37-w-3:12348:10.130.0.28', + ['10.130.0.26', '10.130.0.27', '10.130.0.25', '10.130.0.28'], + ), + ) + def test_get_worker_ips(self, worker_network_endpoints, expected): + with mock.patch.object(tpu, '_get_metadata', return_value=worker_network_endpoints): + worker_ips = tpu.get_worker_ips() + + self.assertListEqual(worker_ips, expected) + + @parameterized.named_parameters( + ( + 'v4-8_process_0', + { + xenv.TPU_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', + 'WORKER_ID': '0' + }, + ['localhost'], + 0, + 4, + { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_PROCESS_BOUNDS: '2,2,1', + xenv.CLOUD_TPU_TASK_ID: '0', + xenv.TPU_PROCESS_PORT: '8476', + xenv.TPU_PROCESS_ADDRESSES: 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_DEVICES: '0', + } + ), + ( + 'v4-8_process_3', + { + xenv.TPU_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', + 'WORKER_ID': '0' + }, + ['localhost'], + 3, + 4, + { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_PROCESS_BOUNDS: '2,2,1', + xenv.CLOUD_TPU_TASK_ID: '3', + xenv.TPU_PROCESS_PORT: '8479', + xenv.TPU_PROCESS_ADDRESSES: 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_DEVICES: '3', + } + ), + ( + 'v4-16_worker_1_process_0', + { + xenv.TPU_PROCESS_BOUNDS: '1,1,2', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', + 'WORKER_ID': '1' + }, + ['10.130.0.31', '10.130.0.30'], + 0, + 4, + { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_PROCESS_BOUNDS: '2,2,2', + xenv.CLOUD_TPU_TASK_ID: '4', + xenv.TPU_PROCESS_PORT: '8476', + xenv.TPU_PROCESS_ADDRESSES: '10.130.0.31:8476,10.130.0.31:8477,10.130.0.31:8478,10.130.0.31:8479,10.130.0.30:8476,10.130.0.30:8477,10.130.0.30:8478,10.130.0.30:8479', + xenv.TPU_VISIBLE_DEVICES: '0', + } + ), + ) + def test_configure_tpu_topology(self, tpu_env, worker_ips, local_rank, local_world_size, expected): + with mock.patch.object(tpu, 'get_tpu_env', return_value=tpu_env), \ + mock.patch.object(tpu, 'get_worker_ips', return_value=worker_ips), \ + mock.patch.dict(os.environ, clear=True) as mock_env: + + tpu.configure_topology(local_rank, local_world_size) + + self.assertDictContainsSubset(expected, mock_env) + + +if __name__ == '__main__': + absltest.main() diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index 3a9b330a7848..2e6a3c846480 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -1,10 +1,7 @@ -import functools - - import functools import operator import os -from typing import Optional, List, Tuple +from typing import Dict, Optional, List, Tuple import requests import yaml @@ -41,12 +38,13 @@ def num_processes(default: int = 4) -> Optional[int]: return _mesh_size(_parse_mesh_shape(process_bounds)) if process_bounds else default def num_local_processes() -> Optional[int]: + # Don't create more processes than local chips (4) return min(4, num_processes()) def task_id() -> Optional[int]: return xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, int) -def get_tpu_env(): +def get_tpu_env() -> Dict[str, str]: metadata = _get_metadata('tpu-env') return yaml.load(metadata, yaml.Loader) From 0ef59d0f97cf187e6f5fcf76f619173d9972ed74 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 1 Aug 2022 17:31:40 +0000 Subject: [PATCH 06/19] Remove unused imports --- test/pjrt/test_experimental_pjrt_tpu.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/pjrt/test_experimental_pjrt_tpu.py b/test/pjrt/test_experimental_pjrt_tpu.py index a2b0bf4de624..583965563abc 100644 --- a/test/pjrt/test_experimental_pjrt_tpu.py +++ b/test/pjrt/test_experimental_pjrt_tpu.py @@ -1,8 +1,6 @@ import concurrent.futures -import functools import itertools import os -import time import requests import torch From cf64ce3300082dd85486f7edd128c668dd590491 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 1 Aug 2022 17:32:24 +0000 Subject: [PATCH 07/19] formatting --- test/pjrt/test_experimental_pjrt_tpu.py | 172 +++++++++++++---------- test/pjrt/test_experimental_tpu.py | 176 ++++++++++++------------ torch_xla/experimental/tpu.py | 33 ++++- 3 files changed, 212 insertions(+), 169 deletions(-) diff --git a/test/pjrt/test_experimental_pjrt_tpu.py b/test/pjrt/test_experimental_pjrt_tpu.py index 583965563abc..d15f4671c93f 100644 --- a/test/pjrt/test_experimental_pjrt_tpu.py +++ b/test/pjrt/test_experimental_pjrt_tpu.py @@ -16,11 +16,14 @@ def _get_real_devices(): """Wraps `_xla_get_devices` to make it pickle-able""" return torch_xla._XLAC._xla_get_devices() + def _get_all_real_devices(): """Wraps `_xla_get_all_devices` to make it pickle-able""" return torch_xla._XLAC._xla_get_all_devices() + class TestExperimentalPjrtTpu(parameterized.TestCase): + def setUp(self): pjrt.set_device_type('TPU') @@ -31,7 +34,8 @@ def setUp(self): tpu_env = tpu.get_tpu_env() self.accelerator_type = tpu_env['ACCELERATOR_TYPE'] except requests.HTTPError as e: - raise EnvironmentError('Failed to get TPU metadata. Are you running on a TPU?') from e + raise EnvironmentError( + 'Failed to get TPU metadata. Are you running on a TPU?') from e # TODO: assert ComputationClient is not initialized # The main process must not initialize the ComputationClient, otherwise @@ -40,34 +44,43 @@ def setUp(self): def test_xla_devices_multiprocess(self): accelerator_devices = { - 'v3-8': { - 0: { - 0: torch.device('xla:0'), - 1: torch.device('xla:1'), - }, - 1: { - 0: torch.device('xla:0'), - 1: torch.device('xla:1'), - }, - 2: { - 0: torch.device('xla:0'), - 1: torch.device('xla:1'), + 'v3-8': { + 0: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + 1: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + 2: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + 3: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, }, - 3: { - 0: torch.device('xla:0'), - 1: torch.device('xla:1'), + 'v4-8': { + 0: { + 0: torch.device('xla:0') + }, + 1: { + 0: torch.device('xla:0') + }, + 2: { + 0: torch.device('xla:0') + }, + 3: { + 0: torch.device('xla:0') + }, }, - }, - 'v4-8': { - 0: {0: torch.device('xla:0')}, - 1: {0: torch.device('xla:0')}, - 2: {0: torch.device('xla:0')}, - 3: {0: torch.device('xla:0')}, - }, } if self.accelerator_type not in accelerator_devices: - raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + raise NotImplementedError('Test not implemented for {}'.format( + self.accelerator_type)) expected = accelerator_devices[self.accelerator_type] devices_per_process = pjrt.run_multiprocess(xm.xla_device) @@ -75,43 +88,54 @@ def test_xla_devices_multiprocess(self): def test_real_devices_multiprocess(self): accelerator_devices = { - 'v3-8': { - 0: { - 0: ['TPU:0', 'TPU:1'], - 1: ['TPU:0', 'TPU:1'], + 'v3-8': { + 0: { + 0: ['TPU:0', 'TPU:1'], + 1: ['TPU:0', 'TPU:1'], + }, + 1: { + 0: ['TPU:2', 'TPU:3'], + 1: ['TPU:2', 'TPU:3'], + }, + 2: { + 0: ['TPU:4', 'TPU:5'], + 1: ['TPU:4', 'TPU:5'], + }, + 3: { + 0: ['TPU:6', 'TPU:7'], + 1: ['TPU:6', 'TPU:7'], + }, }, - 1: { - 0: ['TPU:2', 'TPU:3'], - 1: ['TPU:2', 'TPU:3'], + 'v4-8': { + 0: { + 0: ['TPU:0'] + }, + 1: { + 0: ['TPU:2'] + }, + 2: { + 0: ['TPU:3'] + }, + 3: { + 0: ['TPU:1'] + }, }, - 2: { - 0: ['TPU:4', 'TPU:5'], - 1: ['TPU:4', 'TPU:5'], - }, - 3: { - 0: ['TPU:6', 'TPU:7'], - 1: ['TPU:6', 'TPU:7'], - }, - }, - 'v4-8': { - 0: {0: ['TPU:0']}, - 1: {0: ['TPU:2']}, - 2: {0: ['TPU:3']}, - 3: {0: ['TPU:1']}, - }, } if self.accelerator_type not in accelerator_devices: - raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + raise NotImplementedError('Test not implemented for {}'.format( + self.accelerator_type)) expected = accelerator_devices[self.accelerator_type] - devices_per_process = pjrt.run_multiprocess(_get_real_devices) self.assertDictEqual(devices_per_process, expected) - all_devices = sorted(itertools.chain.from_iterable(process_devices[0] for process_devices in expected.values())) + all_devices = sorted( + itertools.chain.from_iterable( + process_devices[0] for process_devices in expected.values())) expected_all_devices = { - rank: {thread: all_devices for thread in expected[0].keys()} for rank in expected.keys() + rank: {thread: all_devices for thread in expected[0].keys() + } for rank in expected.keys() } all_devices_per_process = pjrt.run_multiprocess(_get_all_real_devices) @@ -119,20 +143,17 @@ def test_real_devices_multiprocess(self): def test_xla_devices_single_process_all_chips(self): accelerator_devices = { - 'v3-8': { - 0: { - i: torch.device(f'xla:{i}') for i in range(8) + 'v3-8': { + 0: {i: torch.device(f'xla:{i}') for i in range(8)}, }, - }, - 'v4-8': { - 0: { - i: torch.device(f'xla:{i}') for i in range(4) + 'v4-8': { + 0: {i: torch.device(f'xla:{i}') for i in range(4)}, }, - }, } if self.accelerator_type not in accelerator_devices: - raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + raise NotImplementedError('Test not implemented for {}'.format( + self.accelerator_type)) expected = accelerator_devices[self.accelerator_type] os.environ[xenv.TPU_VISIBLE_DEVICES] = '0,1,2,3' @@ -143,19 +164,22 @@ def test_xla_devices_single_process_all_chips(self): def test_xla_devices_single_process_one_chip(self): accelerator_devices = { - 'v3-8': { - 0: { - 0: torch.device('xla:0'), - 1: torch.device('xla:1'), + 'v3-8': { + 0: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + }, + }, + 'v4-8': { + 0: { + 0: torch.device('xla:0') + }, }, - }, - 'v4-8': { - 0: {0: torch.device('xla:0')}, - }, } if self.accelerator_type not in accelerator_devices: - raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + raise NotImplementedError('Test not implemented for {}'.format( + self.accelerator_type)) expected = accelerator_devices[self.accelerator_type] os.environ[xenv.TPU_VISIBLE_DEVICES] = '0' @@ -166,19 +190,23 @@ def test_xla_devices_single_process_one_chip(self): def test_default_xla_devices(self): accelerator_num_devices = { - 'v3-8': 8, - 'v4-8': 4, + 'v3-8': 8, + 'v4-8': 4, } if self.accelerator_type not in accelerator_num_devices: - raise NotImplementedError('Test not implemented for {}'.format(self.accelerator_type)) + raise NotImplementedError('Test not implemented for {}'.format( + self.accelerator_type)) expected_num_devices = accelerator_num_devices[self.accelerator_type] with concurrent.futures.ProcessPoolExecutor(max_workers=1) as e: f = e.submit(xm.get_xla_supported_devices, 'TPU') devices = [torch.device(d) for d in f.result()] - self.assertListEqual(devices, [torch.device(f'xla:{i}') for i in range(expected_num_devices)]) + self.assertListEqual( + devices, + [torch.device(f'xla:{i}') for i in range(expected_num_devices)]) + if __name__ == '__main__': absltest.main() diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py index cf9a448f1b25..bd21cba304ac 100644 --- a/test/pjrt/test_experimental_tpu.py +++ b/test/pjrt/test_experimental_tpu.py @@ -8,13 +8,15 @@ from unittest import mock + class TestExperimentalPjrtTpu(parameterized.TestCase): + @parameterized.named_parameters( - ('default_one_host', None, 4), - ('one_process_one_host', '1,1,1', 1), - ('multi_process_one_host', '2,2,1', 4), - ('multi_process_v4-16', '2,2,2', 8), - ('multi_process_v4-32', '2,2,4', 16), + ('default_one_host', None, 4), + ('one_process_one_host', '1,1,1', 1), + ('multi_process_one_host', '2,2,1', 4), + ('multi_process_v4-16', '2,2,2', 8), + ('multi_process_v4-32', '2,2,4', 16), ) def test_num_processes(self, process_bounds, expected): envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {} @@ -24,11 +26,11 @@ def test_num_processes(self, process_bounds, expected): self.assertEqual(n, expected) @parameterized.named_parameters( - ('default_one_host', None, 4), - ('one_process_one_host', '1,1,1', 1), - ('multi_process_one_host', '2,2,1', 4), - ('multi_process_v4-16', '2,2,2', 4), - ('multi_process_v4-32', '2,2,4', 4), + ('default_one_host', None, 4), + ('one_process_one_host', '1,1,1', 1), + ('multi_process_one_host', '2,2,1', 4), + ('multi_process_v4-16', '2,2,2', 4), + ('multi_process_v4-32', '2,2,4', 4), ) def test_num_local_processes(self, process_bounds, expected): envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {} @@ -37,13 +39,7 @@ def test_num_local_processes(self, process_bounds, expected): self.assertEqual(n, expected) - - @parameterized.parameters( - (None, None), - ('0', 0), - ('1', 1), - ('15', 15) - ) + @parameterized.parameters((None, None), ('0', 0), ('1', 1), ('15', 15)) def test_task_id(self, task_id, expected): envs = {xenv.CLOUD_TPU_TASK_ID: task_id} if task_id else {} with mock.patch.dict(os.environ, envs, clear=True): @@ -64,89 +60,89 @@ def test_tpu_env(self): with mock.patch.object(tpu, '_get_metadata', return_value=tpu_env_yaml): tpu_env = tpu.get_tpu_env() - self.assertDictEqual(tpu_env, { - 'ACCELERATOR_TYPE': 'v4-16', - 'CHIPS_PER_HOST_BOUNDS': '2,2,1', - 'HOST_BOUNDS': '1,1,2', - 'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1', - 'TPU_PROCESS_BOUNDS': '1,1,2', - 'ZONE': 'us-central2-b', - }) + self.assertDictEqual( + tpu_env, { + 'ACCELERATOR_TYPE': 'v4-16', + 'CHIPS_PER_HOST_BOUNDS': '2,2,1', + 'HOST_BOUNDS': '1,1,2', + 'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1', + 'TPU_PROCESS_BOUNDS': '1,1,2', + 'ZONE': 'us-central2-b', + }) @parameterized.named_parameters( - ('one_host', 't1v-n-ea9d3291-w-0:12345:10.130.0.31', ['localhost']), - ( - 'four_hosts', - 't1v-n-0f996b37-w-0:12345:10.130.0.26,t1v-n-0f996b37-w-1:12346:10.130.0.27,t1v-n-0f996b37-w-2:12347:10.130.0.25,t1v-n-0f996b37-w-3:12348:10.130.0.28', - ['10.130.0.26', '10.130.0.27', '10.130.0.25', '10.130.0.28'], - ), + ('one_host', 't1v-n-ea9d3291-w-0:12345:10.130.0.31', ['localhost']), + ( + 'four_hosts', + 't1v-n-0f996b37-w-0:12345:10.130.0.26,t1v-n-0f996b37-w-1:12346:10.130.0.27,t1v-n-0f996b37-w-2:12347:10.130.0.25,t1v-n-0f996b37-w-3:12348:10.130.0.28', + ['10.130.0.26', '10.130.0.27', '10.130.0.25', '10.130.0.28'], + ), ) def test_get_worker_ips(self, worker_network_endpoints, expected): - with mock.patch.object(tpu, '_get_metadata', return_value=worker_network_endpoints): + with mock.patch.object( + tpu, '_get_metadata', return_value=worker_network_endpoints): worker_ips = tpu.get_worker_ips() self.assertListEqual(worker_ips, expected) @parameterized.named_parameters( - ( - 'v4-8_process_0', - { - xenv.TPU_PROCESS_BOUNDS: '1,1,1', - xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', - 'WORKER_ID': '0' - }, - ['localhost'], - 0, - 4, - { - xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', - xenv.TPU_PROCESS_BOUNDS: '2,2,1', - xenv.CLOUD_TPU_TASK_ID: '0', - xenv.TPU_PROCESS_PORT: '8476', - xenv.TPU_PROCESS_ADDRESSES: 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', - xenv.TPU_VISIBLE_DEVICES: '0', - } - ), - ( - 'v4-8_process_3', - { - xenv.TPU_PROCESS_BOUNDS: '1,1,1', - xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', - 'WORKER_ID': '0' - }, - ['localhost'], - 3, - 4, - { - xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', - xenv.TPU_PROCESS_BOUNDS: '2,2,1', - xenv.CLOUD_TPU_TASK_ID: '3', - xenv.TPU_PROCESS_PORT: '8479', - xenv.TPU_PROCESS_ADDRESSES: 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', - xenv.TPU_VISIBLE_DEVICES: '3', - } - ), - ( - 'v4-16_worker_1_process_0', - { - xenv.TPU_PROCESS_BOUNDS: '1,1,2', - xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', - 'WORKER_ID': '1' - }, - ['10.130.0.31', '10.130.0.30'], - 0, - 4, - { - xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', - xenv.TPU_PROCESS_BOUNDS: '2,2,2', - xenv.CLOUD_TPU_TASK_ID: '4', - xenv.TPU_PROCESS_PORT: '8476', - xenv.TPU_PROCESS_ADDRESSES: '10.130.0.31:8476,10.130.0.31:8477,10.130.0.31:8478,10.130.0.31:8479,10.130.0.30:8476,10.130.0.30:8477,10.130.0.30:8478,10.130.0.30:8479', - xenv.TPU_VISIBLE_DEVICES: '0', - } - ), + ('v4-8_process_0', { + xenv.TPU_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', + 'WORKER_ID': '0' + }, ['localhost'], 0, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,1', + xenv.CLOUD_TPU_TASK_ID: + '0', + xenv.TPU_PROCESS_PORT: + '8476', + xenv.TPU_PROCESS_ADDRESSES: + 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_DEVICES: + '0', + }), + ('v4-8_process_3', { + xenv.TPU_PROCESS_BOUNDS: '1,1,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', + 'WORKER_ID': '0' + }, ['localhost'], 3, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,1', + xenv.CLOUD_TPU_TASK_ID: + '3', + xenv.TPU_PROCESS_PORT: + '8479', + xenv.TPU_PROCESS_ADDRESSES: + 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_DEVICES: + '3', + }), + ('v4-16_worker_1_process_0', { + xenv.TPU_PROCESS_BOUNDS: '1,1,2', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', + 'WORKER_ID': '1' + }, ['10.130.0.31', '10.130.0.30'], 0, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,2', + xenv.CLOUD_TPU_TASK_ID: + '4', + xenv.TPU_PROCESS_PORT: + '8476', + xenv.TPU_PROCESS_ADDRESSES: + '10.130.0.31:8476,10.130.0.31:8477,10.130.0.31:8478,10.130.0.31:8479,10.130.0.30:8476,10.130.0.30:8477,10.130.0.30:8478,10.130.0.30:8479', + xenv.TPU_VISIBLE_DEVICES: + '0', + }), ) - def test_configure_tpu_topology(self, tpu_env, worker_ips, local_rank, local_world_size, expected): + def test_configure_tpu_topology(self, tpu_env, worker_ips, local_rank, + local_world_size, expected): with mock.patch.object(tpu, 'get_tpu_env', return_value=tpu_env), \ mock.patch.object(tpu, 'get_worker_ips', return_value=worker_ips), \ mock.patch.dict(os.environ, clear=True) as mock_env: diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index 2e6a3c846480..dbdc29068c29 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -12,6 +12,7 @@ MeshShape = Tuple[int, int, int] + def _parse_mesh_shape(mesh: str) -> MeshShape: dims = tuple(int(d) for d in mesh.split(',')) if len(dims) != 3: @@ -19,12 +20,15 @@ def _parse_mesh_shape(mesh: str) -> MeshShape: return dims + def _multiply_mesh_shapes(mesh1: MeshShape, mesh2: MeshShape) -> MeshShape: return tuple(d1 * d2 for d1, d2 in zip(mesh1, mesh2)) + def _mesh_size(mesh: MeshShape) -> int: return functools.reduce(operator.mul, mesh) + def _get_metadata(key: str) -> str: path = os.path.join(_GCE_METADATA_ROOT_URL, 'instance/attributes', key) resp = requests.get(path, headers={'Metadata-Flavor': 'Google'}) @@ -32,23 +36,29 @@ def _get_metadata(key: str) -> str: return resp.text + def num_processes(default: int = 4) -> Optional[int]: process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str) - return _mesh_size(_parse_mesh_shape(process_bounds)) if process_bounds else default + return _mesh_size( + _parse_mesh_shape(process_bounds)) if process_bounds else default + def num_local_processes() -> Optional[int]: # Don't create more processes than local chips (4) return min(4, num_processes()) + def task_id() -> Optional[int]: return xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, int) + def get_tpu_env() -> Dict[str, str]: metadata = _get_metadata('tpu-env') return yaml.load(metadata, yaml.Loader) + def get_worker_ips() -> List[str]: metadata = _get_metadata('worker-network-endpoints') @@ -58,27 +68,36 @@ def get_worker_ips() -> List[str]: return ips if len(ips) > 1 else ['localhost'] -def configure_topology(local_rank: int, local_world_size: int, base_port: int = 8476): + +def configure_topology(local_rank: int, + local_world_size: int, + base_port: int = 8476): tpu_env = get_tpu_env() # Process bounds with 4 chips per process default_process_bounds = _parse_mesh_shape(tpu_env[xenv.TPU_PROCESS_BOUNDS]) - chips_per_process = _parse_mesh_shape(tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) + chips_per_process = _parse_mesh_shape( + tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) # Process bounds with 1 chip per process - process_bounds = _multiply_mesh_shapes(default_process_bounds, chips_per_process) + process_bounds = _multiply_mesh_shapes(default_process_bounds, + chips_per_process) os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') - os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, ','.join(str(dim) for dim in process_bounds)) + os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, + ','.join(str(dim) for dim in process_bounds)) # Assume each TPU has the same number of local processes with the same ports worker_id = int(tpu_env['WORKER_ID']) - os.environ.setdefault(xenv.CLOUD_TPU_TASK_ID, str(worker_id * local_world_size + local_rank)) + os.environ.setdefault(xenv.CLOUD_TPU_TASK_ID, + str(worker_id * local_world_size + local_rank)) worker_ips = get_worker_ips() ports = list(range(base_port, base_port + local_world_size)) - process_endpoints = [','.join(f'{ip}:{port}' for port in ports) for ip in worker_ips] + process_endpoints = [ + ','.join(f'{ip}:{port}' for port in ports) for ip in worker_ips + ] os.environ.setdefault(xenv.TPU_PROCESS_ADDRESSES, ','.join(process_endpoints)) os.environ.setdefault(xenv.TPU_VISIBLE_DEVICES, str(local_rank)) From 00c50cb15a65f608b8b466d0a1f08c39b1b31ea7 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 1 Aug 2022 21:13:07 +0000 Subject: [PATCH 08/19] Don't use metadata for process bounds on v3 --- test/pjrt/test_experimental_tpu.py | 25 ++++++++++++++++++++--- torch_xla/experimental/tpu.py | 32 ++++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py index bd21cba304ac..c53d531c6a1e 100644 --- a/test/pjrt/test_experimental_tpu.py +++ b/test/pjrt/test_experimental_tpu.py @@ -1,4 +1,3 @@ -from cmath import exp import os import textwrap @@ -9,7 +8,7 @@ from unittest import mock -class TestExperimentalPjrtTpu(parameterized.TestCase): +class TestExperimentalTpu(parameterized.TestCase): @parameterized.named_parameters( ('default_one_host', None, 4), @@ -87,6 +86,7 @@ def test_get_worker_ips(self, worker_network_endpoints, expected): @parameterized.named_parameters( ('v4-8_process_0', { + 'ACCELERATOR_TYPE': 'v4-8', xenv.TPU_PROCESS_BOUNDS: '1,1,1', xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', 'WORKER_ID': '0' @@ -105,6 +105,7 @@ def test_get_worker_ips(self, worker_network_endpoints, expected): '0', }), ('v4-8_process_3', { + 'ACCELERATOR_TYPE': 'v4-8', xenv.TPU_PROCESS_BOUNDS: '1,1,1', xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', 'WORKER_ID': '0' @@ -123,6 +124,7 @@ def test_get_worker_ips(self, worker_network_endpoints, expected): '3', }), ('v4-16_worker_1_process_0', { + 'ACCELERATOR_TYPE': 'v4-16', xenv.TPU_PROCESS_BOUNDS: '1,1,2', xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1', 'WORKER_ID': '1' @@ -140,7 +142,24 @@ def test_get_worker_ips(self, worker_network_endpoints, expected): xenv.TPU_VISIBLE_DEVICES: '0', }), - ) + # TODO: remove this case when process bounds are added to metadata + ('v3-8_process_0', { + 'ACCELERATOR_TYPE': 'v3-8', + 'WORKER_ID': '0' + }, ['localhost'], 0, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,1', + xenv.CLOUD_TPU_TASK_ID: + '0', + xenv.TPU_PROCESS_PORT: + '8476', + xenv.TPU_PROCESS_ADDRESSES: + 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_DEVICES: + '0', + })) def test_configure_tpu_topology(self, tpu_env, worker_ips, local_rank, local_world_size, expected): with mock.patch.object(tpu, 'get_tpu_env', return_value=tpu_env), \ diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index dbdc29068c29..87d082939f3f 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -9,6 +9,23 @@ import torch_xla.core.xla_env_vars as xenv _GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1' +_ACCELERATOR_TYPE_TO_HOST_BOUNDS = { + # v2 + 'v2-8': '1,1,1', + 'v2-32': '2,2,1', + 'v2-128': '4,4,1', + 'v2-256': '4,8,1', + 'v2-512': '8,8,1', + # v3 + 'v3-8': '1,1,1', + 'v3-32': '2,2,1', + 'v3-64': '2,4,1', + 'v3-128': '4,4,1', + 'v3-256': '4,8,1', + 'v3-512': '8,8,1', + 'v3-1024': '8,16,1', + 'v3-2048': '16,16,1', +} MeshShape = Tuple[int, int, int] @@ -74,10 +91,17 @@ def configure_topology(local_rank: int, base_port: int = 8476): tpu_env = get_tpu_env() - # Process bounds with 4 chips per process - default_process_bounds = _parse_mesh_shape(tpu_env[xenv.TPU_PROCESS_BOUNDS]) - chips_per_process = _parse_mesh_shape( - tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) + accelerator_type = tpu_env['ACCELERATOR_TYPE'] + if tpu_env['ACCELERATOR_TYPE'].startswith('v4'): + # Process bounds with 4 chips per process + default_process_bounds = _parse_mesh_shape(tpu_env[xenv.TPU_PROCESS_BOUNDS]) + chips_per_process = _parse_mesh_shape( + tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) + else: + # TODO: merge with TPU v4 case when bounds are added to metadata + default_process_bounds = _parse_mesh_shape( + _ACCELERATOR_TYPE_TO_HOST_BOUNDS[accelerator_type]) + chips_per_process = _parse_mesh_shape('2,2,1') # Process bounds with 1 chip per process process_bounds = _multiply_mesh_shapes(default_process_bounds, From c902cc6e2ca80b40d3c87ac52958ec52206802b3 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 1 Aug 2022 21:15:49 +0000 Subject: [PATCH 09/19] Fix real device indices for v3 --- test/pjrt/test_experimental_pjrt_tpu.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/pjrt/test_experimental_pjrt_tpu.py b/test/pjrt/test_experimental_pjrt_tpu.py index d15f4671c93f..00c134e88096 100644 --- a/test/pjrt/test_experimental_pjrt_tpu.py +++ b/test/pjrt/test_experimental_pjrt_tpu.py @@ -87,6 +87,7 @@ def test_xla_devices_multiprocess(self): self.assertDictEqual(devices_per_process, expected) def test_real_devices_multiprocess(self): + # Real devices unfortunately don't correspond to indices in TPU_VISIBLE_DEVICES accelerator_devices = { 'v3-8': { 0: { @@ -94,12 +95,12 @@ def test_real_devices_multiprocess(self): 1: ['TPU:0', 'TPU:1'], }, 1: { - 0: ['TPU:2', 'TPU:3'], - 1: ['TPU:2', 'TPU:3'], + 0: ['TPU:4', 'TPU:5'], + 1: ['TPU:4', 'TPU:5'] }, 2: { - 0: ['TPU:4', 'TPU:5'], - 1: ['TPU:4', 'TPU:5'], + 0: ['TPU:2', 'TPU:3'], + 1: ['TPU:2', 'TPU:3'] }, 3: { 0: ['TPU:6', 'TPU:7'], From 792e03c98ba404c474f555a1df391506f3ebf57a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 1 Aug 2022 21:20:39 +0000 Subject: [PATCH 10/19] Add TPU unit test to run_tests.sh --- test/run_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_tests.sh b/test/run_tests.sh index 4078eeef9e07..de5e75b1a693 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -109,6 +109,7 @@ function run_op_tests { run_test python3 "$CDIR/test_torch_distributed_xla_backend.py" run_xla_ir_debug python3 "$CDIR/test_env_var_mapper.py" run_pjrt python3 "$CDIR/pjrt/test_experimental_pjrt.py" + run_pjrt python3 "$CDIR/pjrt/test_experimental_tpu.py" } function run_mp_op_tests { From 43c64630dcc1936accaf501c4e7362180ef73cd8 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Aug 2022 18:05:51 +0000 Subject: [PATCH 11/19] Update docstrings --- torch_xla/experimental/tpu.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index 87d082939f3f..c6d05eedcca3 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -55,6 +55,7 @@ def _get_metadata(key: str) -> str: def num_processes(default: int = 4) -> Optional[int]: + """Returns number of processes across all TPU hosts.""" process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str) return _mesh_size( @@ -62,21 +63,25 @@ def num_processes(default: int = 4) -> Optional[int]: def num_local_processes() -> Optional[int]: + """Returns number of processes to create on this host.""" # Don't create more processes than local chips (4) return min(4, num_processes()) def task_id() -> Optional[int]: + """Returns index of this process within all TPU worker processes, if any.""" return xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, int) def get_tpu_env() -> Dict[str, str]: + """Fetches and parses `tpu-env` metadata field.""" metadata = _get_metadata('tpu-env') return yaml.load(metadata, yaml.Loader) def get_worker_ips() -> List[str]: + """Returns ordered list of TPU worker IPs from TPU metadata.""" metadata = _get_metadata('worker-network-endpoints') # Workers have format 'hostname:uid:ip,hostname:uid:ip,...' @@ -88,7 +93,17 @@ def get_worker_ips() -> List[str]: def configure_topology(local_rank: int, local_world_size: int, - base_port: int = 8476): + base_port: int = 8476) -> None: + """Configures TPU topology environment variables based on TPU metadata. + + Must be run before using any XLA devices. + + Args: + local_rank: rank of this process within this host. + local_world_size: number of processes on this host. + base_port: starting port for TPU clients on each host. Ports in the range + [base_port, base_port + local_world_size) must be free on each host. + """ tpu_env = get_tpu_env() accelerator_type = tpu_env['ACCELERATOR_TYPE'] From 57b9b0b2b9ed7af87a6909aea30b1d4bbbbeefb8 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Aug 2022 18:21:50 +0000 Subject: [PATCH 12/19] Fix signatures of `num_processes` and `num_local_processes`. --- test/pjrt/test_experimental_tpu.py | 2 +- torch_xla/experimental/tpu.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py index c53d531c6a1e..93fe064b8332 100644 --- a/test/pjrt/test_experimental_tpu.py +++ b/test/pjrt/test_experimental_tpu.py @@ -11,7 +11,7 @@ class TestExperimentalTpu(parameterized.TestCase): @parameterized.named_parameters( - ('default_one_host', None, 4), + ('default_one_host', None, 1), ('one_process_one_host', '1,1,1', 1), ('multi_process_one_host', '2,2,1', 4), ('multi_process_v4-16', '2,2,2', 8), diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index c6d05eedcca3..5195b5411569 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -54,7 +54,7 @@ def _get_metadata(key: str) -> str: return resp.text -def num_processes(default: int = 4) -> Optional[int]: +def num_processes(default: int = 1) -> int: """Returns number of processes across all TPU hosts.""" process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str) @@ -62,10 +62,10 @@ def num_processes(default: int = 4) -> Optional[int]: _parse_mesh_shape(process_bounds)) if process_bounds else default -def num_local_processes() -> Optional[int]: +def num_local_processes(local_chips: int = 4) -> int: """Returns number of processes to create on this host.""" - # Don't create more processes than local chips (4) - return min(4, num_processes()) + # Don't create more processes than local chips + return min(local_chips, num_processes(default=local_chips)) def task_id() -> Optional[int]: From 2ea3017ecd49046d87b57c5316da6ec87e3006c4 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Aug 2022 18:45:23 +0000 Subject: [PATCH 13/19] Make MeshShape more pythonic --- torch_xla/experimental/tpu.py | 46 +++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index 5195b5411569..f6fe85903030 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -1,7 +1,7 @@ import functools import operator import os -from typing import Dict, Optional, List, Tuple +from typing import Dict, NamedTuple, Optional, List, Tuple import requests import yaml @@ -27,23 +27,27 @@ 'v3-2048': '16,16,1', } -MeshShape = Tuple[int, int, int] +class MeshShape(NamedTuple): + """Represents a TPU mesh shape (e.g. '2,2,1' or '1,1,1')""" + x: int + y: int + z: int -def _parse_mesh_shape(mesh: str) -> MeshShape: - dims = tuple(int(d) for d in mesh.split(',')) - if len(dims) != 3: - raise ValueError("Mesh shape '{}' should be length 3".format(mesh)) + @classmethod + def from_string(cls, mesh: str): + dims = tuple(int(d) for d in mesh.split(',')) + if len(dims) != 3: + raise ValueError("Mesh shape '{}' should be length 3".format(mesh)) - return dims + return MeshShape(*dims) + @property + def size(self) -> int: + return functools.reduce(operator.mul, self) -def _multiply_mesh_shapes(mesh1: MeshShape, mesh2: MeshShape) -> MeshShape: - return tuple(d1 * d2 for d1, d2 in zip(mesh1, mesh2)) - - -def _mesh_size(mesh: MeshShape) -> int: - return functools.reduce(operator.mul, mesh) + def __mul__(self, other): + return MeshShape(*(d1 * d2 for d1, d2 in zip(self, other))) def _get_metadata(key: str) -> str: @@ -58,8 +62,8 @@ def num_processes(default: int = 1) -> int: """Returns number of processes across all TPU hosts.""" process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str) - return _mesh_size( - _parse_mesh_shape(process_bounds)) if process_bounds else default + return MeshShape.from_string( + process_bounds).size if process_bounds else default def num_local_processes(local_chips: int = 4) -> int: @@ -109,18 +113,18 @@ def configure_topology(local_rank: int, accelerator_type = tpu_env['ACCELERATOR_TYPE'] if tpu_env['ACCELERATOR_TYPE'].startswith('v4'): # Process bounds with 4 chips per process - default_process_bounds = _parse_mesh_shape(tpu_env[xenv.TPU_PROCESS_BOUNDS]) - chips_per_process = _parse_mesh_shape( + default_process_bounds = MeshShape.from_string( + tpu_env[xenv.TPU_PROCESS_BOUNDS]) + chips_per_process = MeshShape.from_string( tpu_env[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS]) else: # TODO: merge with TPU v4 case when bounds are added to metadata - default_process_bounds = _parse_mesh_shape( + default_process_bounds = MeshShape.from_string( _ACCELERATOR_TYPE_TO_HOST_BOUNDS[accelerator_type]) - chips_per_process = _parse_mesh_shape('2,2,1') + chips_per_process = MeshShape.from_string('2,2,1') # Process bounds with 1 chip per process - process_bounds = _multiply_mesh_shapes(default_process_bounds, - chips_per_process) + process_bounds = default_process_bounds * chips_per_process os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, From 3ce02c7c3bfe3d95360006f9f74707dcb8f6e541 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Aug 2022 18:49:24 +0000 Subject: [PATCH 14/19] Clarify naming of `num_processes` --- test/pjrt/test_experimental_tpu.py | 4 ++-- torch_xla/experimental/tpu.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py index 93fe064b8332..4f4b109daa7f 100644 --- a/test/pjrt/test_experimental_tpu.py +++ b/test/pjrt/test_experimental_tpu.py @@ -17,10 +17,10 @@ class TestExperimentalTpu(parameterized.TestCase): ('multi_process_v4-16', '2,2,2', 8), ('multi_process_v4-32', '2,2,4', 16), ) - def test_num_processes(self, process_bounds, expected): + def test_process_bounds_size(self, process_bounds, expected): envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {} with mock.patch.dict(os.environ, envs, clear=True): - n = tpu.num_processes() + n = tpu.process_bounds_size() self.assertEqual(n, expected) diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index f6fe85903030..c02a396df50c 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -58,7 +58,7 @@ def _get_metadata(key: str) -> str: return resp.text -def num_processes(default: int = 1) -> int: +def process_bounds_size(default: int = 1) -> int: """Returns number of processes across all TPU hosts.""" process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str) @@ -69,7 +69,7 @@ def num_processes(default: int = 1) -> int: def num_local_processes(local_chips: int = 4) -> int: """Returns number of processes to create on this host.""" # Don't create more processes than local chips - return min(local_chips, num_processes(default=local_chips)) + return min(local_chips, process_bounds_size(default=local_chips)) def task_id() -> Optional[int]: From 5747bd1c32f5c25d51f629fc0bd6af88124b74de Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Aug 2022 21:08:18 +0000 Subject: [PATCH 15/19] Don't try to get TPU task id on CPU --- torch_xla/experimental/pjrt.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/pjrt.py b/torch_xla/experimental/pjrt.py index fcf316a9fc96..388e2c4415a9 100644 --- a/torch_xla/experimental/pjrt.py +++ b/torch_xla/experimental/pjrt.py @@ -122,13 +122,13 @@ def addressable_device_count() -> int: @requires_pjrt -def run_thread_per_device(process: int, local_processes: int, +def run_thread_per_device(local_process: int, local_world_size: int, fn: Callable[..., R]) -> Dict[int, R]: """Runs `fn` in a separate thread on each visible device. Args: - process: rank of current process - local_processes: number of processes on this host + local_process: rank of current process within this host + local_world_size: number of processes on this host fn: Function to run on all devices Returns: @@ -136,7 +136,7 @@ def run_thread_per_device(process: int, local_processes: int, result of calling `fn`. """ if device_type() == 'TPU': - tpu.configure_topology(process, local_processes) + tpu.configure_topology(local_process, local_world_size) xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices()) threads = len(xm.get_xla_supported_devices()) @@ -146,8 +146,13 @@ def _thread_fn(fn, device_index): @functools.wraps(fn) def wrapper(*args, **kwargs): # Assumes same number of threads per process - set_global_ordinal(tpu.task_id() * threads + device_index) - set_local_ordinal(process * threads + device_index) + if device_type() == 'TPU': + set_global_ordinal(tpu.task_id() * threads + device_index) + else: + # TODO: support multiple hosts with CPU/GPU + set_global_ordinal(local_process * threads + device_index) + + set_local_ordinal(local_process * threads + device_index) return fn(*args, **kwargs) From 215c9c5ebdfca2373ff7be18323888f50fb1712d Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 3 Aug 2022 16:14:29 +0000 Subject: [PATCH 16/19] Don't use value of `mock.patch.dict` because CI uses 3.7 https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch.dict --- test/pjrt/test_experimental_tpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py index 4f4b109daa7f..ad40edab85a7 100644 --- a/test/pjrt/test_experimental_tpu.py +++ b/test/pjrt/test_experimental_tpu.py @@ -164,11 +164,11 @@ def test_configure_tpu_topology(self, tpu_env, worker_ips, local_rank, local_world_size, expected): with mock.patch.object(tpu, 'get_tpu_env', return_value=tpu_env), \ mock.patch.object(tpu, 'get_worker_ips', return_value=worker_ips), \ - mock.patch.dict(os.environ, clear=True) as mock_env: + mock.patch.dict(os.environ, clear=True): tpu.configure_topology(local_rank, local_world_size) - self.assertDictContainsSubset(expected, mock_env) + self.assertDictContainsSubset(expected, os.environ) if __name__ == '__main__': From c392ce36fb96d5066f7324a7aa3f17bfb5e06e55 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 10 Aug 2022 10:02:35 -0700 Subject: [PATCH 17/19] Remove real devices test Real devices are not deterministic. --- test/pjrt/test_experimental_pjrt_tpu.py | 56 ------------------------- 1 file changed, 56 deletions(-) diff --git a/test/pjrt/test_experimental_pjrt_tpu.py b/test/pjrt/test_experimental_pjrt_tpu.py index 00c134e88096..dc8e3182c242 100644 --- a/test/pjrt/test_experimental_pjrt_tpu.py +++ b/test/pjrt/test_experimental_pjrt_tpu.py @@ -86,62 +86,6 @@ def test_xla_devices_multiprocess(self): devices_per_process = pjrt.run_multiprocess(xm.xla_device) self.assertDictEqual(devices_per_process, expected) - def test_real_devices_multiprocess(self): - # Real devices unfortunately don't correspond to indices in TPU_VISIBLE_DEVICES - accelerator_devices = { - 'v3-8': { - 0: { - 0: ['TPU:0', 'TPU:1'], - 1: ['TPU:0', 'TPU:1'], - }, - 1: { - 0: ['TPU:4', 'TPU:5'], - 1: ['TPU:4', 'TPU:5'] - }, - 2: { - 0: ['TPU:2', 'TPU:3'], - 1: ['TPU:2', 'TPU:3'] - }, - 3: { - 0: ['TPU:6', 'TPU:7'], - 1: ['TPU:6', 'TPU:7'], - }, - }, - 'v4-8': { - 0: { - 0: ['TPU:0'] - }, - 1: { - 0: ['TPU:2'] - }, - 2: { - 0: ['TPU:3'] - }, - 3: { - 0: ['TPU:1'] - }, - }, - } - - if self.accelerator_type not in accelerator_devices: - raise NotImplementedError('Test not implemented for {}'.format( - self.accelerator_type)) - expected = accelerator_devices[self.accelerator_type] - - devices_per_process = pjrt.run_multiprocess(_get_real_devices) - self.assertDictEqual(devices_per_process, expected) - - all_devices = sorted( - itertools.chain.from_iterable( - process_devices[0] for process_devices in expected.values())) - expected_all_devices = { - rank: {thread: all_devices for thread in expected[0].keys() - } for rank in expected.keys() - } - - all_devices_per_process = pjrt.run_multiprocess(_get_all_real_devices) - self.assertDictEqual(all_devices_per_process, expected_all_devices) - def test_xla_devices_single_process_all_chips(self): accelerator_devices = { 'v3-8': { From 695e5df42d8460d65e143c9d65eb8a4771f0a6bc Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 12 Aug 2022 16:15:57 +0000 Subject: [PATCH 18/19] Address review comments --- torch_xla/experimental/pjrt.py | 18 +++++++++--------- torch_xla/experimental/tpu.py | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/pjrt.py b/torch_xla/experimental/pjrt.py index 388e2c4415a9..81c18b4ffb60 100644 --- a/torch_xla/experimental/pjrt.py +++ b/torch_xla/experimental/pjrt.py @@ -122,7 +122,7 @@ def addressable_device_count() -> int: @requires_pjrt -def run_thread_per_device(local_process: int, local_world_size: int, +def run_thread_per_device(local_rank: int, local_world_size: int, fn: Callable[..., R]) -> Dict[int, R]: """Runs `fn` in a separate thread on each visible device. @@ -136,7 +136,7 @@ def run_thread_per_device(local_process: int, local_world_size: int, result of calling `fn`. """ if device_type() == 'TPU': - tpu.configure_topology(local_process, local_world_size) + tpu.configure_topology(local_rank, local_world_size) xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices()) threads = len(xm.get_xla_supported_devices()) @@ -150,9 +150,9 @@ def wrapper(*args, **kwargs): set_global_ordinal(tpu.task_id() * threads + device_index) else: # TODO: support multiple hosts with CPU/GPU - set_global_ordinal(local_process * threads + device_index) + set_global_ordinal(local_rank * threads + device_index) - set_local_ordinal(local_process * threads + device_index) + set_local_ordinal(local_rank * threads + device_index) return fn(*args, **kwargs) @@ -185,16 +185,16 @@ def run_multiprocess(fn: Callable[..., R], *args, return_value is the result of calling `fn`. """ if device_type() == 'TPU': - processes = tpu.num_local_processes() + num_processes = tpu.num_local_processes() else: - processes = 1 + num_processes = 1 with concurrent.futures.ProcessPoolExecutor( - max_workers=processes) as executor: + max_workers=num_processes) as executor: futures = { - executor.submit(run_thread_per_device, i, processes, + executor.submit(run_thread_per_device, i, num_processes, functools.partial(fn, *args, **kwargs)): i - for i in range(processes) + for i in range(num_processes) } results = { diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index c02a396df50c..57e0e7b872a5 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -25,6 +25,7 @@ 'v3-512': '8,8,1', 'v3-1024': '8,16,1', 'v3-2048': '16,16,1', + # Get v4 host bounds from TPU metadata } From d2272cd0869342732f61f11ba8b42ad8279b2959 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 12 Aug 2022 16:17:14 +0000 Subject: [PATCH 19/19] fix docstring --- torch_xla/experimental/pjrt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/pjrt.py b/torch_xla/experimental/pjrt.py index 81c18b4ffb60..4b0a4ec98140 100644 --- a/torch_xla/experimental/pjrt.py +++ b/torch_xla/experimental/pjrt.py @@ -127,7 +127,7 @@ def run_thread_per_device(local_rank: int, local_world_size: int, """Runs `fn` in a separate thread on each visible device. Args: - local_process: rank of current process within this host + local_rank: rank of current process within this host local_world_size: number of processes on this host fn: Function to run on all devices