Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 95 additions & 77 deletions torch_xla/distributed/xla_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,94 @@
import torch_xla.utils.utils as xu


def get_args_parser() -> argparse.ArgumentParser:
"""Helper function parsing the command line options."""

parser = argparse.ArgumentParser(
description='PyTorch on TPU distrubuted training launcher.',
epilog=('Usage example: python3 -m'
' torch_xla.distributed.xla_dist --tpu=[TPU_NAME]'
' --conda-env torch-xla-nightly -- python3 train.py'))

cluster_group = parser.add_argument_group('Cluster Setup')
cluster_group.add_argument(
'--tpu', type=str, required=True, help='Name of the Cloud TPU pod.')
cluster_group.add_argument(
'--vm',
action='append',
type=str,
help=('List of single Compute VM instance names. '
'If not provided we assume usage of instance groups.'))

docker_group = parser.add_argument_group('Docker Setup')
docker_group.add_argument(
'--docker-container',
default='',
type=str,
help='Name of docker container if running in docker.')
docker_group.add_argument(
'--docker-image',
default='',
type=str,
help='Name of docker image if running in container.')
docker_group.add_argument(
'--docker-run-flag',
action='append',
type=str,
help='Docker run flags to run container with (ex. --shm-size, ...).')

conda_group = parser.add_argument_group('Conda Setup')
conda_group.add_argument(
'--conda-env',
default='',
type=str,
help='Name of the conda environment if running with conda.')

parser.add_argument(
'--env',
action='append',
type=str,
help='List of environment variables to distribute.')
parser.add_argument(
'--restart-tpuvm-pod-server',
action='store_true',
help='Restart the long running XRT local service for this training.')
parser.add_argument(
'--tpuvm-server-port',
default=51011,
type=int,
help='Port that XRT local service will be start on.')
parser.add_argument(
'positional',
nargs='+',
type=str,
help='The python command to launch training including model parameters.')
return parser


def parse_args(args):
parser = get_args_parser()
return parser.parse_args(args)


def resolve_and_execute(flags):
"""Resolves the command line flags and launches a distributed process"""
cluster_resolver = ClusterResolver(flags.tpu, vms=flags.vm)
cluster = cluster_resolver.get_cluster()
tpuvm_mode = cluster_resolver.get_tpuvm_mode()
executor = DistributedExecutor(
cluster,
docker_container=flags.docker_container,
docker_image=flags.docker_image,
docker_run_flags=flags.docker_run_flag,
conda_env=flags.conda_env,
env_vars=flags.env,
restart_server=flags.restart_tpuvm_pod_server,
tpuvm_mode=tpuvm_mode,
tpuvm_server_port=flags.tpuvm_server_port)
executor.run(flags.positional)


def concat_cmd_list(cmd_list, delimiter=' ', quote='"'):
concat = ''
for cmd in cmd_list:
Expand Down Expand Up @@ -588,86 +676,16 @@ def run(self, cmd):
})


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='PyTorch on TPU distrubuted training launcher.',
epilog=('Usage example: python3 -m'
' torch_xla.distributed.xla_dist --tpu=[TPU_NAME]'
' --conda-env torch-xla-nightly -- python3 train.py'))

cluster_group = parser.add_argument_group('Cluster Setup')
cluster_group.add_argument(
'--tpu', type=str, required=True, help='Name of the Cloud TPU pod.')
cluster_group.add_argument(
'--vm',
action='append',
type=str,
help=('List of single Compute VM instance names. '
'If not provided we assume usage of instance groups.'))

docker_group = parser.add_argument_group('Docker Setup')
docker_group.add_argument(
'--docker-container',
default='',
type=str,
help='Name of docker container if running in docker.')
docker_group.add_argument(
'--docker-image',
default='',
type=str,
help='Name of docker image if running in container.')
docker_group.add_argument(
'--docker-run-flag',
action='append',
type=str,
help='Docker run flags to run container with (ex. --shm-size, ...).')

conda_group = parser.add_argument_group('Conda Setup')
conda_group.add_argument(
'--conda-env',
default='',
type=str,
help='Name of the conda environment if running with conda.')

parser.add_argument(
'--env',
action='append',
type=str,
help='List of environment variables to distribute.')
parser.add_argument(
'--restart-tpuvm-pod-server',
action='store_true',
help='Restart the long running XRT local service for this training.')
parser.add_argument(
'--tpuvm-server-port',
default=51011,
type=int,
help='Port that XRT local service will be start on.')
parser.add_argument(
'positional',
nargs='+',
type=str,
help='The python command to launch training including model parameters.')

FLAGS = parser.parse_args()

def main(args=None):
FLAGS = parse_args(args)
if (FLAGS.docker_container or FLAGS.docker_image or
FLAGS.docker_run_flag) and FLAGS.conda_env:
raise ValueError('Docker Setup arguments and Conda Setup'
' arguments are mutually exclusive.')

# Resolve VM and TPU clusters.
cluster_resolver = ClusterResolver(FLAGS.tpu, vms=FLAGS.vm)
cluster = cluster_resolver.get_cluster()
tpuvm_mode = cluster_resolver.get_tpuvm_mode()
executor = DistributedExecutor(
cluster,
docker_container=FLAGS.docker_container,
docker_image=FLAGS.docker_image,
docker_run_flags=FLAGS.docker_run_flag,
conda_env=FLAGS.conda_env,
env_vars=FLAGS.env,
restart_server=FLAGS.restart_tpuvm_pod_server,
tpuvm_mode=tpuvm_mode,
tpuvm_server_port=FLAGS.tpuvm_server_port)
executor.run(FLAGS.positional)
resolve_and_execute(FLAGS)


if __name__ == '__main__':
main()