@@ -674,12 +674,11 @@ def impl_download_dataset(dataset_path, bucket_name, region):
674
674
print ("Downloaded from: {}\n " .format (key ))
675
675
676
676
677
- def impl_update_training_progress (model_name , dataset_name , progress_text , bucket_name , region ):
677
+ def impl_update_training_progress (job_id , progress_text , bucket_name , region ):
678
678
"""
679
679
Updates the training progress in S3 for a model specified by its name.
680
680
681
- :param model_name: The filename of the model.
682
- :param dataset_name: The filename of the dataset.
681
+ :param job_id: The unique Job ID.
683
682
:param progress_text: The text to write into the progress file.
684
683
:param bucket_name: The S3 bucket name.
685
684
:param region: The region, or `None` to pull the region from the environment.
@@ -689,46 +688,43 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
689
688
with open (local_file , "w" ) as f :
690
689
f .write (progress_text )
691
690
client = make_client ("s3" , region )
692
- remote_path = create_progress_prefix (model_name , dataset_name ) + "/progress.txt"
691
+ remote_path = create_progress_prefix (job_id ) + "/progress.txt"
693
692
client .upload_file (path , bucket_name , remote_path )
694
693
print ("Updated progress in: {}\n " .format (remote_path ))
695
694
finally :
696
695
os .remove (path )
697
696
698
697
699
- def impl_create_heartbeat (model_name , dataset_name , bucket_name , region ):
698
+ def impl_create_heartbeat (job_id , bucket_name , region ):
700
699
"""
701
700
Creates a heartbeat that Axon uses to check if the training script is running properly.
702
701
703
- :param model_name: The filename of the model.
704
- :param dataset_name: The filename of the dataset.
702
+ :param job_id: The unique Job ID.
705
703
:param bucket_name: The S3 bucket name.
706
704
:param region: The region, or `None` to pull the region from the environment.
707
705
"""
708
706
client = make_client ("s3" , region )
709
- remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
707
+ remote_path = create_progress_prefix (job_id ) + "/heartbeat.txt"
710
708
client .put_object (Body = "1" , Bucket = bucket_name , Key = remote_path )
711
709
print ("Created heartbeat file in: {}\n " .format (remote_path ))
712
710
713
711
714
- def impl_remove_heartbeat (model_name , dataset_name , bucket_name , region ):
712
+ def impl_remove_heartbeat (job_id , bucket_name , region ):
715
713
"""
716
714
Removes a heartbeat that Axon uses to check if the training script is running properly.
717
715
718
- :param model_name: The filename of the model.
719
- :param dataset_name: The filename of the dataset.
716
+ :param job_id: The unique Job ID.
720
717
:param bucket_name: The S3 bucket name.
721
718
:param region: The region, or `None` to pull the region from the environment.
722
719
"""
723
720
client = make_client ("s3" , region )
724
- remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
721
+ remote_path = create_progress_prefix (job_id ) + "/heartbeat.txt"
725
722
client .put_object (Body = "0" , Bucket = bucket_name , Key = remote_path )
726
723
print ("Removed heartbeat file in: {}\n " .format (remote_path ))
727
724
728
725
729
- def create_progress_prefix (model_name , dataset_name ):
730
- return "axon-training-progress/" + os .path .basename (model_name ) + "/" + \
731
- os .path .basename (dataset_name )
726
+ def create_progress_prefix (job_id ):
727
+ return "axon-training-progress/" + job_id
732
728
733
729
734
730
@click .group ()
@@ -918,53 +914,44 @@ def download_dataset(dataset_path, region):
918
914
919
915
920
916
@cli .command (name = "update-training-progress" )
921
- @click .argument ("model-name" )
922
- @click .argument ("dataset-name" )
917
+ @click .argument ("job-id" )
923
918
@click .argument ("progress-text" )
924
919
@click .option ("--region" , help = "The region to connect to." ,
925
920
type = click .Choice (region_choices ))
926
- def update_training_progress (model_name , dataset_name , progress_text , region ):
921
+ def update_training_progress (job_id , progress_text , region ):
927
922
"""
928
923
Updates the training progress. Meant to be used while a training script is running to provide
929
924
progress updates to Axon.
930
925
931
- MODEL_NAME The filename of the model currently being trained.
932
-
933
- DATASET_NAME The name of the dataset currently being trained on.
926
+ JOB_ID The unique Job ID.
934
927
935
928
PROGRESS_TEXT The text to write to the progress file.
936
929
"""
937
- impl_update_training_progress (model_name , dataset_name , progress_text , ensure_s3_bucket (region ),
930
+ impl_update_training_progress (job_id , progress_text , ensure_s3_bucket (region ),
938
931
region )
939
932
940
933
941
934
@cli .command (name = "create-heartbeat" )
942
- @click .argument ("model-name" )
943
- @click .argument ("dataset-name" )
935
+ @click .argument ("job-id" )
944
936
@click .option ("--region" , help = "The region to connect to." ,
945
937
type = click .Choice (region_choices ))
946
- def create_heartbeat (model_name , dataset_name , region ):
938
+ def create_heartbeat (job_id , region ):
947
939
"""
948
940
Creates a heartbeat that Axon uses to check if the training script is running properly.
949
941
950
- MODEL_NAME The filename of the model currently being trained.
951
-
952
- DATASET_NAME The name of the dataset currently being trained on.
942
+ JOB_ID The unique Job ID.
953
943
"""
954
- impl_create_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
944
+ impl_create_heartbeat (job_id , ensure_s3_bucket (region ), region )
955
945
956
946
957
947
@cli .command (name = "remove-heartbeat" )
958
- @click .argument ("model-name" )
959
- @click .argument ("dataset-name" )
948
+ @click .argument ("job-id" )
960
949
@click .option ("--region" , help = "The region to connect to." ,
961
950
type = click .Choice (region_choices ))
962
- def remove_heartbeat (model_name , dataset_name , region ):
951
+ def remove_heartbeat (job_id , region ):
963
952
"""
964
953
Removes a heartbeat that Axon uses to check if the training script is running properly.
965
954
966
- MODEL_NAME The filename of the model currently being trained.
967
-
968
- DATASET_NAME The name of the dataset currently being trained on.
955
+ JOB_ID The unique Job ID.
969
956
"""
970
- impl_remove_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
957
+ impl_remove_heartbeat (job_id , ensure_s3_bucket (region ), region )
0 commit comments