Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…iments into data_egress
  • Loading branch information
michaelzhiluo committed Dec 3, 2021
2 parents 9577b4d + 4745a16 commit 16ce4da
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
61 changes: 61 additions & 0 deletions prototype/examples/horovod_distributed_tf_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json
from typing import Dict, List

import sky
import time_estimators
from sky import clouds

IPAddr = str

with sky.Dag() as dag:
# Total Nodes, INCLUDING Head Node
num_nodes = 2

# The setup command. Will be run under the working directory.
setup = 'pip3 install --upgrade pip && \
pip3 install ray[default] && \
git clone https://github.com/michaelzhiluo/horovod-tf-resnet.git && \
cd horovod-tf-resnet && chmod +x setup.sh && ./setup.sh'

# Post setup function. Run after `ray up *.yml` completes. Returns dictionary of commands to be run on each corresponding node.
# List of IPs, 0th index denoting head worker
def post_setup_fn(ip_list: List[IPAddr]) -> Dict[IPAddr, str]:
command_dict = {}
head_run_str = "ssh-keygen -f ~/.ssh/id_rsa -P \"\" <<< y"
if len(ip_list) > 1:
for i, ip in enumerate(ip_list[1:]):
append_str = f" && cat ~/.ssh/id_rsa.pub | ssh -i ~/ray_bootstrap_key.pem -o StrictHostKeyChecking=no ubuntu@{ip} \"mkdir -p ~/.ssh && chmod 700 ~/.ssh && cat >> ~/.ssh/authorized_keys && chmod 600 ~/.ssh/authorized_keys\""
head_run_str = head_run_str + append_str
return {ip_list[0]: head_run_str}

# The command to run. Will be run under the working directory.
def run_fn(ip_list: List[IPAddr]) -> Dict[IPAddr, str]:
run_dict = {}
ip_str = "localhost:1"
for i, ip in enumerate(ip_list[1:]):
append_str = f",{ip}:1"
ip_str = ip_str + append_str
return {
ip_list[0]: f"cd horovod-tf-resnet && \
horovodrun -np {str(len(ip_list))} -H {ip_str} python3 horovod_mnist.py",
}

run = run_fn

train = sky.Task(
'train',
setup=setup,
post_setup_fn=post_setup_fn,
num_nodes=num_nodes,
run=run,
)

train.set_inputs('gs://cloud-tpu-test-datasets/fake_imagenet',
estimated_size_gigabytes=70)
train.set_outputs('resnet-model-dir', estimated_size_gigabytes=0.1)
train.set_resources({
sky.Resources(clouds.AWS(), 'p3.2xlarge'),
})

dag = sky.Optimizer.optimize(dag)
sky.execute(dag)
10 changes: 6 additions & 4 deletions prototype/sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,12 @@ def run_post_setup(self, handle: ResourceHandle, post_setup_fn: PostSetupFn,
ip_list = self._get_node_ips(handle, task.num_nodes)
ip_to_command = post_setup_fn(ip_list)
for ip, cmd in ip_to_command.items():
cmd = (f'mkdir -p {SKY_REMOTE_WORKDIR} && '
f'cd {SKY_REMOTE_WORKDIR} && {cmd}')
backend_utils.run_command_on_ip_via_ssh(ip, cmd, task.private_key,
task.container_name)
if cmd is not None:
cmd = (f'mkdir -p {SKY_REMOTE_WORKDIR} && '
f'cd {SKY_REMOTE_WORKDIR} && {cmd}')
backend_utils.run_command_on_ip_via_ssh(ip, cmd,
task.private_key,
task.container_name)

def _execute_par_task(self, handle: ResourceHandle,
par_task: task_mod.ParTask,
Expand Down

0 comments on commit 16ce4da

Please sign in to comment.