Skip to content

Commit

Permalink
Update grpc_tpu_worker.py file
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609469881
  • Loading branch information
chandrasekhard2 authored and tensorflow-jenkins committed Feb 28, 2024
1 parent 4bdc149 commit 1a19b6e
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions tensorflow/python/tools/grpc_tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def setup_env_vars():
os.environ['TPU_STDERR_LOG_LEVEL'] = '0'
os.environ['CLOUD_TPU_TASK_ID'] = worker_id
os.environ['TPU_LOCK_DEVICE'] = 'true'
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = (
worker_network_endpoints.split(',')[0].split(':')[2] + ':8476'
)
os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'

accelerator_type_to_host_bounds = {
# v2
'v2-8': '1,1,1',
Expand All @@ -78,29 +82,16 @@ def setup_env_vars():
'v3-512': '8,8,1',
'v3-1024': '8,16,1',
'v3-2048': '16,16,1',
# v4
'v4-8': '1,1,1',
'v4-16': '1,1,2',
'v4-32': '1,1,4',
'v4-64': '1,2,4',
'v4-128': '2,2,4',
'v4-256': '2,2,8',
'v4-512': '2,4,8',
'v4-1024': '4,4,8',
'v4-2048': '4,4,16',
'v4-4096': '4,8,16',
}

os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
accelerator_type]
os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split(
',')[0].split(':')[2] + ':8476'
os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'

os.environ['TPU_STDERR_LOG_LEVEL'] = '0'
}

if accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']:
os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true'
# If v4 TPU don't set any topology related flags,
# libtpu will set these values.
if not (accelerator_type.startswith('v4-') or
accelerator_type.startswith('v5')):
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
accelerator_type]

# Set the hostname override.
os.environ['TPU_HOSTNAME_OVERRIDE'] = get_host_ip()
Expand All @@ -111,7 +102,10 @@ def main(unused_args):
server_def = tensorflow_server_pb2.ServerDef(protocol='grpc')
job_def = server_def.cluster.job.add()
job_def.name = 'tpu_worker'
job_def.tasks[0] = 'localhost:8470'
tpu_task_port = os.getenv('TPU_TASK_PORT')
if tpu_task_port is None or not tpu_task_port:
tpu_task_port = '8470' # If TPU task port is not available, use 8470.
job_def.tasks[0] = 'localhost:' + tpu_task_port
server_def.job_name = 'tpu_worker'
server_def.task_index = 0

Expand Down

0 comments on commit 1a19b6e

Please sign in to comment.