Skip to content

Commit

Permalink
Merge pull request #173 from zuoxingdong/step_info_trajectory
Browse files Browse the repository at this point in the history
update run_experiment: support run serially
  • Loading branch information
zuoxingdong committed May 8, 2019
2 parents 17b375f + 08abc63 commit 5e2f6b6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion baselines/cem/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def run(config, seed, device, logdir):
config=config,
seeds=[1770966829, 1500925526, 2054191100],
log_dir='logs/default',
max_workers=5,
max_workers=None, # no parallelization
chunksize=1,
use_gpu=False,
gpu_ids=None)
11 changes: 7 additions & 4 deletions lagom/experiment/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run_experiment(run, config, seeds, log_dir, max_workers, chunksize=1, use_gp
config (Config): a :class:`Config` object defining all configuration settings
seeds (list): a list of random seeds
log_dir (str): a string to indicate the path to store loggings.
max_workers (int): argument for ProcessPoolExecutor.
max_workers (int): argument for ProcessPoolExecutor. if `None`, then all experiments run serially.
chunksize (int): argument for Executor.map()
use_gpu (bool): if `True`, then use CUDA. Otherwise, use CPU.
gpu_ids (list): if `None`, then use all available GPUs. Otherwise, only use the
Expand Down Expand Up @@ -124,7 +124,10 @@ def _run(args):
result = run(config, seed, device, logdir)
return result

with ProcessPoolExecutor(max_workers=max_workers) as executor:
args = list(product(configs, seeds))
results = list(executor.map(CloudpickleWrapper(_run), args, chunksize=chunksize))
args = list(product(configs, seeds))
if max_workers is None:
results = [_run(x) for x in args]
else:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(CloudpickleWrapper(_run), args, chunksize=chunksize))
return results
2 changes: 1 addition & 1 deletion test/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_config(num_sample, keep_dict_order):


@pytest.mark.parametrize('num_sample', [1, 5])
@pytest.mark.parametrize('max_workers', [1, 5, 100])
@pytest.mark.parametrize('max_workers', [None, 1, 5, 100])
@pytest.mark.parametrize('chunksize', [1, 7, 40])
def test_run_experiment(num_sample, max_workers, chunksize):
def run(config, seed, device, logdir):
Expand Down

0 comments on commit 5e2f6b6

Please sign in to comment.