Skip to content

Commit 96fd3ce

Browse files
steventk-groot
andauthored
Modify tpu version regex to match new names (#4557)
* Match litepod and lite version names * Formatting with yapf --------- Co-authored-by: root <root@t1v-n-804806aa-w-0.us-central2-b.c.tpu-pytorch.internal>
1 parent 0830b58 commit 96fd3ce

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

test/pjrt/test_experimental_tpu.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,30 +46,51 @@ def test_task_id(self, task_id, expected):
4646

4747
self.assertEqual(i, expected)
4848

49-
def test_tpu_env_from_gce_metadata(self):
50-
tpu_env_yaml = textwrap.dedent("""
51-
ACCELERATOR_TYPE: 'v4-16'
52-
CHIPS_PER_HOST_BOUNDS: '2,2,1'
53-
HOST_BOUNDS: '1,1,2'
54-
TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
55-
TPU_PROCESS_BOUNDS: '1,1,2'
56-
ZONE: 'us-central2-b'
57-
WORKER_ID: '0'
58-
""")
59-
49+
@parameterized.named_parameters(
50+
('v4',
51+
textwrap.dedent("""
52+
ACCELERATOR_TYPE: 'v4-16'
53+
CHIPS_PER_HOST_BOUNDS: '2,2,1'
54+
HOST_BOUNDS: '1,1,2'
55+
TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
56+
TPU_PROCESS_BOUNDS: '1,1,2'
57+
ZONE: 'us-central2-b'
58+
WORKER_ID: '0'
59+
"""), {
60+
'ACCELERATOR_TYPE': 'v4-16',
61+
'CHIPS_PER_HOST_BOUNDS': '2,2,1',
62+
'HOST_BOUNDS': '1,1,2',
63+
'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1',
64+
'TPU_PROCESS_BOUNDS': '1,1,2',
65+
'ZONE': 'us-central2-b',
66+
'WORKER_ID': '0'
67+
}, 4),
68+
('v5',
69+
textwrap.dedent("""
70+
ACCELERATOR_TYPE: 'v5abcdefg-16'
71+
CHIPS_PER_HOST_BOUNDS: '2,2,1'
72+
HOST_BOUNDS: '1,1,2'
73+
TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
74+
TPU_PROCESS_BOUNDS: '1,1,2'
75+
ZONE: 'us-central2-b'
76+
WORKER_ID: '0'
77+
"""), {
78+
'ACCELERATOR_TYPE': 'v5abcdefg-16',
79+
'CHIPS_PER_HOST_BOUNDS': '2,2,1',
80+
'HOST_BOUNDS': '1,1,2',
81+
'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1',
82+
'TPU_PROCESS_BOUNDS': '1,1,2',
83+
'ZONE': 'us-central2-b',
84+
'WORKER_ID': '0'
85+
}, 5),
86+
)
87+
def test_tpu_env_from_gce_metadata(self, tpu_env_yaml, expected_env,
88+
expected_version):
6089
with mock.patch.object(tpu, '_get_metadata', return_value=tpu_env_yaml):
6190
tpu_env = tpu.get_tpu_env()
62-
63-
self.assertDictEqual(
64-
tpu_env, {
65-
'ACCELERATOR_TYPE': 'v4-16',
66-
'CHIPS_PER_HOST_BOUNDS': '2,2,1',
67-
'HOST_BOUNDS': '1,1,2',
68-
'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1',
69-
'TPU_PROCESS_BOUNDS': '1,1,2',
70-
'ZONE': 'us-central2-b',
71-
'WORKER_ID': '0'
72-
})
91+
version = tpu.version()
92+
self.assertDictEqual(tpu_env, expected_env)
93+
self.assertEqual(version, expected_version)
7394

7495
@parameterized.named_parameters(
7596
('all-vars-set', {

torch_xla/experimental/tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def version() -> int:
119119
except requests.HTTPError as e:
120120
raise EnvironmentError('Failed to get TPU metadata') from e
121121

122-
match = re.match(r'^v(\d)-(\d+)$', env[xenv.ACCELERATOR_TYPE])
122+
match = re.match(r'^v(\d)([A-Za-z]?){7}-(\d+)$', env[xenv.ACCELERATOR_TYPE])
123123
return int(match.groups()[0])
124124

125125

0 commit comments

Comments
 (0)