Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 3ac8c4a

Browse files
committed
Let the region be None to select from the env
1 parent 3d51693 commit 3ac8c4a

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

axon/client.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ipify
55
import webbrowser
66
import os.path
7+
import axon.progress_reporter
78

89
all_perm = {
910
"FromPort": -1,
@@ -22,16 +23,23 @@
2223
}
2324

2425

26+
def make_client(name, region):
27+
if region is None:
28+
return boto3.client(name)
29+
else:
30+
return boto3.client(name, region_name=region)
31+
32+
2533
def ensure_log_group(group_name, region):
2634
"""
2735
Ensures that a log group is present. If there is a matching log group, nothing is created.
2836
If there is no matching log group, one is created.
2937
3038
:param group_name: The name of the log group.
31-
:param region: The region.
39+
:param region: The region, or `None` to pull the region from the environment.
3240
:return: Nothing.
3341
"""
34-
client = boto3.client("logs", region_name=region)
42+
client = make_client("logs", region)
3543
matching_log_groups = client.describe_log_groups(
3644
logGroupNamePrefix=group_name
3745
)
@@ -66,7 +74,7 @@ def ensure_ecs_gress(sg_id, region):
6674
egress permissions are revoked. The permissions that Axon needs are authorized.
6775
6876
:param sg_id: The SecurityGroup's GroupId.
69-
:param region: The region.
77+
:param region: The region, or `None` to pull the region from the environment.
7078
:return: Nothing.
7179
"""
7280
ec2 = boto3.resource('ec2', region_name=region)
@@ -92,7 +100,7 @@ def ensure_ec2_gress(sg_id, region):
92100
egress permissions are revoked. The permissions that Axon needs are authorized.
93101
94102
:param sg_id: The SecurityGroup's GroupId.
95-
:param region: The region.
103+
:param region: The region, or `None` to pull the region from the environment.
96104
:return: Nothing.
97105
"""
98106
ec2 = boto3.resource('ec2', region_name=region)
@@ -144,12 +152,11 @@ def get_single_security_group(client, sg_name, desc):
144152
def ensure_ecs_security_group(region):
145153
"""
146154
Ensures that the ECS SecurityGroup exists.
147-
148-
:param region: The region.
155+
:param region: The region, or `None` to pull the region from the environment.
149156
:return: The GroupId of the SecurityGroup.
150157
"""
151158
sg_name = "axon-ecs-autogenerated"
152-
client = boto3.client("ec2", region_name=region)
159+
client = make_client("ec2", region)
153160
sg_id = get_single_security_group(client, sg_name, "Axon autogenerated for ECS.")
154161
ensure_ecs_gress(sg_id, region)
155162
return sg_id
@@ -158,12 +165,11 @@ def ensure_ecs_security_group(region):
158165
def ensure_ec2_security_group(region):
159166
"""
160167
Ensures that the EC2 SecurityGroup exists.
161-
162-
:param region: The region.
168+
:param region: The region, or `None` to pull the region from the environment.
163169
:return: The GroupId of the SecurityGroup.
164170
"""
165171
sg_name = "axon-ec2-autogenerated"
166-
client = boto3.client("ec2", region_name=region)
172+
client = make_client("ec2", region)
167173
sg_id = get_single_security_group(client, sg_name, "Axon autogenerated for EC2.")
168174
ensure_ec2_gress(sg_id, region)
169175
return sg_id
@@ -172,12 +178,11 @@ def ensure_ec2_security_group(region):
172178
def select_subnet(region):
173179
"""
174180
Picks the first available subnet.
175-
176-
:param region: The region.
181+
:param region: The region, or `None` to pull the region from the environment.
177182
:return: The SubnetId.
178183
"""
179-
ec2 = boto3.client("ec2", region_name=region)
180-
return ec2.describe_subnets(Filters=[])["Subnets"][0]["SubnetId"]
184+
client = make_client("ec2", region)
185+
return client.describe_subnets(Filters=[])["Subnets"][0]["SubnetId"]
181186

182187

183188
def ensure_role(client, role_name):
@@ -207,12 +212,11 @@ def ensure_task_role(region):
207212
208213
TODO: Fix this.
209214
This method does not check that a matching role has the correct policies.
210-
211-
:param region: The region.
215+
:param region: The region, or `None` to pull the region from the environment.
212216
:return: The role Arn.
213217
"""
214218
role_name = "axon-ecs-autogenerated-task-role"
215-
client = boto3.client("iam", region_name=region)
219+
client = make_client("iam", region)
216220
role_arn = ensure_role(client, role_name)
217221
if role_arn is None:
218222
# Need to create the role
@@ -250,7 +254,7 @@ def ensure_task_role(region):
250254

251255
def ensure_ec2_role(region):
252256
role_name = "axon-ec2-role-manual"
253-
client = boto3.client("iam", region_name=region)
257+
client = make_client("iam", region)
254258
role_arn = ensure_role(client, role_name)
255259
if role_arn is None:
256260
# Need to create the role
@@ -301,7 +305,7 @@ def ensure_task(ecs_client, task_family, region, vcpu, memory):
301305
302306
:param ecs_client: The ECS client to use.
303307
:param task_family: The task family name.
304-
:param region: The region.
308+
:param region: The region, or `None` to pull the region from the environment.
305309
:param vcpu: The amount of cpu in vcpu units.
306310
:param memory: The amount of memory in MB.
307311
:return: The task definition's Arn.
@@ -357,10 +361,10 @@ def wait_for_task_to_start(task_arn, cluster, region):
357361
358362
:param task_arn: The Arn of the task to wait for.
359363
:param cluster: The simple name of the cluster the task is in.
360-
:param region: The region.
364+
:param region: The region, or `None` to pull the region from the environment.
361365
:return: Nothing.
362366
"""
363-
client = boto3.client("ecs", region_name=region)
367+
client = make_client("ecs", region)
364368
waiter = client.get_waiter("tasks_running")
365369
waiter.wait(cluster=cluster, tasks=[task_arn])
366370

@@ -371,9 +375,9 @@ def impl_ensure_configuration(cluster_name, task_family, region):
371375
372376
:param cluster_name: The simple name of the cluster to start the task in.
373377
:param task_family: The family of the task to start.
374-
:param region: The region.
378+
:param region: The region, or `None` to pull the region from the environment.
375379
"""
376-
client = boto3.client("ecs", region_name=region)
380+
client = make_client("ecs", region)
377381
ensure_cluster(client, cluster_name)
378382
ensure_task(client, task_family, region, 2048, 4096)
379383
ensure_ec2_security_group(region)
@@ -392,10 +396,10 @@ def impl_start_task(cluster_name, task_family, revision, region):
392396
:param cluster_name: The simple name of the cluster to start the task in.
393397
:param task_family: The family of the task to start.
394398
:param revision: A task definition revision number, or None to use the latest revision.
395-
:param region: The region.
399+
:param region: The region, or `None` to pull the region from the environment.
396400
:return: The started task's Arn.
397401
"""
398-
client = boto3.client("ecs", region_name=region)
402+
client = make_client("ecs", region)
399403

400404
impl_ensure_configuration(cluster_name, task_family, region)
401405

@@ -431,10 +435,10 @@ def impl_stop_task(cluster_name, task_arn, region):
431435
432436
:param cluster_name: The simple name of the cluster.
433437
:param task_arn: The Arn of the task.
434-
:param region: The region.
438+
:param region: The region, or `None` to pull the region from the environment.
435439
:return: Nothing.
436440
"""
437-
client = boto3.client("ecs", region_name=region)
441+
client = make_client("ecs", region)
438442
client.stop_task(cluster=cluster_name, task=task_arn)
439443

440444

@@ -444,10 +448,10 @@ def impl_get_task_ip(cluster_name, task_arn, region):
444448
445449
:param cluster_name: The simple name of the cluster.
446450
:param task_arn: The task's Arn.
447-
:param region: The region.
451+
:param region: The region, or `None` to pull the region from the environment.
448452
:return: The public IP of the task.
449453
"""
450-
client = boto3.client("ecs", region_name=region)
454+
client = make_client("ecs", region)
451455
task_arn = client.describe_tasks(
452456
cluster=cluster_name,
453457
tasks=[task_arn]
@@ -458,7 +462,7 @@ def impl_get_task_ip(cluster_name, task_arn, region):
458462
eni = next(
459463
x["value"] for x in interface_attachment["details"] if x["name"] == "networkInterfaceId")
460464

461-
ec2 = boto3.client("ec2", region_name=region)
465+
ec2 = make_client("ec2", region)
462466
nics = ec2.describe_network_interfaces(
463467
Filters=[
464468
{
@@ -476,9 +480,9 @@ def impl_upload_model_file(local_file_path, bucket_name, region):
476480
477481
:param local_file_path: The path to the model file on disk.
478482
:param bucket_name: The S3 bucket name.
479-
:param region: The region.
483+
:param region: The region, or `None` to pull the region from the environment.
480484
"""
481-
client = boto3.client("s3", region_name=region)
485+
client = make_client("s3", region)
482486
remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path)
483487
client.upload_file(local_file_path, bucket_name, remote_path)
484488
print("Uploaded to: {}\n".format(remote_path))
@@ -490,9 +494,9 @@ def impl_download_model_file(local_file_path, bucket_name, region):
490494
491495
:param local_file_path: The path to the model file on disk.
492496
:param bucket_name: The S3 bucket name.
493-
:param region: The region.
497+
:param region: The region, or `None` to pull the region from the environment.
494498
"""
495-
client = boto3.client("s3", region_name=region)
499+
client = make_client("s3", region)
496500
remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path)
497501
client.download_file(bucket_name, remote_path, local_file_path)
498502
print("Downloaded from: {}\n".format(remote_path))
@@ -504,9 +508,9 @@ def impl_download_training_script(local_script_path, bucket_name, region):
504508
505509
:param local_script_path: The path to the training script on disk.
506510
:param bucket_name: The S3 bucket name.
507-
:param region: The region.
511+
:param region: The region, or `None` to pull the region from the environment.
508512
"""
509-
client = boto3.client("s3", region_name=region)
513+
client = make_client("s3", region)
510514
remote_path = "axon-uploaded-training-scripts/" + os.path.basename(local_script_path)
511515
client.download_file(bucket_name, remote_path, local_script_path)
512516
print("Downloaded from: {}\n".format(remote_path))
@@ -518,9 +522,9 @@ def impl_download_dataset(local_dataset_path, bucket_name, region):
518522
519523
:param local_dataset_path: The path to the dataset on disk.
520524
:param bucket_name: The S3 bucket name.
521-
:param region: The region.
525+
:param region: The region, or `None` to pull the region from the environment.
522526
"""
523-
client = boto3.client("s3", region_name=region)
527+
client = make_client("s3", region)
524528
remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path)
525529
client.download_file(bucket_name, remote_path, local_dataset_path)
526530
print("Downloaded from: {}\n".format(remote_path))

0 commit comments

Comments
 (0)