Skip to content
57 changes: 57 additions & 0 deletions test/pjrt/test_experimental_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,63 @@ def test_get_worker_ips(self, worker_network_endpoints, expected):
self.assertListEqual(worker_ips, expected)

@parameterized.named_parameters(
('v5-4_process_0', {
'ACCELERATOR_TYPE': 'v5-4',
xenv.TPU_PROCESS_BOUNDS: '2,2,1',
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1',
'WORKER_ID': '0'
}, ['localhost'], 0, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,1',
xenv.CLOUD_TPU_TASK_ID:
'0',
xenv.TPU_PROCESS_PORT:
'8476',
xenv.TPU_PROCESS_ADDRESSES:
'localhost:8476,localhost:8477,localhost:8478,localhost:8479',
xenv.TPU_VISIBLE_CHIPS:
'0',
}),
('v5abcdefg-4_process_0', {
'ACCELERATOR_TYPE': 'v5abcdefg-4',
xenv.TPU_PROCESS_BOUNDS: '2,2,1',
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1',
'WORKER_ID': '0'
}, ['localhost'], 0, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,1',
xenv.CLOUD_TPU_TASK_ID:
'0',
xenv.TPU_PROCESS_PORT:
'8476',
xenv.TPU_PROCESS_ADDRESSES:
'localhost:8476,localhost:8477,localhost:8478,localhost:8479',
xenv.TPU_VISIBLE_CHIPS:
'0',
}),
('v5abcdefg-16_process_0', {
'ACCELERATOR_TYPE': 'v5abcdefg-16',
xenv.TPU_PROCESS_BOUNDS: '2,2,1',
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '1,1,1',
'WORKER_ID': '0'
}, ['localhost'], 0, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,1',
xenv.CLOUD_TPU_TASK_ID:
'0',
xenv.TPU_PROCESS_PORT:
'8476',
xenv.TPU_PROCESS_ADDRESSES:
'localhost:8476,localhost:8477,localhost:8478,localhost:8479',
xenv.TPU_VISIBLE_CHIPS:
'0',
}),
('v4-8_process_0', {
'ACCELERATOR_TYPE': 'v4-8',
xenv.TPU_PROCESS_BOUNDS: '1,1,1',
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def configure_topology(local_rank: int,
tpu_env = get_tpu_env()

accelerator_type = tpu_env[xenv.ACCELERATOR_TYPE]
if version() == 4:
if version() >= 4:
# Process bounds with 4 chips per process
default_process_bounds = MeshShape.from_string(
tpu_env[xenv.TPU_PROCESS_BOUNDS])
Expand Down