Skip to content

Commit f2cb76f

Browse files
steventk-groot
andauthored
Try bootstrapping tpu env from env vars (#4499)
* Try bootstrapping tpu env from env vars * Try bootstrapping tpu env from env vars * Formatting * Mock env vars * Use TPU_SKIP_MDS_QUERY --------- Co-authored-by: root <root@t1v-n-804806aa-w-0.us-central2-b.c.tpu-pytorch.internal>
1 parent 4f4ab8a commit f2cb76f

File tree

3 files changed

+86
-16
lines changed

3 files changed

+86
-16
lines changed

test/pjrt/test_experimental_tpu.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ def test_task_id(self, task_id, expected):
4646

4747
self.assertEqual(i, expected)
4848

49-
def test_tpu_env(self):
49+
def test_tpu_env_from_gce_metadata(self):
5050
tpu_env_yaml = textwrap.dedent("""
5151
ACCELERATOR_TYPE: 'v4-16'
5252
CHIPS_PER_HOST_BOUNDS: '2,2,1'
5353
HOST_BOUNDS: '1,1,2'
5454
TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
5555
TPU_PROCESS_BOUNDS: '1,1,2'
5656
ZONE: 'us-central2-b'
57+
WORKER_ID: '0'
5758
""")
5859

5960
with mock.patch.object(tpu, '_get_metadata', return_value=tpu_env_yaml):
@@ -67,8 +68,43 @@ def test_tpu_env(self):
6768
'TPU_CHIPS_PER_PROCESS_BOUNDS': '2,2,1',
6869
'TPU_PROCESS_BOUNDS': '1,1,2',
6970
'ZONE': 'us-central2-b',
71+
'WORKER_ID': '0'
7072
})
7173

74+
@parameterized.named_parameters(
75+
('all-vars-set', {
76+
xenv.TPU_SKIP_MDS_QUERY: '1',
77+
xenv.TPU_ACCELERATOR_TYPE: 'v4-16',
78+
xenv.TPU_PROCESS_BOUNDS: '1,2,2',
79+
xenv.TPU_HOST_BOUNDS: '1,1,2',
80+
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1',
81+
xenv.TPU_CHIPS_PER_HOST_BOUNDS: '2,1,1',
82+
xenv.CLOUD_TPU_TASK_ID: '1',
83+
xenv.TPU_WORKER_ID: '0'
84+
}, {
85+
xenv.ACCELERATOR_TYPE: 'v4-16',
86+
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1',
87+
xenv.TPU_PROCESS_BOUNDS: '1,2,2',
88+
xenv.WORKER_ID: '1'
89+
}),
90+
('defaults-only', {
91+
xenv.TPU_SKIP_MDS_QUERY: '1',
92+
xenv.TPU_ACCELERATOR_TYPE: 'v4-16',
93+
xenv.TPU_HOST_BOUNDS: '1,1,2',
94+
xenv.TPU_CHIPS_PER_HOST_BOUNDS: '2,1,1',
95+
xenv.TPU_WORKER_ID: '0'
96+
}, {
97+
xenv.ACCELERATOR_TYPE: 'v4-16',
98+
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS: '2,1,1',
99+
xenv.TPU_PROCESS_BOUNDS: '1,1,2',
100+
xenv.WORKER_ID: '0'
101+
}),
102+
)
103+
def test_tpu_env_from_env_vars(self, envs, expected):
104+
with mock.patch.dict(os.environ, envs, clear=True):
105+
tpu_env = tpu.get_tpu_env()
106+
self.assertDictEqual(tpu_env, expected)
107+
72108
@parameterized.named_parameters(
73109
('one_host', 't1v-n-ea9d3291-w-0:12345:10.130.0.31', ['localhost']),
74110
(

torch_xla/core/xla_env_vars.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
GPU_NUM_DEVICES = 'GPU_NUM_DEVICES'
1717
CPU_NUM_DEVICES = 'CPU_NUM_DEVICES'
1818
CLOUD_TPU_TASK_ID = 'CLOUD_TPU_TASK_ID'
19+
ACCELERATOR_TYPE = 'ACCELERATOR_TYPE'
20+
WORKER_ID = 'WORKER_ID'
21+
TPU_SKIP_MDS_QUERY = 'TPU_SKIP_MDS_QUERY'
22+
TPU_ACCELERATOR_TYPE = 'TPU_ACCELERATOR_TYPE'
23+
TPU_WORKER_ID = 'TPU_WORKER_ID'
24+
TPU_WORKER_HOSTNAMES = 'TPU_WORKER_HOSTNAMES'
1925
TPU_HOST_BOUNDS = 'TPU_HOST_BOUNDS'
2026
TPU_CHIPS_PER_HOST_BOUNDS = 'TPU_CHIPS_PER_HOST_BOUNDS'
2127
TPU_MESH_CTLER_ADDR = 'TPU_MESH_CONTROLLER_ADDRESS'

torch_xla/experimental/tpu.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import re
55
from typing import Dict, NamedTuple, Optional, List, Tuple
6+
from typing_extensions import TypedDict
67
import requests
78
import yaml
89

@@ -32,6 +33,13 @@
3233
}
3334

3435

36+
class TpuEnv(TypedDict):
37+
accelerator_type: str
38+
tpu_process_bounds: str
39+
tpu_chips_per_process_bound: str
40+
worker_id: int
41+
42+
3543
class MeshShape(NamedTuple):
3644
"""Represents a TPU mesh shape (e.g. '2,2,1' or '1,1,1')"""
3745
x: int
@@ -65,7 +73,6 @@ def _get_metadata(key: str) -> str:
6573
def process_bounds_size(default: int = 1) -> int:
6674
"""Returns number of processes across all TPU hosts."""
6775
process_bounds = xu.getenv_as(xenv.TPU_PROCESS_BOUNDS, str)
68-
6976
return MeshShape.from_string(
7077
process_bounds).size if process_bounds else default
7178

@@ -81,10 +88,28 @@ def task_id() -> Optional[int]:
8188
return xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, int)
8289

8390

84-
def get_tpu_env() -> Dict[str, str]:
91+
def _using_env_vars() -> bool:
92+
return xu.getenv_as(xenv.TPU_SKIP_MDS_QUERY, str, False)
93+
94+
95+
def build_tpu_env_from_vars() -> TpuEnv:
96+
metadata = dict()
97+
metadata[xenv.ACCELERATOR_TYPE] = xu.getenv_as(xenv.TPU_ACCELERATOR_TYPE, str)
98+
metadata[xenv.TPU_PROCESS_BOUNDS] = xu.getenv_as(
99+
xenv.TPU_PROCESS_BOUNDS, str, xu.getenv_as(xenv.TPU_HOST_BOUNDS, str))
100+
metadata[xenv.TPU_CHIPS_PER_PROCESS_BOUNDS] = xu.getenv_as(
101+
xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, str,
102+
xu.getenv_as(xenv.TPU_CHIPS_PER_HOST_BOUNDS, str))
103+
metadata[xenv.WORKER_ID] = xu.getenv_as(xenv.CLOUD_TPU_TASK_ID, str,
104+
xu.getenv_as(xenv.TPU_WORKER_ID, str))
105+
return metadata
106+
107+
108+
def get_tpu_env() -> TpuEnv:
85109
"""Fetches and parses `tpu-env` metadata field."""
110+
if _using_env_vars():
111+
return build_tpu_env_from_vars()
86112
metadata = _get_metadata('tpu-env')
87-
88113
return yaml.load(metadata, yaml.Loader)
89114

90115

@@ -94,19 +119,22 @@ def version() -> int:
94119
except requests.HTTPError as e:
95120
raise EnvironmentError('Failed to get TPU metadata') from e
96121

97-
match = re.match(r'^v(\d)-(\d+)$', env['ACCELERATOR_TYPE'])
122+
match = re.match(r'^v(\d)-(\d+)$', env[xenv.ACCELERATOR_TYPE])
98123
return int(match.groups()[0])
99124

100125

101126
def get_worker_ips() -> List[str]:
102127
"""Returns ordered list of TPU worker IPs from TPU metadata."""
103-
metadata = _get_metadata('worker-network-endpoints')
104-
105-
# Workers have format 'hostname:uid:ip,hostname:uid:ip,...'
106-
workers = metadata.split(',')
107-
ips = [worker.split(':')[2] for worker in workers]
108-
109-
return ips if len(ips) > 1 else ['localhost']
128+
if _using_env_vars():
129+
hostnames_string = xu.getenv_as(xenv.TPU_WORKER_HOSTNAMES, str, '')
130+
# String has the format 'host-name-1,host-name-2,...,host-name-n'
131+
hostnames = hostnames_string.split(',')
132+
else:
133+
hostnames_string = _get_metadata('worker-network-endpoints')
134+
# Workers have format 'hostname:uid:ip,hostname:uid:ip,...'
135+
workers = hostnames_string.split(',')
136+
hostnames = [worker.split(':')[2] for worker in workers]
137+
return hostnames if len(hostnames) > 1 else ['localhost']
110138

111139

112140
def configure_one_chip_topology() -> None:
@@ -135,8 +163,8 @@ def configure_topology(local_rank: int,
135163
"""
136164
tpu_env = get_tpu_env()
137165

138-
accelerator_type = tpu_env['ACCELERATOR_TYPE']
139-
if tpu_env['ACCELERATOR_TYPE'].startswith('v4'):
166+
accelerator_type = tpu_env[xenv.ACCELERATOR_TYPE]
167+
if version() == 4:
140168
# Process bounds with 4 chips per process
141169
default_process_bounds = MeshShape.from_string(
142170
tpu_env[xenv.TPU_PROCESS_BOUNDS])
@@ -156,7 +184,7 @@ def configure_topology(local_rank: int,
156184
','.join(str(dim) for dim in process_bounds))
157185

158186
# Assume each TPU has the same number of local processes with the same ports
159-
worker_id = int(tpu_env['WORKER_ID'])
187+
worker_id = int(tpu_env[xenv.WORKER_ID])
160188
os.environ.setdefault(xenv.CLOUD_TPU_TASK_ID,
161189
str(worker_id * local_world_size + local_rank))
162190

@@ -186,7 +214,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:
186214
return 'localhost'
187215

188216
tpu_env = get_tpu_env()
189-
current_worker_id = int(tpu_env['WORKER_ID'])
217+
current_worker_id = int(tpu_env[xenv.WORKER_ID])
190218
t = torch.tensor([current_worker_id], device=xm.xla_device())
191219
xm.collective_broadcast([t])
192220
xm.mark_step()

0 commit comments

Comments
 (0)