Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions test/pjrt/test_experimental_pjrt_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import concurrent.futures
import itertools
import os
import requests

import torch
import torch_xla
from absl.testing import absltest, parameterized
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
from torch_xla.experimental import tpu


def _get_real_devices():
"""Wraps `_xla_get_devices` to make it pickle-able"""
return torch_xla._XLAC._xla_get_devices()


def _get_all_real_devices():
"""Wraps `_xla_get_all_devices` to make it pickle-able"""
return torch_xla._XLAC._xla_get_all_devices()


class TestExperimentalPjrtTpu(parameterized.TestCase):

def setUp(self):
pjrt.set_device_type('TPU')

os.environ.pop(xenv.TPU_VISIBLE_DEVICES, None)
os.environ.pop(xenv.TPU_PROCESS_BOUNDS, None)

try:
tpu_env = tpu.get_tpu_env()
self.accelerator_type = tpu_env['ACCELERATOR_TYPE']
except requests.HTTPError as e:
raise EnvironmentError(
'Failed to get TPU metadata. Are you running on a TPU?') from e

# TODO: assert ComputationClient is not initialized
# The main process must not initialize the ComputationClient, otherwise
# sub-processes will not be able to initialize the client witht the correct
# settings.

def test_xla_devices_multiprocess(self):
accelerator_devices = {
'v3-8': {
0: {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
},
1: {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
},
2: {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
},
3: {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
},
},
'v4-8': {
0: {
0: torch.device('xla:0')
},
1: {
0: torch.device('xla:0')
},
2: {
0: torch.device('xla:0')
},
3: {
0: torch.device('xla:0')
},
},
}

if self.accelerator_type not in accelerator_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected = accelerator_devices[self.accelerator_type]

devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

def test_xla_devices_single_process_all_chips(self):
accelerator_devices = {
'v3-8': {
0: {i: torch.device(f'xla:{i}') for i in range(8)},
},
'v4-8': {
0: {i: torch.device(f'xla:{i}') for i in range(4)},
},
}

if self.accelerator_type not in accelerator_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected = accelerator_devices[self.accelerator_type]

os.environ[xenv.TPU_VISIBLE_DEVICES] = '0,1,2,3'
os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1'

devices = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices, expected)

def test_xla_devices_single_process_one_chip(self):
accelerator_devices = {
'v3-8': {
0: {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
},
},
'v4-8': {
0: {
0: torch.device('xla:0')
},
},
}

if self.accelerator_type not in accelerator_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected = accelerator_devices[self.accelerator_type]

os.environ[xenv.TPU_VISIBLE_DEVICES] = '0'
os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1'

devices = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices, expected)

def test_default_xla_devices(self):
accelerator_num_devices = {
'v3-8': 8,
'v4-8': 4,
}

if self.accelerator_type not in accelerator_num_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected_num_devices = accelerator_num_devices[self.accelerator_type]

with concurrent.futures.ProcessPoolExecutor(max_workers=1) as e:
f = e.submit(xm.get_xla_supported_devices, 'TPU')
devices = [torch.device(d) for d in f.result()]

self.assertListEqual(
devices,
[torch.device(f'xla:{i}') for i in range(expected_num_devices)])


if __name__ == '__main__':
absltest.main()
175 changes: 175 additions & 0 deletions test/pjrt/test_experimental_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import os
import textwrap

from absl.testing import absltest, parameterized
import torch_xla.core.xla_env_vars as xenv
from torch_xla.experimental import tpu

from unittest import mock


class TestExperimentalTpu(parameterized.TestCase):

@parameterized.named_parameters(
('default_one_host', None, 1),
('one_process_one_host', '1,1,1', 1),
('multi_process_one_host', '2,2,1', 4),
('multi_process_v4-16', '2,2,2', 8),
('multi_process_v4-32', '2,2,4', 16),
)
def test_process_bounds_size(self, process_bounds, expected):
envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {}
with mock.patch.dict(os.environ, envs, clear=True):
n = tpu.process_bounds_size()

self.assertEqual(n, expected)

@parameterized.named_parameters(
('default_one_host', None, 4),
('one_process_one_host', '1,1,1', 1),
('multi_process_one_host', '2,2,1', 4),
('multi_process_v4-16', '2,2,2', 4),
('multi_process_v4-32', '2,2,4', 4),
)
def test_num_local_processes(self, process_bounds, expected):
envs = {xenv.TPU_PROCESS_BOUNDS: process_bounds} if process_bounds else {}
with mock.patch.dict(os.environ, envs, clear=True):
n = tpu.num_local_processes()

self.assertEqual(n, expected)

@parameterized.parameters((None, None), ('0', 0), ('1', 1), ('15', 15))
def test_task_id(self, task_id, expected):
envs = {xenv.CLOUD_TPU_TASK_ID: task_id} if task_id else {}
with mock.patch.dict(os.environ, envs, clear=True):
i = tpu.task_id()

self.assertEqual(i, expected)

def test_tpu_env(self):
tpu_env_yaml = textwrap.dedent("""
ACCELERATOR_TYPE: 'v4-16'
CHIPS_PER_HOST_BOUNDS: '2,2,1'
HOST_BOUNDS: '1,1,2'
TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
TPU_PROCESS_BOUNDS: '1,1,2'
ZONE: 'us-central2-b'
""")

with mock.patch.object(tpu, '_get_metadata', return_value=tpu_env_yaml):
tpu_env = tpu.get_tpu_env()

self.assertDictEqual(
tpu_env, {
'ACCELERATOR_TYPE': 'v4-16',
'CHIPS_PER_HOST_BOUNDS': '2,2,1',
'HOST_BOUNDS': '1,1,2',
'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1',
'TPU_PROCESS_BOUNDS': '1,1,2',
'ZONE': 'us-central2-b',
})

@parameterized.named_parameters(
('one_host', 't1v-n-ea9d3291-w-0:12345:10.130.0.31', ['localhost']),
(
'four_hosts',
't1v-n-0f996b37-w-0:12345:10.130.0.26,t1v-n-0f996b37-w-1:12346:10.130.0.27,t1v-n-0f996b37-w-2:12347:10.130.0.25,t1v-n-0f996b37-w-3:12348:10.130.0.28',
['10.130.0.26', '10.130.0.27', '10.130.0.25', '10.130.0.28'],
),
)
def test_get_worker_ips(self, worker_network_endpoints, expected):
with mock.patch.object(
tpu, '_get_metadata', return_value=worker_network_endpoints):
worker_ips = tpu.get_worker_ips()

self.assertListEqual(worker_ips, expected)

@parameterized.named_parameters(
('v4-8_process_0', {
'ACCELERATOR_TYPE': 'v4-8',
xenv.TPU_PROCESS_BOUNDS: '1,1,1',
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1',
'WORKER_ID': '0'
}, ['localhost'], 0, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,1',
xenv.CLOUD_TPU_TASK_ID:
'0',
xenv.TPU_PROCESS_PORT:
'8476',
xenv.TPU_PROCESS_ADDRESSES:
'localhost:8476,localhost:8477,localhost:8478,localhost:8479',
xenv.TPU_VISIBLE_DEVICES:
'0',
}),
('v4-8_process_3', {
'ACCELERATOR_TYPE': 'v4-8',
xenv.TPU_PROCESS_BOUNDS: '1,1,1',
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1',
'WORKER_ID': '0'
}, ['localhost'], 3, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,1',
xenv.CLOUD_TPU_TASK_ID:
'3',
xenv.TPU_PROCESS_PORT:
'8479',
xenv.TPU_PROCESS_ADDRESSES:
'localhost:8476,localhost:8477,localhost:8478,localhost:8479',
xenv.TPU_VISIBLE_DEVICES:
'3',
}),
('v4-16_worker_1_process_0', {
'ACCELERATOR_TYPE': 'v4-16',
xenv.TPU_PROCESS_BOUNDS: '1,1,2',
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1',
'WORKER_ID': '1'
}, ['10.130.0.31', '10.130.0.30'], 0, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,2',
xenv.CLOUD_TPU_TASK_ID:
'4',
xenv.TPU_PROCESS_PORT:
'8476',
xenv.TPU_PROCESS_ADDRESSES:
'10.130.0.31:8476,10.130.0.31:8477,10.130.0.31:8478,10.130.0.31:8479,10.130.0.30:8476,10.130.0.30:8477,10.130.0.30:8478,10.130.0.30:8479',
xenv.TPU_VISIBLE_DEVICES:
'0',
}),
# TODO: remove this case when process bounds are added to metadata
('v3-8_process_0', {
'ACCELERATOR_TYPE': 'v3-8',
'WORKER_ID': '0'
}, ['localhost'], 0, 4, {
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS:
'1,1,1',
xenv.TPU_PROCESS_BOUNDS:
'2,2,1',
xenv.CLOUD_TPU_TASK_ID:
'0',
xenv.TPU_PROCESS_PORT:
'8476',
xenv.TPU_PROCESS_ADDRESSES:
'localhost:8476,localhost:8477,localhost:8478,localhost:8479',
xenv.TPU_VISIBLE_DEVICES:
'0',
}))
def test_configure_tpu_topology(self, tpu_env, worker_ips, local_rank,
local_world_size, expected):
with mock.patch.object(tpu, 'get_tpu_env', return_value=tpu_env), \
mock.patch.object(tpu, 'get_worker_ips', return_value=worker_ips), \
mock.patch.dict(os.environ, clear=True):

tpu.configure_topology(local_rank, local_world_size)

self.assertDictContainsSubset(expected, os.environ)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ function run_op_tests {
run_test python3 "$CDIR/test_torch_distributed_xla_backend.py"
run_xla_ir_debug python3 "$CDIR/test_env_var_mapper.py"
run_pjrt python3 "$CDIR/pjrt/test_experimental_pjrt.py"
run_pjrt python3 "$CDIR/pjrt/test_experimental_tpu.py"
}

function run_mp_op_tests {
Expand Down
Loading