Skip to content
This repository has been archived by the owner on Jan 6, 2023. It is now read-only.

implemented torchelastic.distributed.launch for oss #65

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions test/agent/server/local_elastic_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def _distributed_sum(wait):
raise RuntimeError(f"Expected rank sum {expected}, got {actual}")


def _check_env_function():
# just check these env vars exist, os.environ[...] will naturally throw
# if the variable does not exist
os.environ["RANK"]
os.environ["LOCAL_RANK"]
os.environ["WORLD_SIZE"]
os.environ["MASTER_ADDR"]
os.environ["MASTER_PORT"]
os.environ["TORCHELASTIC_RESTART_COUNT"]
os.environ["TORCHELASTIC_MAX_RESTARTS"]


def _run_agent(run_id, etcd_host, etcd_port, min_size, max_size, wait=0):
rdzv_handler = dist.rendezvous(
f"etcd://{etcd_host}:{etcd_port}/{run_id}"
Expand Down Expand Up @@ -130,6 +142,11 @@ def test_run_bipolar_function(self):
self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
self.assertEqual(0, agent._remaining_restarts)

def test_run_check_env_function(self):
spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2)
agent = LocalElasticAgent(spec, start_method="fork")
agent.run()

def test_double_agent_happy(self):
host = self._etcd_server.get_host()
port = self._etcd_server.get_port()
Expand Down
54 changes: 54 additions & 0 deletions test/distributed/bin/test_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import sys
from pathlib import Path


def parse_args():
parser = argparse.ArgumentParser(
description="test script, parses --local_rank,"
" used to test launcher with no --use_env flag"
)

# make this required so that if --local_rank is not passed
# when --use_env is used, the script fails as it fails to
# parse arguments
parser.add_argument(
"--local_rank", required=True, type=int, help="set by launch.py"
)

# file is used for assertions
parser.add_argument(
"--touch_file_dir",
type=str,
help="dir to touch a file with global rank as the filename",
)
return parser.parse_args()


def main():
args = parse_args()
print(f"Running {sys.argv}")

env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]

print("Distributed env vars set by agent:")
for env_var in env_vars:
value = os.environ[env_var]
print(f"{env_var} = {value}")

file = os.path.join(args.touch_file_dir, os.environ["RANK"])
Path(file).touch()
print(f"Success, created {file}")


if __name__ == "__main__":
main()
11 changes: 11 additions & 0 deletions test/distributed/bin/test_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

FILE="$1/$RANK"
echo "creating $FILE"
touch "$FILE"
55 changes: 55 additions & 0 deletions test/distributed/bin/test_script_use_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
from pathlib import Path


def parse_args():
parser = argparse.ArgumentParser(
description="test script, does not parse --local_rank,"
" use to test launcher with --use_env flag"
)

parser.add_argument(
"--fail",
default=False,
action="store_true",
help="forces the script to throw a RuntimeError",
)

# file is used for assertions
parser.add_argument(
"--touch_file_dir",
type=str,
help="dir to touch a file with global rank as the filename",
)
return parser.parse_args()


def main():
args = parse_args()

env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]

print("Distributed env vars set by agent:")
for env_var in env_vars:
value = os.environ[env_var]
print(f"{env_var} = {value}")

if args.fail:
raise RuntimeError("raising exception since --fail flag was set")
else:
file = os.path.join(args.touch_file_dir, os.environ["RANK"])
Path(file).touch()
print(f"Success, created {file}")


if __name__ == "__main__":
main()
182 changes: 182 additions & 0 deletions test/distributed/launch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import tempfile
import unittest
import uuid
kiukchung marked this conversation as resolved.
Show resolved Hide resolved

import torchelastic.distributed.launch as launch
import torchelastic.rendezvous.etcd_rendezvous # noqa: F401
from p2p.etcd_server_fixture import EtcdServerFixture


def path(script):
return os.path.join(os.path.dirname(__file__), script)


class LaunchTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# start a standalone, single process etcd server to use for all tests
cls._etcd_server = EtcdServerFixture()
cls._etcd_server.start()
host = cls._etcd_server.get_host()
port = cls._etcd_server.get_port()
cls._etcd_endpoint = f"{host}:{port}"

@classmethod
def tearDownClass(cls):
# stop the standalone etcd server
cls._etcd_server.stop()

def setUp(self):
self.test_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.test_dir)

def test_get_rdzv_url(self):
actual_url = launch.get_rdzv_url(
"etcd",
"localhost:8081",
"1234",
1,
4,
"timeout=60,protocol=https,key=/etc/kubernetes/certs/client.key",
)

expected_url = (
"etcd://localhost:8081/1234"
"?min_workers=1"
"&max_workers=4"
"&timeout=60"
"&protocol=https"
"&key=/etc/kubernetes/certs/client.key"
)

self.assertEqual(expected_url, actual_url)

def test_get_rdzv_url_no_conf(self):
actual_url = launch.get_rdzv_url(
"etcd", "localhost:8081", "1234", 1, 4, conf=""
)

expected_url = "etcd://localhost:8081/1234" "?min_workers=1" "&max_workers=4"

self.assertEqual(expected_url, actual_url)

def test_launch_user_script_python(self):
run_id = str(uuid.uuid4().int)
nnodes = 1
nproc_per_node = 4
world_size = nnodes * nproc_per_node
args = [
f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}",
f"--rdzv_backend=etcd",
f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}",
f"--monitor_interval=1",
f"--start_method=fork",
path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}",
]
launch.main(args)

# make sure all the workers ran
# each worker touches a file with its global rank as the name
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)

def test_launch_user_script_python_use_env(self):
run_id = str(uuid.uuid4().int)
nnodes = 1
nproc_per_node = 4
world_size = nnodes * nproc_per_node
args = [
f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}",
f"--use_env",
f"--rdzv_backend=etcd",
f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}",
f"--monitor_interval=1",
f"--start_method=fork",
path("bin/test_script_use_env.py"),
f"--touch_file_dir={self.test_dir}",
]
launch.main(args)

# make sure all the workers ran
# each worker touches a file with its global rank as the name
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)

def test_launch_user_script_bash(self):
run_id = str(uuid.uuid4().int)
nnodes = 1
nproc_per_node = 4
world_size = nnodes * nproc_per_node

args = [
f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}",
f"--rdzv_backend=etcd",
f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}",
f"--monitor_interval=1",
f"--start_method=fork",
f"--no_python",
]

script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

with self.assertRaises(ValueError):
# --no_python also requires --use_env
launch.main(args + script_args)

with self.assertRaises(ValueError):
# --no_python cannot be used with --module
launch.main(args + ["--module"] + script_args)

launch.main(args + ["--use_env"] + script_args)

# make sure all the workers ran
# each worker touches a file with its global rank as the name
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)

def test_launch_elastic(self):
run_id = str(uuid.uuid4().int)
min_nodes = 1
max_nodes = 2
nproc_per_node = 4
# we are only launching 1 node (even though max = 2)
world_size = nproc_per_node
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc_per_node={nproc_per_node}",
f"--rdzv_backend=etcd",
f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}",
f"--monitor_interval=1",
f"--start_method=fork",
path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}",
]
launch.main(args)

# make sure all the workers ran
# each worker touches a file with its global rank as the name
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)
2 changes: 1 addition & 1 deletion torchelastic/agent/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torchelastic.rendezvous as rdzv


DEFAULT_ROLE = "trainer"
DEFAULT_ROLE = "default"

log = logging.getLogger(__name__)

Expand Down
Loading