Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove or improve several hardcoded TPU test conditions #5272

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 55 additions & 62 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,53 @@
from torch_xla._internal import tpu
import torch_xla.distributed.xla_multiprocessing as xmp

assert tpu.num_available_chips() > 0, 'Must be run on a TPU!'


def _ordinal_to_device(processes=None,
cores_per_process=None) -> Dict[int, torch.device]:
"""Returns a dict of global ordinals and their expected `torch.device` value.

Example for v4-8 multiprocessing:
{
0: torch.device('xla:0'),
1: torch.device('xla:0'),
2: torch.device('xla:0'),
3: torch.device('xla:0'),
}

Exmaple for v4-8 single-process:
{
0: torch.device('xla:0'),
1: torch.device('xla:1'),
2: torch.device('xla:2'),
3: torch.device('xla:3'),
}

Example for v3-8 multiprocessing:
{
0: torch.device('xla:0'),
1: torch.device('xla:1'),
2: torch.device('xla:0'),
3: torch.device('xla:1'),
4: torch.device('xla:0'),
5: torch.device('xla:1'),
6: torch.device('xla:0'),
7: torch.device('xla:1'),
}
"""
processes = processes or tpu.num_available_chips()
cores_per_process = cores_per_process or tpu.num_logical_cores_per_chip()

ordinal = 0
ordinal_to_device = {}
for _ in range(processes):
for core in range(cores_per_process):
ordinal_to_device[ordinal] = torch.device(f'xla:{core}')
ordinal += 1

return ordinal_to_device


class TestExperimentalPjrtTpu(parameterized.TestCase):

Expand All @@ -27,11 +74,7 @@ def setUp(self):
tpu_env = tpu.get_tpu_env()
self.accelerator_type = tpu_env['ACCELERATOR_TYPE']
# Number of logical devices per single-host TPU
self.num_devices = {
'v2-8': 8,
'v3-8': 8,
'v4-8': 4,
}[self.accelerator_type]
self.num_devices = tpu.num_available_devices()
except requests.HTTPError as e:
raise EnvironmentError(
'Failed to get TPU metadata. Are you running on a TPU?') from e
Expand All @@ -45,65 +88,24 @@ def tearDown(self) -> None:
os.environ.pop(xenv.TPU_VISIBLE_CHIPS, None)
os.environ.pop(xenv.TPU_PROCESS_BOUNDS, None)

@absltest.skipIf(
tpu.version() <= 3,
'This test is not currently supported on v3 TPUVMs or earlier.')
def test_xla_devices_multiprocess(self):
accelerator_devices = {
'v4-8': {
0: torch.device('xla:0'),
1: torch.device('xla:0'),
2: torch.device('xla:0'),
3: torch.device('xla:0'),
},
}

if self.accelerator_type not in accelerator_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected = accelerator_devices[self.accelerator_type]
expected = _ordinal_to_device()

devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

@absltest.skipIf(
tpu.version() <= 2,
'This test is not currently supported on v2 TPUVMs or earlier.')
def test_xla_devices_single_process_all_chips(self):
accelerator_devices = {
'v3-8': {i: torch.device(f'xla:{i}') for i in range(8)},
'v4-8': {i: torch.device(f'xla:{i}') for i in range(4)},
}

if self.accelerator_type not in accelerator_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected = accelerator_devices[self.accelerator_type]
expected = _ordinal_to_device(
processes=1, cores_per_process=tpu.num_available_devices())

os.environ[xenv.TPU_VISIBLE_CHIPS] = '0,1,2,3'
os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1'

devices = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices, expected)

@absltest.skipIf(
tpu.version() <= 2,
'This test is not currently supported on v2 TPUVMs or earlier.')
def test_xla_devices_single_process_one_chip(self):
accelerator_devices = {
'v3-8': {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
},
'v4-8': {
0: torch.device('xla:0')
},
}

if self.accelerator_type not in accelerator_devices:
raise NotImplementedError('Test not implemented for {}'.format(
self.accelerator_type))
expected = accelerator_devices[self.accelerator_type]
expected = _ordinal_to_device(processes=1)

os.environ[xenv.TPU_VISIBLE_CHIPS] = '0'
os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1'
Expand All @@ -124,9 +126,6 @@ def test_xla_devices_single_process_one_chip_one_device_spawn(self):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
executor.submit(self._fail_on_nonfirst_device).result()

@absltest.skipIf(
tpu.version() <= 2,
'This test is not currently supported on v2 TPUVMs or earlier.')
def test_default_xla_devices(self):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as e:
f = e.submit(xm.get_xla_supported_devices, 'TPU')
Expand All @@ -137,19 +136,13 @@ def test_default_xla_devices(self):

@parameterized.named_parameters(('xla_model', xm.get_ordinal),
('pjrt', xr.global_ordinal))
@absltest.skipIf(
tpu.version() <= 2,
'This test is not currently supported on v2 TPUVMs or earlier.')
def test_global_ordinal(self, ordinal_func):
results = pjrt.run_multiprocess(ordinal_func)
values = list(results.values())
self.assertListEqual(sorted(values), list(range(self.num_devices)))

@parameterized.named_parameters(('xla_model', xm.get_local_ordinal),
('pjrt', xr.local_ordinal))
@absltest.skipIf(
tpu.version() <= 2,
'This test is not currently supported on v2 TPUVMs or earlier.')
def test_local_ordinal(self, ordinal_func):
results = pjrt.run_multiprocess(ordinal_func)
self.assertCountEqual(results.values(), list(range(self.num_devices)))
Expand All @@ -161,16 +154,16 @@ def _local_ordinal_with_discontiguous_global_ordinal_v4():
new_global_ordinal = global_ordinals[xr.global_ordinal()]

with mock.patch.object(
pjrt, 'global_ordinal', return_value=new_global_ordinal):
xr, 'global_ordinal', return_value=new_global_ordinal):
return xr.local_ordinal()

@absltest.skipIf(tpu.version() < 4, "Not implemented")
@absltest.skipIf(tpu.num_available_devices() != 4, "Not implemented")
def test_local_ordinal_with_discontiguous_global_ordinal_v4(self):
results = pjrt.run_multiprocess(
self._local_ordinal_with_discontiguous_global_ordinal_v4)
self.assertCountEqual(results.values(), [0, 1, 2, 3])

@absltest.skipIf(tpu.version() < 4, "Not implemented")
@absltest.skipIf(tpu.num_available_devices() != 4, "Not implemented")
def test_local_ordinal_with_discontiguous_global_ordinal_v4_threaded(self):
os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1'
os.environ[xenv.TPU_VISIBLE_CHIPS] = '0,1,2,3'
Expand Down