/
azure_batch.py
925 lines (792 loc) · 38.3 KB
/
azure_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
__author__ = "Johannes Köster, Andreas Wilm, Jake VanCampen"
__copyright__ = "Copyright 2022, Johannes Köster"
__email__ = "johannes.koester@uni-due.de"
__license__ = "MIT"
import os
from collections import namedtuple
from urllib.parse import urlparse
import datetime
import uuid
import io
import shutil
import tarfile
import tempfile
import sys
import re
import msrest.authentication as msa
from pprint import pformat
from snakemake.executors import ClusterExecutor, sleep
from snakemake.exceptions import WorkflowError
from snakemake.logging import logger
from snakemake.common import get_container_image, get_file_hash, async_lock
from snakemake.resources import DefaultResources
AzBatchJob = namedtuple("AzBatchJob", "job jobid task_id callback error_callback")
class AzBatchConfig:
def __init__(self, batch_account_url: str):
# configure defaults
self.batch_account_url = batch_account_url
# parse batch account name
result = urlparse(self.batch_account_url)
self.batch_account_name = str.split(result.hostname, ".")[0]
self.batch_account_key = self.set_or_default("AZ_BATCH_ACCOUNT_KEY", None)
# optional subnet config
self.batch_pool_subnet_id = self.set_or_default("BATCH_POOL_SUBNET_ID", None)
# managed identity resource id configuration
self.managed_identity_resource_id = self.set_or_default(
"BATCH_MANAGED_IDENTITY_RESOURCE_ID", None
)
# parse subscription and resource id
if self.managed_identity_resource_id is not None:
self.subscription_id = self.managed_identity_resource_id.split("/")[2]
self.resource_group = self.managed_identity_resource_id.split("/")[4]
self.managed_identity_client_id = self.set_or_default(
"BATCH_MANAGED_IDENTITY_CLIENT_ID", None
)
if self.batch_pool_subnet_id is not None:
if (
self.managed_identity_client_id is None
or self.managed_identity_resource_id is None
):
sys.exit(
"Error: BATCH_MANAGED_IDENTITY_RESOURCE_ID, BATCH_MANAGED_IDENTITY_CLIENT_ID must be set when deploying batch nodes into a private subnet!"
)
# parse account details necessary for batch client authentication steps
if self.batch_pool_subnet_id.split("/")[2] != self.subscription_id:
raise WorkflowError(
"Error: managed identity must be in the same subscription as the batch pool subnet."
)
if self.batch_pool_subnet_id.split("/")[4] != self.resource_group:
raise WorkflowError(
"Error: managed identity must be in the same resource group as the batch pool subnet."
)
# sas url to a batch node start task bash script
self.batch_node_start_task_sasurl = os.getenv("BATCH_NODE_START_TASK_SAS_URL")
# options configured with env vars or default
self.batch_pool_image_publisher = self.set_or_default(
"BATCH_POOL_IMAGE_PUBLISHER", "microsoft-azure-batch"
)
self.batch_pool_image_offer = self.set_or_default(
"BATCH_POOL_IMAGE_OFFER", "ubuntu-server-container"
)
self.batch_pool_image_sku = self.set_or_default(
"BATCH_POOL_IMAGE_SKU", "20-04-lts"
)
self.batch_pool_vm_container_image = self.set_or_default(
"BATCH_POOL_VM_CONTAINER_IMAGE", "ubuntu"
)
self.batch_pool_vm_node_agent_sku_id = self.set_or_default(
"BATCH_POOL_VM_NODE_AGENT_SKU_ID", "batch.node.ubuntu 20.04"
)
self.batch_pool_vm_size = self.set_or_default(
"BATCH_POOL_VM_SIZE", "Standard_D2_v3"
)
# dedicated pool node count
self.batch_pool_node_count = self.set_or_default("BATCH_POOL_NODE_COUNT", 1)
# default tasks per node
# see https://learn.microsoft.com/en-us/azure/batch/batch-parallel-node-tasks
self.batch_tasks_per_node = self.set_or_default("BATCH_TASKS_PER_NODE", 1)
# possible values "spread" or "pack"
# see https://learn.microsoft.com/en-us/azure/batch/batch-parallel-node-tasks
self.batch_node_fill_type = self.set_or_default(
"BATCH_NODE_FILL_TYPE", "spread"
)
# enables simplified batch node communication if set
# see: https://learn.microsoft.com/en-us/azure/batch/simplified-compute-node-communication
self.batch_node_communication_mode = self.set_or_default(
"BATCH_NODE_COMMUNICATION_SIMPLIFIED", None
)
self.resource_file_prefix = self.set_or_default(
"BATCH_POOL_RESOURCE_FILE_PREFIX", "resource-files"
)
self.container_registry_url = self.set_or_default(
"BATCH_CONTAINER_REGISTRY_URL", None
)
self.container_registry_user = self.set_or_default(
"BATCH_CONTAINER_REGISTRY_USER", None
)
self.container_registry_pass = self.set_or_default(
"BATCH_CONTAINER_REGISTRY_PASS", None
)
@staticmethod
def set_or_default(evar: str, default: str | None):
gotvar = os.getenv(evar)
if gotvar is not None:
return gotvar
else:
return default
# the usage of this credential helper is required to authenitcate batch with managed identity credentials
# because not all Azure SDKs support the azure.identity credentials yet, and batch is one of them.
# ref1: https://gist.github.com/lmazuel/cc683d82ea1d7b40208de7c9fc8de59d
# ref2: https://gist.github.com/lmazuel/cc683d82ea1d7b40208de7c9fc8de59d
class AzureIdentityCredentialAdapter(msa.BasicTokenAuthentication):
def __init__(
self,
credential=None,
resource_id="https://management.azure.com/.default",
**kwargs,
):
"""Adapt any azure-identity credential to work with SDK that needs azure.common.credentials or msrestazure.
Default resource is ARM (syntax of endpoint v2)
:param credential: Any azure-identity credential (DefaultAzureCredential by default)
:param str resource_id: The scope to use to get the token (default ARM)
"""
try:
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.identity import DefaultAzureCredential
except ImportError:
raise WorkflowError(
"The Python 3 packages 'azure-core' and 'azure-identity' are required"
)
super(AzureIdentityCredentialAdapter, self).__init__(None)
if credential is None:
credential = DefaultAzureCredential()
self._policy = BearerTokenCredentialPolicy(credential, resource_id, **kwargs)
def _make_request(self):
try:
from azure.core.pipeline import PipelineRequest, PipelineContext
from azure.core.pipeline.transport import HttpRequest
except ImportError:
raise WorkflowError("The Python 3 package azure-core is required")
return PipelineRequest(
HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl"),
PipelineContext(None),
)
def set_token(self):
"""Ask the azure-core BearerTokenCredentialPolicy policy to get a token.
Using the policy gives us for free the caching system of azure-core.
We could make this code simpler by using private method, but by definition
I can't assure they will be there forever, so mocking a fake call to the policy
to extract the token, using 100% public API."""
request = self._make_request()
self._policy.on_request(request)
# Read Authorization, and get the second part after Bearer
token = request.http_request.headers["Authorization"].split(" ", 1)[1]
self.token = {"access_token": token}
def signed_session(self, session=None):
self.set_token()
return super(AzureIdentityCredentialAdapter, self).signed_session(session)
class AzBatchExecutor(ClusterExecutor):
"Azure Batch Executor"
def __init__(
self,
workflow,
dag,
cores,
jobname="snakejob.{name}.{jobid}.sh",
printreason=False,
quiet=False,
printshellcmds=False,
container_image=None,
regions=None,
location=None,
cache=False,
local_input=None,
restart_times=None,
max_status_checks_per_second=1,
az_batch_account_url=None,
az_batch_enable_autoscale=False,
):
super().__init__(
workflow,
dag,
None,
jobname=jobname,
printreason=printreason,
quiet=quiet,
printshellcmds=printshellcmds,
restart_times=restart_times,
assume_shared_fs=False,
max_status_checks_per_second=1,
)
try:
from azure.batch import BatchServiceClient
from azure.mgmt.batch import BatchManagementClient
from azure.batch.batch_auth import SharedKeyCredentials
from azure.identity import DefaultAzureCredential
from snakemake.remote.AzBlob import AzureStorageHelper
except ImportError:
raise WorkflowError(
"The Python 3 packages 'azure-batch', 'azure-mgmt-batch', and 'azure-identity'"
" must be installed to use Azure Batch"
)
AZURE_BATCH_RESOURCE_ENDPOINT = "https://batch.core.windows.net/"
# Here we validate that az blob credential is SAS
# token because it is specific to azure batch executor
self.validate_az_blob_credential_is_sas()
self.azblob_helper = AzureStorageHelper()
# get container from remote prefix
self.prefix_container = str.split(workflow.default_remote_prefix, "/")[0]
# setup batch configuration sets self.az_batch_config
self.batch_config = AzBatchConfig(az_batch_account_url)
logger.debug(f"AzBatchConfig: {self.mask_batch_config_as_string()}")
self.workflow = workflow
# handle case on OSX with /var/ symlinked to /private/var/ causing
# issues with workdir not matching other workflow file dirs
dirname = os.path.dirname(self.workflow.persistence.path)
osxprefix = "/private"
if osxprefix in dirname:
dirname = dirname.removeprefix(osxprefix)
self.workdir = dirname
self.workflow.default_resources = DefaultResources(mode="bare")
# Relative path for running on instance
self._set_snakefile()
# Prepare workflow sources for build package
self._set_workflow_sources()
# Pool ids can only contain any combination of alphanumeric characters along with dash and underscore
ts = datetime.datetime.now().strftime("%Y-%m%dT%H-%M-%S")
self.pool_id = "snakepool-{:s}".format(ts)
self.job_id = "snakejob-{:s}".format(ts)
self.envvars = list(self.workflow.envvars) or []
self.container_image = container_image or get_container_image()
# enable autoscale flag
self.az_batch_enable_autoscale = az_batch_enable_autoscale
# Package workflow sources files and upload to storage
self._build_packages = set()
targz = self._generate_build_source_package()
# removed after job failure/success
self.resource_file = self._upload_build_source_package(
targz, resource_prefix=self.batch_config.resource_file_prefix
)
# authenticate batch client from SharedKeyCredentials
if (
self.batch_config.batch_account_key is not None
and self.batch_config.managed_identity_client_id is None
):
logger.debug("Using batch account key for authentication...")
creds = SharedKeyCredentials(
self.batch_config.batch_account_name,
self.batch_config.batch_account_key,
)
# else authenticate with managed indentity client id
elif self.batch_config.managed_identity_client_id is not None:
logger.debug("Using managed identity batch authentication...")
creds = DefaultAzureCredential(
managed_identity_client_id=self.batch_config.managed_identity_client_id
)
creds = AzureIdentityCredentialAdapter(
credential=creds, resource_id=AZURE_BATCH_RESOURCE_ENDPOINT
)
self.batch_client = BatchServiceClient(
creds, batch_url=self.batch_config.batch_account_url
)
if self.batch_config.managed_identity_resource_id is not None:
self.batch_mgmt_client = BatchManagementClient(
credential=DefaultAzureCredential(
managed_identity_client_id=self.batch_config.managed_identity_client_id
),
subscription_id=self.batch_config.subscription_id,
)
try:
self.create_batch_pool()
except WorkflowError:
logger.debug("Error: Failed to create batch pool, shutting down.")
self.shutdown()
try:
self.create_batch_job()
except WorkflowError:
logger.debug("Error: Failed to create batch job, shutting down.")
self.shutdown()
def shutdown(self):
# perform additional steps on shutdown
# if necessary (jobs were cancelled already)
logger.debug("Deleting AzBatch job")
self.batch_client.job.delete(self.job_id)
logger.debug("Deleting AzBatch pool")
self.batch_client.pool.delete(self.pool_id)
logger.debug("Deleting workflow sources from blob")
self.azblob_helper.delete_from_container(
self.prefix_container, self.resource_file.file_path
)
super().shutdown()
def cancel(self):
for task in self.batch_client.task.list(self.job_id):
# strictly not need as job deletion also deletes task
self.batch_client.task.terminate(self.job_id, task.id)
self.shutdown()
# mask_dict_vals masks sensitive keys from a dictionary of values for
# logging used to mask dicts with sensitive information from logging
@staticmethod
def mask_dict_vals(mdict: dict, keys: list):
ret_dict = mdict.copy()
for k in keys:
if k in ret_dict.keys() and ret_dict[k] is not None:
ret_dict[k] = 10 * "*"
return ret_dict
# mask blob url is used to mask url values that may contain SAS
# token information from being printed to the logs
def mask_sas_urls(self, attrs: dict):
attrs_new = attrs.copy()
sas_pattern = r"\?s[v|p]=.+(\'|\"|$)"
mask = 10 * "*"
for k, value in attrs.items():
if value is not None and re.search(sas_pattern, str(value)):
attrs_new[k] = re.sub(sas_pattern, mask, value)
return attrs_new
def mask_batch_config_as_string(self) -> str:
masked_keys = self.mask_dict_vals(
self.batch_config.__dict__,
[
"batch_account_key",
"managed_identity_client_id",
],
)
masked_urls = self.mask_sas_urls(masked_keys)
return pformat(masked_urls, indent=2)
def run(self, job, callback=None, submit_callback=None, error_callback=None):
import azure.batch._batch_service_client as batch
import azure.batch.models as batchmodels
super()._run(job)
envsettings = []
for key in self.envvars:
try:
envsettings.append(
batchmodels.EnvironmentSetting(name=key, value=os.environ[key])
)
except KeyError:
continue
exec_job = self.format_job_exec(job)
exec_job = f"/bin/sh -c 'tar xzf {self.resource_file.file_path} && {exec_job}'"
# A string that uniquely identifies the Task within the Job.
task_uuid = str(uuid.uuid1())
task_id = f"{job.rule.name}-{task_uuid}"
# This is the admin user who runs the command inside the container.
user = batchmodels.AutoUserSpecification(
scope=batchmodels.AutoUserScope.pool,
elevation_level=batchmodels.ElevationLevel.admin,
)
# This is the docker image we want to run
task_container_settings = batchmodels.TaskContainerSettings(
image_name=self.container_image, container_run_options="--rm"
)
# https://docs.microsoft.com/en-us/python/api/azure-batch/azure.batch.models.taskaddparameter?view=azure-python
# all directories recursively below the AZ_BATCH_NODE_ROOT_DIR (the root of Azure Batch directories on the node)
# are mapped into the container, all Task environment variables are mapped into the container,
# and the Task command line is executed in the container
task = batch.models.TaskAddParameter(
id=task_id,
command_line=exec_job,
container_settings=task_container_settings,
resource_files=[self.resource_file], # Snakefile, envs, yml files etc.
user_identity=batchmodels.UserIdentity(auto_user=user),
environment_settings=envsettings,
)
# register job as active, using your own namedtuple.
self.batch_client.task.add(self.job_id, task)
self.active_jobs.append(
AzBatchJob(job, self.job_id, task_id, callback, error_callback)
)
logger.debug(f"Added AzBatch task {task_id}")
logger.debug(
f"Added AzBatch task {pformat(self.mask_sas_urls(task.__dict__), indent=2)}"
)
# from https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py
@staticmethod
def _read_stream_as_string(stream, encoding):
"""Read stream as string
:param stream: input stream generator
:param str encoding: The encoding of the file. The default is utf-8.
:return: The file content.
:rtype: str
"""
output = io.BytesIO()
try:
for data in stream:
output.write(data)
if encoding is None:
encoding = "utf-8"
return output.getvalue().decode(encoding)
finally:
output.close()
# adopted from https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py
def _get_task_output(self, job_id, task_id, stdout_or_stderr, encoding=None):
assert stdout_or_stderr in ["stdout", "stderr"]
fname = stdout_or_stderr + ".txt"
try:
stream = self.batch_client.file.get_from_task(job_id, task_id, fname)
content = self._read_stream_as_string(stream, encoding)
except Exception:
content = ""
return content
async def _wait_for_jobs(self):
import azure.batch.models as batchmodels
while True:
# always use self.lock to avoid race conditions
async with async_lock(self.lock):
if not self.wait:
return
active_jobs = self.active_jobs
self.active_jobs = list()
still_running = list()
# Loop through active jobs and act on status
for batch_job in active_jobs:
async with self.status_rate_limiter:
logger.debug(f"Monitoring {len(active_jobs)} active AzBatch tasks")
task = self.batch_client.task.get(self.job_id, batch_job.task_id)
if task.state == batchmodels.TaskState.completed:
dt = (
task.execution_info.end_time
- task.execution_info.start_time
)
rc = task.execution_info.exit_code
rt = task.execution_info.retry_count
stderr = self._get_task_output(
self.job_id, batch_job.task_id, "stderr"
)
stdout = self._get_task_output(
self.job_id, batch_job.task_id, "stdout"
)
logger.debug(
"task {} completed: result={} exit_code={}\n".format(
batch_job.task_id, task.execution_info.result, rc
)
)
logger.debug(
"task {} completed: run_time={}, retry_count={}\n".format(
batch_job.task_id, str(dt), rt
)
)
def print_output():
logger.debug(
"task {}: stderr='{}'\n".format(
batch_job.task_id, stderr
)
)
logger.debug(
"task {}: stdout='{}'\n".format(
batch_job.task_id, stdout
)
)
if (
task.execution_info.result
== batchmodels.TaskExecutionResult.failure
):
logger.error(
f"Azure task failed: code={str(task.execution_info.failure_info.code)}, message={str(task.execution_info.failure_info.message)}"
)
for d in task.execution_info.failure_info.details:
logger.error(f"Error Details: {str(d)}")
print_output()
batch_job.error_callback(batch_job.job)
elif (
task.execution_info.result
== batchmodels.TaskExecutionResult.success
):
batch_job.callback(batch_job.job)
else:
logger.error(
"Unknown Azure task execution result: {}".format(
task.execution_info.result
)
)
print_output()
batch_job.error_callback(batch_job.job)
# The operation is still running
else:
logger.debug(
f"task {batch_job.task_id}: creation_time={task.creation_time} state={task.state} node_info={task.node_info}\n"
)
still_running.append(batch_job)
# fail if start task fails on a node or node state becomes unusable
# and stream stderr stdout to stream
node_list = self.batch_client.compute_node.list(self.pool_id)
for n in node_list:
# error on unusable node (this occurs if your container image fails to pull)
if n.state == "unusable":
if n.errors is not None:
for e in n.errors:
logger.error(
f"Azure task error: {e.message}, {e.error_details[0].__dict__}"
)
logger.error(
"A node entered an unusable state, quitting."
)
return
if n.start_task_info is not None and (
n.start_task_info.result
== batchmodels.TaskExecutionResult.failure
):
try:
stderr_file = (
self.batch_client.file.get_from_compute_node(
self.pool_id, n.id, "/startup/stderr.txt"
)
)
stderr_stream = self._read_stream_as_string(
stderr_file, "utf-8"
)
except Exception:
stderr_stream = ""
try:
stdout_file = (
self.batch_client.file.get_from_compute_node(
self.pool_id, n.id, "/startup/stdout.txt"
)
)
stdout_stream = self._read_stream_as_string(
stdout_file, "utf-8"
)
except Exception:
stdout_stream = ""
logger.error(
"Azure start task execution failed on node: {}.\nSTART_TASK_STDERR:{}\nSTART_TASK_STDOUT: {}".format(
n.start_task_info.failure_info.message,
stdout_stream,
stderr_stream,
)
)
return
async with async_lock(self.lock):
self.active_jobs.extend(still_running)
await sleep()
def create_batch_pool(self):
"""Creates a pool of compute nodes"""
import azure.batch._batch_service_client as bsc
import azure.batch.models as batchmodels
import azure.mgmt.batch.models as mgmtbatchmodels
image_ref = bsc.models.ImageReference(
publisher=self.batch_config.batch_pool_image_publisher,
offer=self.batch_config.batch_pool_image_offer,
sku=self.batch_config.batch_pool_image_sku,
version="latest",
)
# optional subnet network configuration
# requires AAD batch auth insead of batch key auth
network_config = None
if self.batch_config.batch_pool_subnet_id is not None:
network_config = batchmodels.NetworkConfiguration(
subnet_id=self.batch_config.batch_pool_subnet_id
)
# configure a container registry
# Specify container configuration, fetching an image
# https://docs.microsoft.com/en-us/azure/batch/batch-docker-container-workloads#prefetch-images-for-container-configuration
container_config = batchmodels.ContainerConfiguration(
type="dockerCompatible", container_image_names=[self.container_image]
)
user = None
passw = None
identity_ref = None
registry_conf = None
if self.batch_config.container_registry_url is not None:
if (
self.batch_config.container_registry_user is not None
and self.batch_config.container_registry_pass is not None
):
user = self.batch_config.container_registry_user
passw = self.batch_config.container_registry_pass
elif self.batch_config.managed_identity_resource_id is not None:
identity_ref = batchmodels.ComputeNodeIdentityReference(
resource_id=self.batch_config.managed_identity_resource_id
)
else:
raise WorkflowError(
"No container registry authentication scheme set. Please set the BATCH_CONTAINER_REGISTRY_USER and BATCH_CONTAINER_REGISTRY_PASS or set MANAGED_IDENTITY_CLIENT_ID and MANAGED_IDENTITY_RESOURCE_ID."
)
registry_conf = [
batchmodels.ContainerRegistry(
registry_server=self.batch_config.container_registry_url,
identity_reference=identity_ref,
user_name=str(user),
password=str(passw),
)
]
# Specify container configuration, fetching an image
# https://docs.microsoft.com/en-us/azure/batch/batch-docker-container-workloads#prefetch-images-for-container-configuration
container_config = batchmodels.ContainerConfiguration(
type="dockerCompatible",
container_image_names=[self.container_image],
container_registries=registry_conf,
)
# default to no start task
start_task = None
# if configured us start task bash script from sas url
if self.batch_config.batch_node_start_task_sasurl is not None:
_SIMPLE_TASK_NAME = "start_task.sh"
start_task_admin = batchmodels.UserIdentity(
auto_user=batchmodels.AutoUserSpecification(
elevation_level=batchmodels.ElevationLevel.admin,
scope=batchmodels.AutoUserScope.pool,
)
)
start_task = batchmodels.StartTask(
command_line=f"bash {_SIMPLE_TASK_NAME}",
resource_files=[
batchmodels.ResourceFile(
file_path=_SIMPLE_TASK_NAME,
http_url=self.batch_config.batch_node_start_task_sasurl,
)
],
user_identity=start_task_admin,
)
# autoscale requires the initial dedicated node count to be zero
if self.az_batch_enable_autoscale:
self.batch_config.batch_pool_node_count = 0
node_communication_strategy = None
if self.batch_config.batch_node_communication_mode is not None:
node_communication_strategy = batchmodels.NodeCommunicationMode.simplified
new_pool = batchmodels.PoolAddParameter(
id=self.pool_id,
virtual_machine_configuration=batchmodels.VirtualMachineConfiguration(
image_reference=image_ref,
container_configuration=container_config,
node_agent_sku_id=self.batch_config.batch_pool_vm_node_agent_sku_id,
),
network_configuration=network_config,
vm_size=self.batch_config.batch_pool_vm_size,
target_dedicated_nodes=self.batch_config.batch_pool_node_count,
target_node_communication_mode=node_communication_strategy,
target_low_priority_nodes=0,
start_task=start_task,
task_slots_per_node=self.batch_config.batch_tasks_per_node,
task_scheduling_policy=batchmodels.TaskSchedulingPolicy(
node_fill_type=self.batch_config.batch_node_fill_type
),
)
# create pool if not exists
try:
logger.debug(f"Creating pool: {self.pool_id}")
self.batch_client.pool.add(new_pool)
if self.az_batch_enable_autoscale:
# define the autoscale formula
formula = """$samples = $PendingTasks.GetSamplePercent(TimeInterval_Minute * 5);
$tasks = $samples < 70 ? max(0,$PendingTasks.GetSample(1)) : max( $PendingTasks.GetSample(1), avg($PendingTasks.GetSample(TimeInterval_Minute * 5)));
$targetVMs = $tasks > 0? $tasks:max(0, $TargetDedicatedNodes/2);
$TargetDedicatedNodes = max(0, min($targetVMs, 10));
$NodeDeallocationOption = taskcompletion;"""
# Enable autoscale; specify the formula
self.batch_client.pool.enable_auto_scale(
self.pool_id,
auto_scale_formula=formula,
# the minimum allowed autoscale interval is 5 minutes
auto_scale_evaluation_interval=datetime.timedelta(minutes=5),
pool_enable_auto_scale_options=None,
custom_headers=None,
raw=False,
)
# update pool with managed identity, enables batch nodes to act as managed identity
if self.batch_config.managed_identity_resource_id is not None:
mid = mgmtbatchmodels.BatchPoolIdentity(
type=mgmtbatchmodels.PoolIdentityType.user_assigned,
user_assigned_identities={
self.batch_config.managed_identity_resource_id: mgmtbatchmodels.UserAssignedIdentities()
},
)
params = mgmtbatchmodels.Pool(identity=mid)
self.batch_mgmt_client.pool.update(
resource_group_name=self.batch_config.resource_group,
account_name=self.batch_config.batch_account_name,
pool_name=self.pool_id,
parameters=params,
)
except batchmodels.BatchErrorException as err:
if err.error.code != "PoolExists":
raise WorkflowError(
f"Error: Failed to create pool: {err.error.message}"
)
else:
logger.debug(f"Pool {self.pool_id} exists.")
def create_batch_job(self):
"""Creates a job with the specified ID, associated with the specified pool"""
import azure.batch._batch_service_client as bsc
logger.debug(f"Creating job {self.job_id}")
self.batch_client.job.add(
bsc.models.JobAddParameter(
id=self.job_id,
constraints=bsc.models.JobConstraints(max_task_retry_count=0),
pool_info=bsc.models.PoolInformation(pool_id=self.pool_id),
)
)
@staticmethod
def validate_az_blob_credential_is_sas():
"""
Validates that the AZ_BLOB_CREDENTIAL is a valid storage account SAS
token, required when using --az-batch with AzBlob remote.
"""
cred = os.environ.get("AZ_BLOB_CREDENTIAL")
if cred is not None:
# regex pattern for Storage Account SAS Token
rgx = r"\?sv=.*&ss=.*&srt=.*&sp=.*&se=.*&st=.*&spr=.*&sig=.*"
if re.compile(rgx).match(cred) is None:
raise WorkflowError(
"AZ_BLOB_CREDENTIAL is not a valid storage account SAS token."
)
def _set_snakefile(self):
"""The snakefile must be a relative path, which cannot be reliably
derived from the self.workflow.snakefile as we might have moved
execution into a temporary directory, and the initial Snakefile
was somewhere else on the system.
"""
self.snakefile = os.path.basename(self.workflow.main_snakefile)
# from google_lifesciences.py
def _set_workflow_sources(self):
"""We only add files from the working directory that are config related
(e.g., the Snakefile or a config.yml equivalent), or checked into git.
"""
self.workflow_sources = []
for wfs in self.dag.get_sources():
if os.path.isdir(wfs):
for dirpath, dirnames, filenames in os.walk(wfs):
self.workflow_sources.extend(
[
self.workflow.check_source_sizes(os.path.join(dirpath, f))
for f in filenames
]
)
else:
self.workflow_sources.append(
self.workflow.check_source_sizes(os.path.abspath(wfs))
)
# from google_lifesciences.py
def _generate_build_source_package(self):
"""in order for the instance to access the working directory in storage,
we need to upload it. This file is cleaned up at the end of the run.
We do this, and then obtain from the instance and extract.
"""
# Workflow sources for cloud executor must all be under same workdir root
for filename in self.workflow_sources:
if self.workdir not in filename:
raise WorkflowError(
"All source files must be present in the working directory, "
"{workdir} to be uploaded to a build package that respects "
"relative paths, but {filename} was found outside of this "
"directory. Please set your working directory accordingly, "
"and the path of your Snakefile to be relative to it.".format(
workdir=self.workdir, filename=filename
)
)
# We will generate a tar.gz package, renamed by hash
tmpname = next(tempfile._get_candidate_names())
targz = os.path.join(tempfile.gettempdir(), f"snakemake-{tmpname}.tar.gz")
tar = tarfile.open(targz, "w:gz")
# Add all workflow_sources files
for filename in self.workflow_sources:
arcname = filename.replace(self.workdir + os.path.sep, "")
tar.add(filename, arcname=arcname)
logger.debug(
f"Created {targz} with the following contents: {self.workflow_sources}"
)
tar.close()
# Rename based on hash, in case user wants to save cache
sha256 = get_file_hash(targz)
hash_tar = os.path.join(
self.workflow.persistence.aux_path, f"workdir-{sha256}.tar.gz"
)
# Only copy if we don't have it yet, clean up if we do
if not os.path.exists(hash_tar):
shutil.move(targz, hash_tar)
else:
os.remove(targz)
# We will clean these all up at shutdown
self._build_packages.add(hash_tar)
return hash_tar
def _upload_build_source_package(self, targz, resource_prefix=""):
"""given a .tar.gz created for a workflow, upload it to the blob
storage account, only if the blob doesn't already exist.
"""
import azure.batch.models as batchmodels
blob_name = os.path.join(resource_prefix, os.path.basename(targz))
# upload blob to storage using storage helper
bc = self.azblob_helper.upload_to_azure_storage(
self.prefix_container, targz, blob_name=blob_name
)
# return resource file
return batchmodels.ResourceFile(http_url=bc.url, file_path=blob_name)