Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 7229fa8

Browse files
committed
Add {up/down}loading of (un)trained models
1 parent 7123ded commit 7229fa8

File tree

1 file changed

+74
-30
lines changed

1 file changed

+74
-30
lines changed

axon/client.py

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -577,32 +577,60 @@ def impl_get_task_ip(cluster_name, task_arn, region):
577577
return nics[0]["Association"]["PublicIp"]
578578

579579

580-
def impl_upload_model_file(model_name, bucket_name, region):
580+
def impl_upload_untrained_model_file(model_path, bucket_name, region):
581581
"""
582-
Uploads a model to S3.
582+
Uploads an untrained model to S3.
583583
584-
:param model_name: The filename of the model to upload (must be in the current directory).
584+
:param model_path: The file path to the model to upload, ending with the name of the model.
585585
:param bucket_name: The S3 bucket name.
586586
:param region: The region, or `None` to pull the region from the environment.
587587
"""
588588
client = make_client("s3", region)
589-
remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name)
590-
client.upload_file(model_name, bucket_name, remote_path)
591-
print("Uploaded to: {}\n".format(remote_path))
589+
key = "axon-untrained-models/" + os.path.basename(model_path)
590+
client.upload_file(model_path, bucket_name, key)
591+
print("Uploaded to: {}\n".format(key))
592592

593593

594-
def impl_download_model_file(model_name, bucket_name, region):
594+
def impl_download_untrained_model_file(model_path, bucket_name, region):
595595
"""
596-
Downloads a model from S3.
596+
Downloads an untrained model from S3.
597597
598-
:param model_name: The filename of the model to download (must be in the current directory).
598+
:param model_path: The file path to download to, ending with the name of the model.
599599
:param bucket_name: The S3 bucket name.
600600
:param region: The region, or `None` to pull the region from the environment.
601601
"""
602602
client = make_client("s3", region)
603-
remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name)
604-
client.download_file(bucket_name, remote_path, model_name)
605-
print("Downloaded from: {}\n".format(remote_path))
603+
key = "axon-untrained-models/" + os.path.basename(model_path)
604+
client.download_file(bucket_name, key, model_path)
605+
print("Downloaded from: {}\n".format(key))
606+
607+
608+
def impl_upload_trained_model_file(model_path, bucket_name, region):
609+
"""
610+
Uploads an trained model to S3.
611+
612+
:param model_path: The file path to the model to upload, ending with the name of the model.
613+
:param bucket_name: The S3 bucket name.
614+
:param region: The region, or `None` to pull the region from the environment.
615+
"""
616+
client = make_client("s3", region)
617+
key = "axon-trained-models/" + os.path.basename(model_path)
618+
client.upload_file(model_path, bucket_name, key)
619+
print("Uploaded to: {}\n".format(key))
620+
621+
622+
def impl_download_trained_model_file(model_path, bucket_name, region):
623+
"""
624+
Downloads an trained model from S3.
625+
626+
:param model_path: The file path to download to, ending with the name of the model.
627+
:param bucket_name: The S3 bucket name.
628+
:param region: The region, or `None` to pull the region from the environment.
629+
"""
630+
client = make_client("s3", region)
631+
key = "axon-trained-models/" + os.path.basename(model_path)
632+
client.download_file(bucket_name, key, model_path)
633+
print("Downloaded from: {}\n".format(key))
606634

607635

608636
def impl_download_training_script(script_name, bucket_name, region):
@@ -614,7 +642,7 @@ def impl_download_training_script(script_name, bucket_name, region):
614642
:param region: The region, or `None` to pull the region from the environment.
615643
"""
616644
client = make_client("s3", region)
617-
remote_path = "axon-uploaded-training-scripts/" + os.path.basename(script_name)
645+
remote_path = "axon-training-scripts/" + os.path.basename(script_name)
618646
client.download_file(bucket_name, remote_path, script_name)
619647
print("Downloaded from: {}\n".format(remote_path))
620648

@@ -682,7 +710,7 @@ def cli():
682710
@click.argument("task-family")
683711
@click.option("--revision", default=None,
684712
help="The revision of the task. Set to None to use the latest revision.")
685-
@click.option("--region", default="us-east-1", help="The region to connect to.")
713+
@click.option("--region", help="The region to connect to.")
686714
def start_axon(cluster_name, task_family, revision, region):
687715
impl_ensure_configuration(cluster_name, task_family, region)
688716
task_arn = impl_start_task(cluster_name, task_family, revision, region)
@@ -698,7 +726,7 @@ def start_axon(cluster_name, task_family, revision, region):
698726
@cli.command(name="ensure-configuration")
699727
@click.argument("cluster-name")
700728
@click.argument("task-family")
701-
@click.option("--region", default="us-east-1", help="The region to connect to.")
729+
@click.option("--region", help="The region to connect to.")
702730
def ensure_configuration(cluster_name, task_family, region):
703731
impl_ensure_configuration(cluster_name, task_family, region)
704732

@@ -708,7 +736,7 @@ def ensure_configuration(cluster_name, task_family, region):
708736
@click.argument("task-family")
709737
@click.option("--revision", default=None,
710738
help="The revision of the task. Set to None to use the latest revision.")
711-
@click.option("--region", default="us-east-1", help="The region to connect to.")
739+
@click.option("--region", help="The region to connect to.")
712740
@click.option("--stop-after/--no-stop-after", default=False,
713741
help="Whether to stop the task immediately after creating it.")
714742
def start_task(cluster_name, task_family, revision, region, stop_after):
@@ -726,55 +754,71 @@ def start_task(cluster_name, task_family, revision, region, stop_after):
726754
@cli.command(name="stop-task")
727755
@click.argument("cluster-name")
728756
@click.argument("task")
729-
@click.option("--region", default="us-east-1", help="The region to connect to.")
757+
@click.option("--region", help="The region to connect to.")
730758
def stop_task(cluster_name, task, region):
731759
impl_stop_task(cluster_name, task, region)
732760

733761

734762
@cli.command(name="get-container-ip")
735763
@click.argument("cluster-name")
736764
@click.argument("task")
737-
@click.option("--region", default="us-east-1", help="The region to connect to.")
765+
@click.option("--region", help="The region to connect to.")
738766
def get_container_ip(cluster_name, task, region):
739767
print(impl_get_task_ip(cluster_name, task, region))
740768

741769

742-
@cli.command(name="upload-model-file")
770+
@cli.command(name="upload-untrained-model-file")
771+
@click.argument("model-name")
772+
@click.argument("bucket-name")
773+
@click.option("--region", help="The region to connect to.")
774+
def upload_untrained_model_file(model_name, bucket_name, region):
775+
impl_upload_untrained_model_file(model_name, bucket_name, region)
776+
777+
778+
@cli.command(name="download-untrained-model-file")
779+
@click.argument("model-name")
780+
@click.argument("bucket-name")
781+
@click.option("--region", help="The region to connect to.")
782+
def download_untrained_model_file(model_name, bucket_name, region):
783+
impl_download_untrained_model_file(model_name, bucket_name, region)
784+
785+
786+
@cli.command(name="upload-trained-model-file")
743787
@click.argument("model-name")
744788
@click.argument("bucket-name")
745-
@click.option("--region", default="us-east-1", help="The region to connect to.")
746-
def upload_model_file(model_name, bucket_name, region):
747-
impl_upload_model_file(model_name, bucket_name, region)
789+
@click.option("--region", help="The region to connect to.")
790+
def upload_trained_model_file(model_name, bucket_name, region):
791+
impl_upload_trained_model_file(model_name, bucket_name, region)
748792

749793

750-
@cli.command(name="download-model-file")
794+
@cli.command(name="download-trained-model-file")
751795
@click.argument("model-name")
752796
@click.argument("bucket-name")
753-
@click.option("--region", default="us-east-1", help="The region to connect to.")
754-
def download_model_file(model_name, bucket_name, region):
755-
impl_download_model_file(model_name, bucket_name, region)
797+
@click.option("--region", help="The region to connect to.")
798+
def download_trained_model_file(model_name, bucket_name, region):
799+
impl_download_trained_model_file(model_name, bucket_name, region)
756800

757801

758802
@cli.command(name="download-training-script")
759803
@click.argument("script-name")
760804
@click.argument("bucket-name")
761-
@click.option("--region", default="us-east-1", help="The region to connect to.")
805+
@click.option("--region", help="The region to connect to.")
762806
def download_training_script(script_name, bucket_name, region):
763807
impl_download_training_script(script_name, bucket_name, region)
764808

765809

766810
@cli.command(name="download-dataset")
767811
@click.argument("dataset-name")
768812
@click.argument("bucket-name")
769-
@click.option("--region", default="us-east-1", help="The region to connect to.")
813+
@click.option("--region", help="The region to connect to.")
770814
def download_dataset(dataset_name, bucket_name, region):
771815
impl_download_dataset(dataset_name, bucket_name, region)
772816

773817

774818
@cli.command(name="upload-dataset")
775819
@click.argument("dataset-name")
776820
@click.argument("bucket-name")
777-
@click.option("--region", default="us-east-1", help="The region to connect to.")
821+
@click.option("--region", help="The region to connect to.")
778822
def upload_dataset(dataset_name, bucket_name, region):
779823
impl_upload_dataset(dataset_name, bucket_name, region)
780824

@@ -784,6 +828,6 @@ def upload_dataset(dataset_name, bucket_name, region):
784828
@click.argument("dataset-name")
785829
@click.argument("progress-text")
786830
@click.argument("bucket-name")
787-
@click.option("--region", default="us-east-1", help="The region to connect to.")
831+
@click.option("--region", help="The region to connect to.")
788832
def update_training_progress(model_name, dataset_name, progress_text, bucket_name, region):
789833
impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region)

0 commit comments

Comments
 (0)