Skip to content

Commit

Permalink
Issue #18 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
larshertel committed Jan 12, 2019
1 parent c9b3053 commit 99de04f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
7 changes: 3 additions & 4 deletions sherpa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def keras_callback(self, trial, objective_name, context_names=[]):
return keras.callbacks.LambdaCallback(on_epoch_end=send_call)



class _Runner(object):
"""
Encapsulates all functionality needed to run a Study in parallel.
Expand All @@ -405,8 +404,8 @@ class _Runner(object):
scheduler (sherpa.schedulers.Scheduler): a scheduler object.
database (sherpa.database._Database): the database.
max_concurrent (int): how many trials to run in parallel.
command (str): the command that runs a trial script e.g.
"python train_nn.py".
command (list[str]): components of the command that runs a trial script
e.g. ["python", "train_nn.py"].
resubmit_failed_trials (bool): whether a failed trial should be
resubmitted.
Expand Down Expand Up @@ -617,7 +616,7 @@ def optimize(parameters, algorithm, lower_is_better,
scheduler=scheduler,
database=db,
max_concurrent=max_concurrent,
command=' '.join(['python', filename]),
command=['python', filename],
resubmit_failed_trials=resubmit_failed_trials)
runner.run_loop()
return study.get_best_result()
Expand Down
8 changes: 4 additions & 4 deletions sherpa/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def submit_job(self, command, env={}, job_name=''):
Submits a job to the scheduler.
Args:
command (str): the command to run by the scheduler e.g.
``python train.py``
command (list[str]): components to the command to run by the
scheduler e.g. ``["python", "train.py"]``
env (dict): environment variables to pass to the job.
job_name (str): this specifies a name for the job and its output
directory.
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, submit_options='', output_dir=''):
def submit_job(self, command, env={}, job_name=''):
env.update(os.environ.copy())
optns = self.submit_options.split(' ') if self.submit_options else []
process = subprocess.Popen(optns + command.split(' '), env=env)
process = subprocess.Popen(optns + command, env=env)
self.jobs[process.pid] = process
return process.pid

Expand Down Expand Up @@ -181,7 +181,7 @@ def submit_job(self, command, env={}, job_name=''):
job_script += 'echo "Running from" ${HOSTNAME}\n'
for var_name, var_value in env.items():
job_script += 'export {}={}\n'.format(var_name, var_value)
job_script += command # 'python file.py args...'
job_script += " ".join(command) # 'python file.py args...'

# Submit command to SGE.
# Note: submitting job using drmaa didn't work because we weren't able
Expand Down
4 changes: 2 additions & 2 deletions tests/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_local_scheduler(test_dir):

s = sherpa.schedulers.LocalScheduler()

job_id = s.submit_job("python {}/test.py".format(test_dir),
job_id = s.submit_job(["python", "{}/test.py".format(test_dir)],
env={'SHERPA_TRIAL_ID': '3'})

assert s.get_status(job_id) == sherpa.schedulers._JobStatus.running
Expand All @@ -101,7 +101,7 @@ def test_local_scheduler(test_dir):
testlogger.debug(s.get_status(job_id))
assert s.get_status(job_id) == sherpa.schedulers._JobStatus.finished

job_id = s.submit_job("python {}/test.py".format(test_dir))
job_id = s.submit_job(["python", "{}/test.py".format(test_dir)])
time.sleep(1)
s.kill_job(job_id)
time.sleep(1)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_sherpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_test_study():
mock_algorithm.get_suggestion.return_value = {'a': 1, 'b': 2}
mock_stopping_rule = mock.MagicMock()

s = sherpa.Study(parameters=get_test_parameters(),
s = sherpa.Study(parameters=list(get_test_parameters()),
algorithm=mock_algorithm,
stopping_rule=mock_stopping_rule,
lower_is_better=True)
Expand All @@ -194,7 +194,7 @@ def test_runner_update_results():

r = sherpa.core._Runner(study=get_test_study(), scheduler=mock.MagicMock(),
database=mock_db, max_concurrent=1,
command="python test.py")
command=["python", "test.py"])

# new trial
t = get_test_trial()
Expand All @@ -220,7 +220,7 @@ def test_runner_update_active_trials():

r = sherpa.core._Runner(study=mock_study, scheduler=mock_scheduler,
database=mock.MagicMock(), max_concurrent=1,
command="python test.py")
command=["python", "test.py"])

t = get_test_trial()
r._all_trials[t.id] = {'trial': t, 'job_id': None}
Expand All @@ -242,7 +242,7 @@ def test_runner_stop_bad_performers():
scheduler=mock.MagicMock(),
database=mock.MagicMock(),
max_concurrent=1,
command="python test.py")
command=["python", "test.py"])

# setup
t = get_test_trial()
Expand All @@ -262,7 +262,7 @@ def test_runner_stop_bad_performers():

def test_runner_submit_new_trials():
mock_scheduler = mock.MagicMock()
mock_scheduler.submit.side_effect = ['job1', 'job2', 'job3']
mock_scheduler.submit_job.side_effect = ['job1', 'job2', 'job3']
mock_study = mock.MagicMock()
mock_study.get_suggestion.side_effect = [get_test_trial(1),
get_test_trial(2),
Expand All @@ -272,12 +272,12 @@ def test_runner_submit_new_trials():
scheduler=mock_scheduler,
database=mock.MagicMock(),
max_concurrent=3,
command="python test.py")
command=["python", "test.py"])

r.submit_new_trials()

mock_scheduler.submit.has_calls([mock.call("python test.py"),
mock.call("python test.py"),
mock.call("python test.py")])
mock_scheduler.submit_job.has_calls([mock.call(["python", "test.py"]),
mock.call(["python", "test.py"]),
mock.call(["python", "test.py"])])
assert len(r._active_trials) == 3
assert len(r._all_trials) == 3

0 comments on commit 99de04f

Please sign in to comment.