Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit b032da7

Browse files
committed
Add update_training_progress
1 parent 102f697 commit b032da7

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

axon/client.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
import tempfile
3+
24
import click
35
import boto3
46
import ipify
@@ -473,76 +475,99 @@ def impl_get_task_ip(cluster_name, task_arn, region):
473475
return nics[0]["Association"]["PublicIp"]
474476

475477

476-
def impl_upload_model_file(local_file_path, bucket_name, region):
478+
def impl_upload_model_file(model_name, bucket_name, region):
477479
"""
478480
Uploads a model to S3.
479481
480-
:param local_file_path: The path to the model file on disk.
482+
:param model_name: The filename of the model to upload (must be in the current directory).
481483
:param bucket_name: The S3 bucket name.
482484
:param region: The region, or `None` to pull the region from the environment.
483485
"""
484486
client = make_client("s3", region)
485-
remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path)
486-
client.upload_file(local_file_path, bucket_name, remote_path)
487+
remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name)
488+
client.upload_file(model_name, bucket_name, remote_path)
487489
print("Uploaded to: {}\n".format(remote_path))
488490

489491

490-
def impl_download_model_file(local_file_path, bucket_name, region):
492+
def impl_download_model_file(model_name, bucket_name, region):
491493
"""
492494
Downloads a model from S3.
493495
494-
:param local_file_path: The path to the model file on disk.
496+
:param model_name: The filename of the model to download (must be in the current directory).
495497
:param bucket_name: The S3 bucket name.
496498
:param region: The region, or `None` to pull the region from the environment.
497499
"""
498500
client = make_client("s3", region)
499-
remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path)
500-
client.download_file(bucket_name, remote_path, local_file_path)
501+
remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name)
502+
client.download_file(bucket_name, remote_path, model_name)
501503
print("Downloaded from: {}\n".format(remote_path))
502504

503505

504-
def impl_download_training_script(local_script_path, bucket_name, region):
506+
def impl_download_training_script(script_name, bucket_name, region):
505507
"""
506508
Downloads a training script from S3.
507509
508-
:param local_script_path: The path to the training script on disk.
510+
:param script_name: The filename of the script to download (must be in the current directory).
509511
:param bucket_name: The S3 bucket name.
510512
:param region: The region, or `None` to pull the region from the environment.
511513
"""
512514
client = make_client("s3", region)
513-
remote_path = "axon-uploaded-training-scripts/" + os.path.basename(local_script_path)
514-
client.download_file(bucket_name, remote_path, local_script_path)
515+
remote_path = "axon-uploaded-training-scripts/" + os.path.basename(script_name)
516+
client.download_file(bucket_name, remote_path, script_name)
515517
print("Downloaded from: {}\n".format(remote_path))
516518

517519

518-
def impl_upload_dataset(local_dataset_path, bucket_name, region):
520+
def impl_upload_dataset(dataset_name, bucket_name, region):
519521
"""
520522
Uploads a dataset to S3.
521523
522-
:param local_dataset_path: The path to the dataset on disk.
524+
:param dataset_name: The filename of the dataset to upload (must be in the current directory).
523525
:param bucket_name: The S3 bucket name.
524526
:param region: The region, or `None` to pull the region from the environment.
525527
"""
526528
client = make_client("s3", region)
527-
remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path)
528-
client.upload_file(local_dataset_path, bucket_name, remote_path)
529+
remote_path = "axon-uploaded-datasets/" + os.path.basename(dataset_name)
530+
client.upload_file(dataset_name, bucket_name, remote_path)
529531
print("Uploaded to: {}\n".format(remote_path))
530532

531533

532-
def impl_download_dataset(local_dataset_path, bucket_name, region):
534+
def impl_download_dataset(dataset_name, bucket_name, region):
533535
"""
534536
Downloads a dataset from S3.
535537
536-
:param local_dataset_path: The path to the dataset on disk.
538+
:param dataset_name: The filename of the dataset to download (must be in the current directory).
537539
:param bucket_name: The S3 bucket name.
538540
:param region: The region, or `None` to pull the region from the environment.
539541
"""
540542
client = make_client("s3", region)
541-
remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path)
542-
client.download_file(bucket_name, remote_path, local_dataset_path)
543+
remote_path = "axon-uploaded-datasets/" + os.path.basename(dataset_name)
544+
client.download_file(bucket_name, remote_path, dataset_name)
543545
print("Downloaded from: {}\n".format(remote_path))
544546

545547

548+
def impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region):
549+
"""
550+
Updates the training progress in S3 for a model specified by its name.
551+
552+
:param model_name: The filename of the model.
553+
:param dataset_name: The filename of the dataset.
554+
:param progress_text: The text to write into the progress file.
555+
:param bucket_name: The S3 bucket name.
556+
:param region: The region, or `None` to pull the region from the environment.
557+
"""
558+
local_file, path = tempfile.mkstemp()
559+
try:
560+
with open(local_file, "w") as f:
561+
f.write(progress_text)
562+
client = make_client("s3", region)
563+
remote_path = "axon-training-progress/" + os.path.basename(model_name) + "/" + \
564+
os.path.basename(dataset_name) + "/progress.txt"
565+
client.upload_file(path, bucket_name, remote_path)
566+
print("Updated progress in: {}\n".format(remote_path))
567+
finally:
568+
os.remove(path)
569+
570+
546571
@click.group()
547572
def cli():
548573
return
@@ -613,40 +638,50 @@ def get_container_ip(cluster_name, task, region):
613638

614639

615640
@cli.command(name="upload-model-file")
616-
@click.argument("local-file-path")
641+
@click.argument("model-name")
617642
@click.argument("bucket-name")
618643
@click.option("--region", default="us-east-1", help="The region to connect to.")
619-
def upload_model_file(local_file_path, bucket_name, region):
620-
impl_upload_model_file(local_file_path, bucket_name, region)
644+
def upload_model_file(model_name, bucket_name, region):
645+
impl_upload_model_file(model_name, bucket_name, region)
621646

622647

623648
@cli.command(name="download-model-file")
624-
@click.argument("local-file-path")
649+
@click.argument("model-name")
625650
@click.argument("bucket-name")
626651
@click.option("--region", default="us-east-1", help="The region to connect to.")
627-
def download_model_file(local_file_path, bucket_name, region):
628-
impl_download_model_file(local_file_path, bucket_name, region)
652+
def download_model_file(model_name, bucket_name, region):
653+
impl_download_model_file(model_name, bucket_name, region)
629654

630655

631656
@cli.command(name="download-training-script")
632-
@click.argument("local-script-path")
657+
@click.argument("script-name")
633658
@click.argument("bucket-name")
634659
@click.option("--region", default="us-east-1", help="The region to connect to.")
635-
def download_training_script(local_script_path, bucket_name, region):
636-
impl_download_training_script(local_script_path, bucket_name, region)
660+
def download_training_script(script_name, bucket_name, region):
661+
impl_download_training_script(script_name, bucket_name, region)
637662

638663

639664
@cli.command(name="download-dataset")
640-
@click.argument("local-dataset-path")
665+
@click.argument("dataset-name")
641666
@click.argument("bucket-name")
642667
@click.option("--region", default="us-east-1", help="The region to connect to.")
643-
def download_dataset(local_dataset_path, bucket_name, region):
644-
impl_download_dataset(local_dataset_path, bucket_name, region)
668+
def download_dataset(dataset_name, bucket_name, region):
669+
impl_download_dataset(dataset_name, bucket_name, region)
645670

646671

647672
@cli.command(name="upload-dataset")
648-
@click.argument("local-dataset-path")
673+
@click.argument("dataset-name")
674+
@click.argument("bucket-name")
675+
@click.option("--region", default="us-east-1", help="The region to connect to.")
676+
def upload_dataset(dataset_name, bucket_name, region):
677+
impl_upload_dataset(dataset_name, bucket_name, region)
678+
679+
680+
@cli.command(name="update-training-progress")
681+
@click.argument("model-name")
682+
@click.argument("dataset-name")
683+
@click.argument("progress-text")
649684
@click.argument("bucket-name")
650685
@click.option("--region", default="us-east-1", help="The region to connect to.")
651-
def upload_dataset(local_dataset_path, bucket_name, region):
652-
impl_upload_dataset(local_dataset_path, bucket_name, region)
686+
def update_training_progress(model_name, dataset_name, progress_text, bucket_name, region):
687+
impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region)

0 commit comments

Comments
 (0)