diff --git a/torch_xla/distributed/xla_dist.py b/torch_xla/distributed/xla_dist.py index 885ecb90392c..b19ccedd1520 100755 --- a/torch_xla/distributed/xla_dist.py +++ b/torch_xla/distributed/xla_dist.py @@ -19,6 +19,94 @@ import torch_xla.utils.utils as xu +def get_args_parser() -> argparse.ArgumentParser: + """Helper function parsing the command line options.""" + + parser = argparse.ArgumentParser( + description='PyTorch on TPU distrubuted training launcher.', + epilog=('Usage example: python3 -m' + ' torch_xla.distributed.xla_dist --tpu=[TPU_NAME]' + ' --conda-env torch-xla-nightly -- python3 train.py')) + + cluster_group = parser.add_argument_group('Cluster Setup') + cluster_group.add_argument( + '--tpu', type=str, required=True, help='Name of the Cloud TPU pod.') + cluster_group.add_argument( + '--vm', + action='append', + type=str, + help=('List of single Compute VM instance names. ' + 'If not provided we assume usage of instance groups.')) + + docker_group = parser.add_argument_group('Docker Setup') + docker_group.add_argument( + '--docker-container', + default='', + type=str, + help='Name of docker container if running in docker.') + docker_group.add_argument( + '--docker-image', + default='', + type=str, + help='Name of docker image if running in container.') + docker_group.add_argument( + '--docker-run-flag', + action='append', + type=str, + help='Docker run flags to run container with (ex. --shm-size, ...).') + + conda_group = parser.add_argument_group('Conda Setup') + conda_group.add_argument( + '--conda-env', + default='', + type=str, + help='Name of the conda environment if running with conda.') + + parser.add_argument( + '--env', + action='append', + type=str, + help='List of environment variables to distribute.') + parser.add_argument( + '--restart-tpuvm-pod-server', + action='store_true', + help='Restart the long running XRT local service for this training.') + parser.add_argument( + '--tpuvm-server-port', + default=51011, + type=int, + help='Port that XRT local service will be start on.') + parser.add_argument( + 'positional', + nargs='+', + type=str, + help='The python command to launch training including model parameters.') + return parser + + +def parse_args(args): + parser = get_args_parser() + return parser.parse_args(args) + + +def resolve_and_execute(flags): + """Resolves the command line flags and launches a distributed process""" + cluster_resolver = ClusterResolver(flags.tpu, vms=flags.vm) + cluster = cluster_resolver.get_cluster() + tpuvm_mode = cluster_resolver.get_tpuvm_mode() + executor = DistributedExecutor( + cluster, + docker_container=flags.docker_container, + docker_image=flags.docker_image, + docker_run_flags=flags.docker_run_flag, + conda_env=flags.conda_env, + env_vars=flags.env, + restart_server=flags.restart_tpuvm_pod_server, + tpuvm_mode=tpuvm_mode, + tpuvm_server_port=flags.tpuvm_server_port) + executor.run(flags.positional) + + def concat_cmd_list(cmd_list, delimiter=' ', quote='"'): concat = '' for cmd in cmd_list: @@ -588,86 +676,16 @@ def run(self, cmd): }) -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description='PyTorch on TPU distrubuted training launcher.', - epilog=('Usage example: python3 -m' - ' torch_xla.distributed.xla_dist --tpu=[TPU_NAME]' - ' --conda-env torch-xla-nightly -- python3 train.py')) - - cluster_group = parser.add_argument_group('Cluster Setup') - cluster_group.add_argument( - '--tpu', type=str, required=True, help='Name of the Cloud TPU pod.') - cluster_group.add_argument( - '--vm', - action='append', - type=str, - help=('List of single Compute VM instance names. ' - 'If not provided we assume usage of instance groups.')) - - docker_group = parser.add_argument_group('Docker Setup') - docker_group.add_argument( - '--docker-container', - default='', - type=str, - help='Name of docker container if running in docker.') - docker_group.add_argument( - '--docker-image', - default='', - type=str, - help='Name of docker image if running in container.') - docker_group.add_argument( - '--docker-run-flag', - action='append', - type=str, - help='Docker run flags to run container with (ex. --shm-size, ...).') - - conda_group = parser.add_argument_group('Conda Setup') - conda_group.add_argument( - '--conda-env', - default='', - type=str, - help='Name of the conda environment if running with conda.') - - parser.add_argument( - '--env', - action='append', - type=str, - help='List of environment variables to distribute.') - parser.add_argument( - '--restart-tpuvm-pod-server', - action='store_true', - help='Restart the long running XRT local service for this training.') - parser.add_argument( - '--tpuvm-server-port', - default=51011, - type=int, - help='Port that XRT local service will be start on.') - parser.add_argument( - 'positional', - nargs='+', - type=str, - help='The python command to launch training including model parameters.') - - FLAGS = parser.parse_args() - +def main(args=None): + FLAGS = parse_args(args) if (FLAGS.docker_container or FLAGS.docker_image or FLAGS.docker_run_flag) and FLAGS.conda_env: raise ValueError('Docker Setup arguments and Conda Setup' ' arguments are mutually exclusive.') # Resolve VM and TPU clusters. - cluster_resolver = ClusterResolver(FLAGS.tpu, vms=FLAGS.vm) - cluster = cluster_resolver.get_cluster() - tpuvm_mode = cluster_resolver.get_tpuvm_mode() - executor = DistributedExecutor( - cluster, - docker_container=FLAGS.docker_container, - docker_image=FLAGS.docker_image, - docker_run_flags=FLAGS.docker_run_flag, - conda_env=FLAGS.conda_env, - env_vars=FLAGS.env, - restart_server=FLAGS.restart_tpuvm_pod_server, - tpuvm_mode=tpuvm_mode, - tpuvm_server_port=FLAGS.tpuvm_server_port) - executor.run(FLAGS.positional) + resolve_and_execute(FLAGS) + + +if __name__ == '__main__': + main()