Skip to content

Commit

Permalink
Merge pull request #3 from pypr/dont-change-commands
Browse files Browse the repository at this point in the history
Don't change commands
  • Loading branch information
prabhuramachandran committed Aug 30, 2018
2 parents 0dffcf1 + ec9646a commit 380cd20
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 60 deletions.
89 changes: 51 additions & 38 deletions automan/cluster_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@
from urllib.request import urlopen


def prompt(msg):
try:
return raw_input(msg)
except NameError:
return input(msg)


class ClusterManager(object):
"""The cluster manager class.
Expand Down Expand Up @@ -110,7 +103,8 @@ class ClusterManager(object):
#######################################################

def __init__(self, root='automan', sources=None,
config_fname='config.json', exclude_paths=None):
config_fname='config.json', exclude_paths=None,
testing=False):
"""Create a cluster manager instance.
**Parameters**
Expand All @@ -125,12 +119,17 @@ def __init__(self, root='automan', sources=None,
exclude_paths: list
A list of paths to exclude while syncing. This is in a form suitable
to pass to rsync.
testing: bool
Use this while testing. This allows us to run unit tests for remotes
on the local machine.
"""
self.root = root
self.workers = []
self.sources = sources
self.scripts_dir = os.path.abspath('.' + self.root)
self.exclude_paths = exclude_paths if exclude_paths else []
self.testing = testing

# This is setup by the config and is the name of
# the project directory.
Expand All @@ -144,37 +143,38 @@ def __init__(self, root='automan', sources=None,
os.makedirs(self.scripts_dir)

# ### Private Protocol ########################################

def _bootstrap(self, host, home):
venv_script = self._get_virtualenv()
base_cmd = ("cd {home}; mkdir -p {root}/envs; "
"mkdir -p {root}/{project_name}/.{root}").format(
home=home, root=self.root,
project_name=self.project_name
)
self._ssh_run_command(host, base_cmd)

cmd = ("ssh {host} 'cd {home}; mkdir -p {root}/envs'; " +
"mkdir -p {root}/{project_name}/.{root}").format(
home=home, host=host, root=self.root,
project_name=self.project_name
)
self._run_command(cmd)

root = os.path.join(home, self.root)
cmd = "scp {venv_script} {host}:{root}".format(
host=host, root=root, venv_script=venv_script
)
self._run_command(cmd)
abs_root = os.path.join(home, self.root)
if venv_script:
real_host = '' if self.testing else '{host}:'.format(host=host)
cmd = "scp {venv_script} {host}{root}".format(
host=real_host, root=abs_root, venv_script=venv_script
)
self._run_command(cmd)

self._update_sources(host, home)

cmd = "ssh {host} 'cd {root}; ./{project_name}/.{root}/bootstrap.sh'"
cmd = cmd.format(host=host, root=root, project_name=self.project_name)
cmd = "cd {abs_root}; ./{project_name}/.{root}/bootstrap.sh".format(
abs_root=abs_root, root=self.root, project_name=self.project_name
)
try:
self._run_command(cmd)
self._ssh_run_command(host, cmd)
except subprocess.CalledProcessError:
msg = dedent("""
******************************************************************
Bootstrapping of remote host {host} failed.
All files have been copied to the host.
Please take a look at
{root}/{project_name}/.{root}/bootstrap.sh
{abs_root}/{project_name}/.{root}/bootstrap.sh
and try to fix it.
You should run it from within the {root} directory as:
Expand All @@ -189,7 +189,8 @@ def _bootstrap(self, host, home):
and can be edited by you. These will be used for any new hosts
you add.
******************************************************************
""".format(root=root, host=host, scripts_dir=self.scripts_dir,
""".format(abs_root=abs_root, root=self.root, host=host,
scripts_dir=self.scripts_dir,
project_name=self.project_name)
)
print(msg)
Expand Down Expand Up @@ -224,16 +225,25 @@ def _read_config(self):
self.scripts_dir = os.path.abspath('.' + self.root)

def _rebuild(self, host, home):
root = os.path.join(home, self.root)
command = "ssh {host} 'cd {root}; ./{project_name}/.{root}/update.sh'"
command = command.format(host=host, root=root,
project_name=self.project_name)
self._run_command(command)
abs_root = os.path.join(home, self.root)
base_cmd = "cd {abs_root}; ./{project_name}/.{root}/update.sh".format(
abs_root=abs_root, root=self.root, project_name=self.project_name
)
self._ssh_run_command(host, base_cmd)

def _run_command(self, cmd, **kw):
print(cmd)
subprocess.check_call(shlex.split(cmd), **kw)

def _ssh_run_command(self, host, base_cmd):
if self.testing:
command = base_cmd
print(command)
subprocess.check_call(command, shell=True)
else:
command = "ssh {host} '{cmd}'".format(host=host, cmd=base_cmd)
self._run_command(command)

def _sync_dir(self, host, src, dest):
options = ""
kwargs = dict()
Expand All @@ -250,8 +260,9 @@ def _sync_dir(self, host, src, dest):
for path in self.exclude_paths:
options += ' --exclude="%s"' % path

command = "rsync -a {options} {src} {host}:{dest} ".format(
options=options, src=src, host=host, dest=dest
real_host = '' if self.testing else '{host}:'.format(host=host)
command = "rsync -a {options} {src} {host}{dest} ".format(
options=options, src=src, host=real_host, dest=dest
)
self._run_command(command, **kwargs)

Expand Down Expand Up @@ -279,8 +290,9 @@ def _update_sources(self, host, home):

path = os.path.join(home, self.root, self.project_name,
'.' + self.root)
cmd = "scp {script_files} {host}:{path}".format(
host=host, path=path, script_files=' '.join(script_files)
real_host = '' if self.testing else '{host}:'.format(host=host)
cmd = "scp {script_files} {host}{path}".format(
host=real_host, path=path, script_files=' '.join(script_files)
)
self._run_command(cmd)

Expand Down Expand Up @@ -346,9 +358,10 @@ def create_scheduler(self):
else:
python = worker.get('python')
chdir = worker.get('chdir')
scheduler.add_worker(
dict(host=host, python=python, chdir=chdir, nfs=nfs)
)
config = dict(host=host, python=python, chdir=chdir, nfs=nfs)
if self.testing:
config['testing'] = True
scheduler.add_worker(config)
return scheduler

def cli(self, argv=None):
Expand Down
16 changes: 7 additions & 9 deletions automan/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def __init__(self, command, output_dir, n_core=1, n_thread=1, env=None):
"""
self.command = _make_command_list(command)
self.orig_command = self.command
self.substitute_in_command('python', sys.executable)

self._given_env = env
self.env = dict(os.environ)
if env is not None:
Expand All @@ -51,10 +48,11 @@ def __init__(self, command, output_dir, n_core=1, n_thread=1, env=None):
def substitute_in_command(self, basename, substitute):
"""Replace occurrence of given basename with the substitute.
This is useful where the user asks to run ['python', 'script.py'].
Here, we need to make sure the right python is used. Typically a remote
machine will need to use a particular Python and not just the vanilla
Python.
This is useful where the user asks to run ['python', 'script.py'] and
we wish to change the 'python' to a specific Python. Normally this is
not needed as the PATH is set to pick up the right Python. However, in
the rare cases where this rewriting is needed, this method is
available.
"""
args = []
Expand All @@ -63,7 +61,7 @@ def substitute_in_command(self, basename, substitute):
args.append(substitute)
else:
args.append(arg)
self.commands = args
self.command = args

def to_dict(self):
state = dict()
Expand All @@ -73,7 +71,7 @@ def to_dict(self):
return state

def pretty_command(self):
return ' '.join(self.orig_command)
return ' '.join(self.command)

def get_stderr(self):
return open(self.stderr).read()
Expand Down
5 changes: 0 additions & 5 deletions automan/tests/test_automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,6 @@ def test_filter_cases_works_with_predicate():
class TestAutomator(TestAutomationBase):
def setUp(self):
super(TestAutomator, self).setUp()
patch = mock.patch(
'automan.cluster_manager.prompt', return_value=''
)
patch.start()
self.addCleanup(patch.stop)

@mock.patch.object(TaskRunner, 'run')
def test_automator(self, mock_run):
Expand Down
93 changes: 88 additions & 5 deletions automan/tests/test_cluster_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,47 @@

import json
import os
from os.path import dirname
import shutil
import sys
import tempfile
from textwrap import dedent
import unittest

try:
from unittest import mock
except ImportError:
import mock

from automan.jobs import Job
from automan.cluster_manager import ClusterManager
from .test_jobs import wait_until


ROOT_DIR = dirname(dirname(dirname(__file__)))


class MyClusterManager(ClusterManager):

BOOTSTRAP = dedent("""\
#!/bin/bash
set -e
python -m venv envs/{project_name}
source envs/{project_name}/bin/activate
cd %s
python -m pip install execnet psutil
python setup.py install
""" % ROOT_DIR)

UPDATE = dedent("""\
#!/bin/bash
echo "update"
""")

def _get_virtualenv(self):
return None


class TestClusterManager(unittest.TestCase):
Expand All @@ -20,11 +51,6 @@ def setUp(self):
self.cwd = os.getcwd()
self.root = tempfile.mkdtemp()
os.chdir(self.root)
patch = mock.patch(
'automan.cluster_manager.prompt', return_value=''
)
patch.start()
self.addCleanup(patch.stop)

def tearDown(self):
os.chdir(self.cwd)
Expand All @@ -44,6 +70,8 @@ def test_new_config_created_on_creation(self):
config = self._get_config()

self.assertEqual(config.get('root'), 'automan')
self.assertEqual(config.get('project_name'),
os.path.basename(self.root))
self.assertEqual(os.path.realpath(config.get('sources')[0]),
os.path.realpath(self.root))
workers = config.get('workers')
Expand Down Expand Up @@ -106,3 +134,58 @@ def test_cli(self, mock_update, mock_add_worker):

# Then
mock_add_worker.assert_called_with('host', 'home', True)

@unittest.skipIf((sys.version_info < (3, 3)) or
sys.platform.startswith('win'),
'Test requires Python 3.x and a non-Windows system.')
def test_remote_bootstrap_and_sync(self):
# Given
cm = MyClusterManager(exclude_paths=['outputs/'], testing=True)
output_dir = os.path.join(self.root, 'outputs')
os.makedirs(output_dir)

# Remove the default localhost worker.
cm.workers = []

# When
cm.add_worker('host', home=self.root, nfs=False)

# Then
self.assertEqual(len(cm.workers), 1)
worker = cm.workers[0]
self.assertEqual(worker['host'], 'host')
project_name = cm.project_name
self.assertEqual(project_name, os.path.basename(self.root))
py = os.path.join(self.root, 'automan', 'envs', project_name,
'bin', 'python')
self.assertEqual(worker['python'], py)
chdir = os.path.join(self.root, 'automan', project_name)
self.assertEqual(worker['chdir'], chdir)

# Given
cmd = ['python', '-c', 'import sys; print(sys.executable)']
job = Job(command=cmd, output_dir=output_dir)

s = cm.create_scheduler()

# When
proxy = s.submit(job)

# Then
wait_until(lambda: proxy.status() != 'done')

self.assertEqual(proxy.status(), 'done')
output = proxy.get_stdout().strip()
self.assertEqual(os.path.realpath(output), os.path.realpath(py))

# Test to see if updating works.

# When
with open(os.path.join(self.root, 'script.py'), 'w') as f:
f.write('print("hello")\n')

cm.update()

# Then
dest = os.path.join(self.root, 'automan', project_name, 'script.py')
self.assertTrue(os.path.exists(dest))
12 changes: 12 additions & 0 deletions automan/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def test_simple_job(self):
expect['command'][0] = sys.executable
self.assertDictEqual(state, expect)

def test_job_substitute_in_command(self):
# Given
j = jobs.Job(command=['python', '-c', 'print(123)'],
output_dir=self.root)

# When
sub = '/usr/bin/python'
j.substitute_in_command('python', sub)

# Then
self.assertEqual(j.command[0], sub)

def test_job_status(self):
# Given/When
j = jobs.Job(
Expand Down

0 comments on commit 380cd20

Please sign in to comment.