diff --git a/test/pjrt/test_experimental_tpu.py b/test/pjrt/test_experimental_tpu.py index d80b75e0c37c..662e31062252 100644 --- a/test/pjrt/test_experimental_tpu.py +++ b/test/pjrt/test_experimental_tpu.py @@ -145,6 +145,63 @@ def test_get_worker_ips(self, worker_network_endpoints, expected): self.assertListEqual(worker_ips, expected) @parameterized.named_parameters( + ('v5-4_process_0', { + 'ACCELERATOR_TYPE': 'v5-4', + xenv.TPU_PROCESS_BOUNDS: '2,2,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', + 'WORKER_ID': '0' + }, ['localhost'], 0, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,1', + xenv.CLOUD_TPU_TASK_ID: + '0', + xenv.TPU_PROCESS_PORT: + '8476', + xenv.TPU_PROCESS_ADDRESSES: + 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_CHIPS: + '0', + }), + ('v5abcdefg-4_process_0', { + 'ACCELERATOR_TYPE': 'v5abcdefg-4', + xenv.TPU_PROCESS_BOUNDS: '2,2,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', + 'WORKER_ID': '0' + }, ['localhost'], 0, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,1', + xenv.CLOUD_TPU_TASK_ID: + '0', + xenv.TPU_PROCESS_PORT: + '8476', + xenv.TPU_PROCESS_ADDRESSES: + 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_CHIPS: + '0', + }), + ('v5abcdefg-16_process_0', { + 'ACCELERATOR_TYPE': 'v5abcdefg-16', + xenv.TPU_PROCESS_BOUNDS: '2,2,1', + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1', + 'WORKER_ID': '0' + }, ['localhost'], 0, 4, { + xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: + '1,1,1', + xenv.TPU_PROCESS_BOUNDS: + '2,2,1', + xenv.CLOUD_TPU_TASK_ID: + '0', + xenv.TPU_PROCESS_PORT: + '8476', + xenv.TPU_PROCESS_ADDRESSES: + 'localhost:8476,localhost:8477,localhost:8478,localhost:8479', + xenv.TPU_VISIBLE_CHIPS: + '0', + }), ('v4-8_process_0', { 'ACCELERATOR_TYPE': 'v4-8', xenv.TPU_PROCESS_BOUNDS: '1,1,1', diff --git a/torch_xla/experimental/tpu.py b/torch_xla/experimental/tpu.py index 0a660a1faeda..99d9cf9e5428 100644 --- a/torch_xla/experimental/tpu.py +++ b/torch_xla/experimental/tpu.py @@ -172,7 +172,7 @@ def configure_topology(local_rank: int, tpu_env = get_tpu_env() accelerator_type = tpu_env[xenv.ACCELERATOR_TYPE] - if version() == 4: + if version() >= 4: # Process bounds with 4 chips per process default_process_bounds = MeshShape.from_string( tpu_env[xenv.TPU_PROCESS_BOUNDS])