Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 342aaca

Browse files
committed
Use job id instead of model name and dataset name
1 parent 15c1d96 commit 342aaca

File tree

1 file changed

+23
-36
lines changed

1 file changed

+23
-36
lines changed

axon/client.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -674,12 +674,11 @@ def impl_download_dataset(dataset_path, bucket_name, region):
674674
print("Downloaded from: {}\n".format(key))
675675

676676

677-
def impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region):
677+
def impl_update_training_progress(job_id, progress_text, bucket_name, region):
678678
"""
679679
Updates the training progress in S3 for a model specified by its name.
680680
681-
:param model_name: The filename of the model.
682-
:param dataset_name: The filename of the dataset.
681+
:param job_id: The unique Job ID.
683682
:param progress_text: The text to write into the progress file.
684683
:param bucket_name: The S3 bucket name.
685684
:param region: The region, or `None` to pull the region from the environment.
@@ -689,46 +688,43 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
689688
with open(local_file, "w") as f:
690689
f.write(progress_text)
691690
client = make_client("s3", region)
692-
remote_path = create_progress_prefix(model_name, dataset_name) + "/progress.txt"
691+
remote_path = create_progress_prefix(job_id) + "/progress.txt"
693692
client.upload_file(path, bucket_name, remote_path)
694693
print("Updated progress in: {}\n".format(remote_path))
695694
finally:
696695
os.remove(path)
697696

698697

699-
def impl_create_heartbeat(model_name, dataset_name, bucket_name, region):
698+
def impl_create_heartbeat(job_id, bucket_name, region):
700699
"""
701700
Creates a heartbeat that Axon uses to check if the training script is running properly.
702701
703-
:param model_name: The filename of the model.
704-
:param dataset_name: The filename of the dataset.
702+
:param job_id: The unique Job ID.
705703
:param bucket_name: The S3 bucket name.
706704
:param region: The region, or `None` to pull the region from the environment.
707705
"""
708706
client = make_client("s3", region)
709-
remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt"
707+
remote_path = create_progress_prefix(job_id) + "/heartbeat.txt"
710708
client.put_object(Body="1", Bucket=bucket_name, Key=remote_path)
711709
print("Created heartbeat file in: {}\n".format(remote_path))
712710

713711

714-
def impl_remove_heartbeat(model_name, dataset_name, bucket_name, region):
712+
def impl_remove_heartbeat(job_id, bucket_name, region):
715713
"""
716714
Removes a heartbeat that Axon uses to check if the training script is running properly.
717715
718-
:param model_name: The filename of the model.
719-
:param dataset_name: The filename of the dataset.
716+
:param job_id: The unique Job ID.
720717
:param bucket_name: The S3 bucket name.
721718
:param region: The region, or `None` to pull the region from the environment.
722719
"""
723720
client = make_client("s3", region)
724-
remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt"
721+
remote_path = create_progress_prefix(job_id) + "/heartbeat.txt"
725722
client.put_object(Body="0", Bucket=bucket_name, Key=remote_path)
726723
print("Removed heartbeat file in: {}\n".format(remote_path))
727724

728725

729-
def create_progress_prefix(model_name, dataset_name):
730-
return "axon-training-progress/" + os.path.basename(model_name) + "/" + \
731-
os.path.basename(dataset_name)
726+
def create_progress_prefix(job_id):
727+
return "axon-training-progress/" + job_id
732728

733729

734730
@click.group()
@@ -918,53 +914,44 @@ def download_dataset(dataset_path, region):
918914

919915

920916
@cli.command(name="update-training-progress")
921-
@click.argument("model-name")
922-
@click.argument("dataset-name")
917+
@click.argument("job-id")
923918
@click.argument("progress-text")
924919
@click.option("--region", help="The region to connect to.",
925920
type=click.Choice(region_choices))
926-
def update_training_progress(model_name, dataset_name, progress_text, region):
921+
def update_training_progress(job_id, progress_text, region):
927922
"""
928923
Updates the training progress. Meant to be used while a training script is running to provide
929924
progress updates to Axon.
930925
931-
MODEL_NAME The filename of the model currently being trained.
932-
933-
DATASET_NAME The name of the dataset currently being trained on.
926+
JOB_ID The unique Job ID.
934927
935928
PROGRESS_TEXT The text to write to the progress file.
936929
"""
937-
impl_update_training_progress(model_name, dataset_name, progress_text, ensure_s3_bucket(region),
930+
impl_update_training_progress(job_id, progress_text, ensure_s3_bucket(region),
938931
region)
939932

940933

941934
@cli.command(name="create-heartbeat")
942-
@click.argument("model-name")
943-
@click.argument("dataset-name")
935+
@click.argument("job-id")
944936
@click.option("--region", help="The region to connect to.",
945937
type=click.Choice(region_choices))
946-
def create_heartbeat(model_name, dataset_name, region):
938+
def create_heartbeat(job_id, region):
947939
"""
948940
Creates a heartbeat that Axon uses to check if the training script is running properly.
949941
950-
MODEL_NAME The filename of the model currently being trained.
951-
952-
DATASET_NAME The name of the dataset currently being trained on.
942+
JOB_ID The unique Job ID.
953943
"""
954-
impl_create_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)
944+
impl_create_heartbeat(job_id, ensure_s3_bucket(region), region)
955945

956946

957947
@cli.command(name="remove-heartbeat")
958-
@click.argument("model-name")
959-
@click.argument("dataset-name")
948+
@click.argument("job-id")
960949
@click.option("--region", help="The region to connect to.",
961950
type=click.Choice(region_choices))
962-
def remove_heartbeat(model_name, dataset_name, region):
951+
def remove_heartbeat(job_id, region):
963952
"""
964953
Removes a heartbeat that Axon uses to check if the training script is running properly.
965954
966-
MODEL_NAME The filename of the model currently being trained.
967-
968-
DATASET_NAME The name of the dataset currently being trained on.
955+
JOB_ID The unique Job ID.
969956
"""
970-
impl_remove_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)
957+
impl_remove_heartbeat(job_id, ensure_s3_bucket(region), region)

0 commit comments

Comments
 (0)