Skip to content

Commit

Permalink
Allow multiple config files when cloud training.
Browse files Browse the repository at this point in the history
  • Loading branch information
IanTayler authored and vierja committed Oct 20, 2017
1 parent d6324b6 commit 01e82d3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions luminoth/tools/cloud/gcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ def cloud_service(credentials, service, version='v1'):
@click.option('--bucket', 'bucket_name', help='Where to save models and logs.') # noqa
@click.option('--region', default='us-central1', help='Region in which to run the job.') # noqa
@click.option('--dataset', required=True, help='Bucket where the dataset is located.') # noqa
@click.option('--config', help='Path to config to use in training.')
@click.option('config_files', '--config', '-c', multiple=True, required=True, help='Path to config to use in training.') # noqa
@click.option('--scale-tier', default=DEFAULT_SCALE_TIER, type=click.Choice(SCALE_TIERS)) # noqa
@click.option('--master-type', default=DEFAULT_MASTER_TYPE, type=click.Choice(MACHINE_TYPES)) # noqa
@click.option('--worker-type', default=DEFAULT_WORKER_TYPE, type=click.Choice(MACHINE_TYPES)) # noqa
@click.option('--worker-count', default=DEFAULT_WORKER_COUNT, type=int)
@click.option('--parameter-server-type', default=DEFAULT_PS_TYPE, type=click.Choice(MACHINE_TYPES)) # noqa
@click.option('--parameter-server-count', default=DEFAULT_PS_COUNT, type=int)
def train(job_id, service_account_json, bucket_name, region, config, dataset,
scale_tier, master_type, worker_type, worker_count,
def train(job_id, service_account_json, bucket_name, region, config_files,
dataset, scale_tier, master_type, worker_type, worker_count,
parameter_server_type, parameter_server_count):

args = []
Expand Down Expand Up @@ -195,7 +195,7 @@ def train(job_id, service_account_json, bucket_name, region, config, dataset,
'--override', 'dataset.data_augmentation=false'
])

if config:
for config in config_files:
# Upload config file to be used by the training job.
path = upload_file(bucket, base_path, config)
args.extend(['--config', 'gs://{}/{}'.format(bucket_name, path)])
Expand Down

0 comments on commit 01e82d3

Please sign in to comment.