33import os
44import re
55from typing import Dict , NamedTuple , Optional , List , Tuple
6+ from typing_extensions import TypedDict
67import requests
78import yaml
89
3233}
3334
3435
36+ class TpuEnv (TypedDict ):
37+ accelerator_type : str
38+ tpu_process_bounds : str
39+ tpu_chips_per_process_bound : str
40+ worker_id : int
41+
42+
3543class MeshShape (NamedTuple ):
3644 """Represents a TPU mesh shape (e.g. '2,2,1' or '1,1,1')"""
3745 x : int
@@ -65,7 +73,6 @@ def _get_metadata(key: str) -> str:
6573def process_bounds_size (default : int = 1 ) -> int :
6674 """Returns number of processes across all TPU hosts."""
6775 process_bounds = xu .getenv_as (xenv .TPU_PROCESS_BOUNDS , str )
68-
6976 return MeshShape .from_string (
7077 process_bounds ).size if process_bounds else default
7178
@@ -81,10 +88,28 @@ def task_id() -> Optional[int]:
8188 return xu .getenv_as (xenv .CLOUD_TPU_TASK_ID , int )
8289
8390
84- def get_tpu_env () -> Dict [str , str ]:
91+ def _using_env_vars () -> bool :
92+ return xu .getenv_as (xenv .TPU_SKIP_MDS_QUERY , str , False )
93+
94+
95+ def build_tpu_env_from_vars () -> TpuEnv :
96+ metadata = dict ()
97+ metadata [xenv .ACCELERATOR_TYPE ] = xu .getenv_as (xenv .TPU_ACCELERATOR_TYPE , str )
98+ metadata [xenv .TPU_PROCESS_BOUNDS ] = xu .getenv_as (
99+ xenv .TPU_PROCESS_BOUNDS , str , xu .getenv_as (xenv .TPU_HOST_BOUNDS , str ))
100+ metadata [xenv .TPU_CHIPS_PER_PROCESS_BOUNDS ] = xu .getenv_as (
101+ xenv .TPU_CHIPS_PER_PROCESS_BOUNDS , str ,
102+ xu .getenv_as (xenv .TPU_CHIPS_PER_HOST_BOUNDS , str ))
103+ metadata [xenv .WORKER_ID ] = xu .getenv_as (xenv .CLOUD_TPU_TASK_ID , str ,
104+ xu .getenv_as (xenv .TPU_WORKER_ID , str ))
105+ return metadata
106+
107+
108+ def get_tpu_env () -> TpuEnv :
85109 """Fetches and parses `tpu-env` metadata field."""
110+ if _using_env_vars ():
111+ return build_tpu_env_from_vars ()
86112 metadata = _get_metadata ('tpu-env' )
87-
88113 return yaml .load (metadata , yaml .Loader )
89114
90115
@@ -94,19 +119,22 @@ def version() -> int:
94119 except requests .HTTPError as e :
95120 raise EnvironmentError ('Failed to get TPU metadata' ) from e
96121
97- match = re .match (r'^v(\d)-(\d+)$' , env [' ACCELERATOR_TYPE' ])
122+ match = re .match (r'^v(\d)-(\d+)$' , env [xenv . ACCELERATOR_TYPE ])
98123 return int (match .groups ()[0 ])
99124
100125
101126def get_worker_ips () -> List [str ]:
102127 """Returns ordered list of TPU worker IPs from TPU metadata."""
103- metadata = _get_metadata ('worker-network-endpoints' )
104-
105- # Workers have format 'hostname:uid:ip,hostname:uid:ip,...'
106- workers = metadata .split (',' )
107- ips = [worker .split (':' )[2 ] for worker in workers ]
108-
109- return ips if len (ips ) > 1 else ['localhost' ]
128+ if _using_env_vars ():
129+ hostnames_string = xu .getenv_as (xenv .TPU_WORKER_HOSTNAMES , str , '' )
130+ # String has the format 'host-name-1,host-name-2,...,host-name-n'
131+ hostnames = hostnames_string .split (',' )
132+ else :
133+ hostnames_string = _get_metadata ('worker-network-endpoints' )
134+ # Workers have format 'hostname:uid:ip,hostname:uid:ip,...'
135+ workers = hostnames_string .split (',' )
136+ hostnames = [worker .split (':' )[2 ] for worker in workers ]
137+ return hostnames if len (hostnames ) > 1 else ['localhost' ]
110138
111139
112140def configure_one_chip_topology () -> None :
@@ -135,8 +163,8 @@ def configure_topology(local_rank: int,
135163 """
136164 tpu_env = get_tpu_env ()
137165
138- accelerator_type = tpu_env [' ACCELERATOR_TYPE' ]
139- if tpu_env [ 'ACCELERATOR_TYPE' ]. startswith ( 'v4' ) :
166+ accelerator_type = tpu_env [xenv . ACCELERATOR_TYPE ]
167+ if version () == 4 :
140168 # Process bounds with 4 chips per process
141169 default_process_bounds = MeshShape .from_string (
142170 tpu_env [xenv .TPU_PROCESS_BOUNDS ])
@@ -156,7 +184,7 @@ def configure_topology(local_rank: int,
156184 ',' .join (str (dim ) for dim in process_bounds ))
157185
158186 # Assume each TPU has the same number of local processes with the same ports
159- worker_id = int (tpu_env [' WORKER_ID' ])
187+ worker_id = int (tpu_env [xenv . WORKER_ID ])
160188 os .environ .setdefault (xenv .CLOUD_TPU_TASK_ID ,
161189 str (worker_id * local_world_size + local_rank ))
162190
@@ -186,7 +214,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:
186214 return 'localhost'
187215
188216 tpu_env = get_tpu_env ()
189- current_worker_id = int (tpu_env [' WORKER_ID' ])
217+ current_worker_id = int (tpu_env [xenv . WORKER_ID ])
190218 t = torch .tensor ([current_worker_id ], device = xm .xla_device ())
191219 xm .collective_broadcast ([t ])
192220 xm .mark_step ()
0 commit comments