|
1 | 1 | import json
|
| 2 | +import tempfile |
| 3 | + |
2 | 4 | import click
|
3 | 5 | import boto3
|
4 | 6 | import ipify
|
@@ -473,76 +475,99 @@ def impl_get_task_ip(cluster_name, task_arn, region):
|
473 | 475 | return nics[0]["Association"]["PublicIp"]
|
474 | 476 |
|
475 | 477 |
|
476 |
| -def impl_upload_model_file(local_file_path, bucket_name, region): |
| 478 | +def impl_upload_model_file(model_name, bucket_name, region): |
477 | 479 | """
|
478 | 480 | Uploads a model to S3.
|
479 | 481 |
|
480 |
| - :param local_file_path: The path to the model file on disk. |
| 482 | + :param model_name: The filename of the model to upload (must be in the current directory). |
481 | 483 | :param bucket_name: The S3 bucket name.
|
482 | 484 | :param region: The region, or `None` to pull the region from the environment.
|
483 | 485 | """
|
484 | 486 | client = make_client("s3", region)
|
485 |
| - remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path) |
486 |
| - client.upload_file(local_file_path, bucket_name, remote_path) |
| 487 | + remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name) |
| 488 | + client.upload_file(model_name, bucket_name, remote_path) |
487 | 489 | print("Uploaded to: {}\n".format(remote_path))
|
488 | 490 |
|
489 | 491 |
|
490 |
| -def impl_download_model_file(local_file_path, bucket_name, region): |
| 492 | +def impl_download_model_file(model_name, bucket_name, region): |
491 | 493 | """
|
492 | 494 | Downloads a model from S3.
|
493 | 495 |
|
494 |
| - :param local_file_path: The path to the model file on disk. |
| 496 | + :param model_name: The filename of the model to download (must be in the current directory). |
495 | 497 | :param bucket_name: The S3 bucket name.
|
496 | 498 | :param region: The region, or `None` to pull the region from the environment.
|
497 | 499 | """
|
498 | 500 | client = make_client("s3", region)
|
499 |
| - remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path) |
500 |
| - client.download_file(bucket_name, remote_path, local_file_path) |
| 501 | + remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name) |
| 502 | + client.download_file(bucket_name, remote_path, model_name) |
501 | 503 | print("Downloaded from: {}\n".format(remote_path))
|
502 | 504 |
|
503 | 505 |
|
504 |
| -def impl_download_training_script(local_script_path, bucket_name, region): |
| 506 | +def impl_download_training_script(script_name, bucket_name, region): |
505 | 507 | """
|
506 | 508 | Downloads a training script from S3.
|
507 | 509 |
|
508 |
| - :param local_script_path: The path to the training script on disk. |
| 510 | + :param script_name: The filename of the script to download (must be in the current directory). |
509 | 511 | :param bucket_name: The S3 bucket name.
|
510 | 512 | :param region: The region, or `None` to pull the region from the environment.
|
511 | 513 | """
|
512 | 514 | client = make_client("s3", region)
|
513 |
| - remote_path = "axon-uploaded-training-scripts/" + os.path.basename(local_script_path) |
514 |
| - client.download_file(bucket_name, remote_path, local_script_path) |
| 515 | + remote_path = "axon-uploaded-training-scripts/" + os.path.basename(script_name) |
| 516 | + client.download_file(bucket_name, remote_path, script_name) |
515 | 517 | print("Downloaded from: {}\n".format(remote_path))
|
516 | 518 |
|
517 | 519 |
|
518 |
| -def impl_upload_dataset(local_dataset_path, bucket_name, region): |
| 520 | +def impl_upload_dataset(dataset_name, bucket_name, region): |
519 | 521 | """
|
520 | 522 | Uploads a dataset to S3.
|
521 | 523 |
|
522 |
| - :param local_dataset_path: The path to the dataset on disk. |
| 524 | + :param dataset_name: The filename of the dataset to upload (must be in the current directory). |
523 | 525 | :param bucket_name: The S3 bucket name.
|
524 | 526 | :param region: The region, or `None` to pull the region from the environment.
|
525 | 527 | """
|
526 | 528 | client = make_client("s3", region)
|
527 |
| - remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path) |
528 |
| - client.upload_file(local_dataset_path, bucket_name, remote_path) |
| 529 | + remote_path = "axon-uploaded-datasets/" + os.path.basename(dataset_name) |
| 530 | + client.upload_file(dataset_name, bucket_name, remote_path) |
529 | 531 | print("Uploaded to: {}\n".format(remote_path))
|
530 | 532 |
|
531 | 533 |
|
532 |
| -def impl_download_dataset(local_dataset_path, bucket_name, region): |
| 534 | +def impl_download_dataset(dataset_name, bucket_name, region): |
533 | 535 | """
|
534 | 536 | Downloads a dataset from S3.
|
535 | 537 |
|
536 |
| - :param local_dataset_path: The path to the dataset on disk. |
| 538 | + :param dataset_name: The filename of the dataset to download (must be in the current directory). |
537 | 539 | :param bucket_name: The S3 bucket name.
|
538 | 540 | :param region: The region, or `None` to pull the region from the environment.
|
539 | 541 | """
|
540 | 542 | client = make_client("s3", region)
|
541 |
| - remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path) |
542 |
| - client.download_file(bucket_name, remote_path, local_dataset_path) |
| 543 | + remote_path = "axon-uploaded-datasets/" + os.path.basename(dataset_name) |
| 544 | + client.download_file(bucket_name, remote_path, dataset_name) |
543 | 545 | print("Downloaded from: {}\n".format(remote_path))
|
544 | 546 |
|
545 | 547 |
|
| 548 | +def impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region): |
| 549 | + """ |
| 550 | + Updates the training progress in S3 for a model specified by its name. |
| 551 | +
|
| 552 | + :param model_name: The filename of the model. |
| 553 | + :param dataset_name: The filename of the dataset. |
| 554 | + :param progress_text: The text to write into the progress file. |
| 555 | + :param bucket_name: The S3 bucket name. |
| 556 | + :param region: The region, or `None` to pull the region from the environment. |
| 557 | + """ |
| 558 | + local_file, path = tempfile.mkstemp() |
| 559 | + try: |
| 560 | + with open(local_file, "w") as f: |
| 561 | + f.write(progress_text) |
| 562 | + client = make_client("s3", region) |
| 563 | + remote_path = "axon-training-progress/" + os.path.basename(model_name) + "/" + \ |
| 564 | + os.path.basename(dataset_name) + "/progress.txt" |
| 565 | + client.upload_file(path, bucket_name, remote_path) |
| 566 | + print("Updated progress in: {}\n".format(remote_path)) |
| 567 | + finally: |
| 568 | + os.remove(path) |
| 569 | + |
| 570 | + |
546 | 571 | @click.group()
|
547 | 572 | def cli():
|
548 | 573 | return
|
@@ -613,40 +638,50 @@ def get_container_ip(cluster_name, task, region):
|
613 | 638 |
|
614 | 639 |
|
615 | 640 | @cli.command(name="upload-model-file")
|
616 |
| -@click.argument("local-file-path") |
| 641 | +@click.argument("model-name") |
617 | 642 | @click.argument("bucket-name")
|
618 | 643 | @click.option("--region", default="us-east-1", help="The region to connect to.")
|
619 |
| -def upload_model_file(local_file_path, bucket_name, region): |
620 |
| - impl_upload_model_file(local_file_path, bucket_name, region) |
| 644 | +def upload_model_file(model_name, bucket_name, region): |
| 645 | + impl_upload_model_file(model_name, bucket_name, region) |
621 | 646 |
|
622 | 647 |
|
623 | 648 | @cli.command(name="download-model-file")
|
624 |
| -@click.argument("local-file-path") |
| 649 | +@click.argument("model-name") |
625 | 650 | @click.argument("bucket-name")
|
626 | 651 | @click.option("--region", default="us-east-1", help="The region to connect to.")
|
627 |
| -def download_model_file(local_file_path, bucket_name, region): |
628 |
| - impl_download_model_file(local_file_path, bucket_name, region) |
| 652 | +def download_model_file(model_name, bucket_name, region): |
| 653 | + impl_download_model_file(model_name, bucket_name, region) |
629 | 654 |
|
630 | 655 |
|
631 | 656 | @cli.command(name="download-training-script")
|
632 |
| -@click.argument("local-script-path") |
| 657 | +@click.argument("script-name") |
633 | 658 | @click.argument("bucket-name")
|
634 | 659 | @click.option("--region", default="us-east-1", help="The region to connect to.")
|
635 |
| -def download_training_script(local_script_path, bucket_name, region): |
636 |
| - impl_download_training_script(local_script_path, bucket_name, region) |
| 660 | +def download_training_script(script_name, bucket_name, region): |
| 661 | + impl_download_training_script(script_name, bucket_name, region) |
637 | 662 |
|
638 | 663 |
|
639 | 664 | @cli.command(name="download-dataset")
|
640 |
| -@click.argument("local-dataset-path") |
| 665 | +@click.argument("dataset-name") |
641 | 666 | @click.argument("bucket-name")
|
642 | 667 | @click.option("--region", default="us-east-1", help="The region to connect to.")
|
643 |
| -def download_dataset(local_dataset_path, bucket_name, region): |
644 |
| - impl_download_dataset(local_dataset_path, bucket_name, region) |
| 668 | +def download_dataset(dataset_name, bucket_name, region): |
| 669 | + impl_download_dataset(dataset_name, bucket_name, region) |
645 | 670 |
|
646 | 671 |
|
647 | 672 | @cli.command(name="upload-dataset")
|
648 |
| -@click.argument("local-dataset-path") |
| 673 | +@click.argument("dataset-name") |
| 674 | +@click.argument("bucket-name") |
| 675 | +@click.option("--region", default="us-east-1", help="The region to connect to.") |
| 676 | +def upload_dataset(dataset_name, bucket_name, region): |
| 677 | + impl_upload_dataset(dataset_name, bucket_name, region) |
| 678 | + |
| 679 | + |
| 680 | +@cli.command(name="update-training-progress") |
| 681 | +@click.argument("model-name") |
| 682 | +@click.argument("dataset-name") |
| 683 | +@click.argument("progress-text") |
649 | 684 | @click.argument("bucket-name")
|
650 | 685 | @click.option("--region", default="us-east-1", help="The region to connect to.")
|
651 |
| -def upload_dataset(local_dataset_path, bucket_name, region): |
652 |
| - impl_upload_dataset(local_dataset_path, bucket_name, region) |
| 686 | +def update_training_progress(model_name, dataset_name, progress_text, bucket_name, region): |
| 687 | + impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region) |
0 commit comments