diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 182c8675fbd..2993a033c6e 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -52,6 +52,8 @@ # Testing only '0x0056', '0x0062', + # TPU 7x + '0x0076' ] @@ -188,7 +190,10 @@ def version() -> int: except requests.HTTPError as e: raise EnvironmentError('Failed to get TPU metadata') from e - match = re.match(r'^v(\d)([A-Za-z]?){7}-(\d+)$', env[xenv.ACCELERATOR_TYPE]) + match = re.match(r'^(?:v|tpu)(\d)([A-Za-z]?){7}-(\d+)$', + env[xenv.ACCELERATOR_TYPE]) + if not match: + raise EnvironmentError('Failed to parse TPU version from metadata') return int(match.groups()[0]) @@ -254,7 +259,8 @@ def configure_topology(local_rank: int, tpu_env = get_tpu_env() accelerator_type = tpu_env[xenv.ACCELERATOR_TYPE] - if version() >= 4: + tpu_version = version() + if tpu_version >= 4: # Process bounds with 4 chips per process default_process_bounds = MeshShape.from_string( tpu_env[xenv.TPU_PROCESS_BOUNDS]) @@ -270,8 +276,11 @@ def configure_topology(local_rank: int, process_bounds = default_process_bounds * chips_per_process os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') - os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, - ','.join(str(dim) for dim in process_bounds)) + process_bounds_str = ','.join(str(dim) for dim in process_bounds) + if tpu_version == 7: + process_bounds_str += ',2' + + os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, process_bounds_str) # Assume each TPU has the same number of local processes with the same ports worker_id = int(tpu_env[xenv.WORKER_ID])