diff --git a/snakemake/__init__.py b/snakemake/__init__.py index 167a0c396..4ba15467a 100644 --- a/snakemake/__init__.py +++ b/snakemake/__init__.py @@ -181,6 +181,7 @@ def snakemake( kubernetes=None, container_image=None, k8s_cpu_scalar=1.0, + k8s_service_account_name=None, flux=False, tibanna=False, tibanna_sfn=None, @@ -325,6 +326,7 @@ def snakemake( kubernetes (str): submit jobs to Kubernetes, using the given namespace. container_image (str): Docker image to use, e.g., for Kubernetes. k8s_cpu_scalar (float): What proportion of each k8s node's CPUs are availabe to snakemake? + k8s_service_account_name (str): Custom k8s service account, needed for workload identity. flux (bool): Launch workflow to flux cluster. default_remote_provider (str): default remote provider to use instead of local files (e.g. S3, GS) default_remote_prefix (str): prefix for default remote provider (e.g. name of the bucket). @@ -753,6 +755,7 @@ def snakemake( kubernetes=kubernetes, container_image=container_image, k8s_cpu_scalar=k8s_cpu_scalar, + k8s_service_account_name=k8s_service_account_name, conda_create_envs_only=conda_create_envs_only, default_remote_provider=default_remote_provider, default_remote_prefix=default_remote_prefix, @@ -819,6 +822,7 @@ def snakemake( kubernetes=kubernetes, container_image=container_image, k8s_cpu_scalar=k8s_cpu_scalar, + k8s_service_account_name=k8s_service_account_name, tibanna=tibanna, tibanna_sfn=tibanna_sfn, az_batch=az_batch, @@ -2473,6 +2477,16 @@ def get_argument_parser(profiles=None): "see the original value, i.e. as the value substituted in {threads}.", ) + group_kubernetes.add_argument( + "--k8s-service-account-name", + metavar="SERVICEACCOUNTNAME", + default=None, + help="This argument allows the use of customer service accounts for " + "kubernetes pods. If specified serviceAccountName will be added to the " + "pod specs. This is needed when using workload identity which is enforced " + "when using Google Cloud GKE Autopilot.", + ) + group_tibanna.add_argument( "--tibanna", action="store_true", @@ -3179,6 +3193,7 @@ def open_browser(): kubernetes=args.kubernetes, container_image=args.container_image, k8s_cpu_scalar=args.k8s_cpu_scalar, + k8s_service_account_name=args.k8s_service_account_name, flux=args.flux, tibanna=args.tibanna, tibanna_sfn=args.tibanna_sfn, diff --git a/snakemake/executors/__init__.py b/snakemake/executors/__init__.py index 2558888ff..f5d40f7cb 100644 --- a/snakemake/executors/__init__.py +++ b/snakemake/executors/__init__.py @@ -1742,6 +1742,7 @@ def __init__( namespace, container_image=None, k8s_cpu_scalar=1.0, + k8s_service_account_name=None, jobname="{rulename}.{jobid}", printreason=False, quiet=False, @@ -1783,6 +1784,7 @@ def __init__( import kubernetes.client self.k8s_cpu_scalar = k8s_cpu_scalar + self.k8s_service_account_name = k8s_service_account_name self.kubeapi = kubernetes.client.CoreV1Api() self.batchapi = kubernetes.client.BatchV1Api() self.namespace = namespace @@ -1969,6 +1971,10 @@ def run( body.spec = kubernetes.client.V1PodSpec( containers=[container], node_selector=node_selector ) + # Add service account name if provided + if self.k8s_service_account_name: + body.spec.service_account_name = self.k8s_service_account_name + # fail on first error body.spec.restart_policy = "Never" diff --git a/snakemake/scheduler.py b/snakemake/scheduler.py index 0241794c8..589bcdc9e 100644 --- a/snakemake/scheduler.py +++ b/snakemake/scheduler.py @@ -80,6 +80,7 @@ def __init__( env_modules=None, kubernetes=None, k8s_cpu_scalar=1.0, + k8s_service_account_name=None, container_image=None, flux=None, tibanna=None, @@ -316,6 +317,7 @@ def __init__( kubernetes, container_image=container_image, k8s_cpu_scalar=k8s_cpu_scalar, + k8s_service_account_name=k8s_service_account_name, printreason=printreason, quiet=quiet, printshellcmds=printshellcmds, diff --git a/snakemake/workflow.py b/snakemake/workflow.py index fc1ce7287..383766441 100644 --- a/snakemake/workflow.py +++ b/snakemake/workflow.py @@ -672,6 +672,7 @@ def execute( drmaa_log_dir=None, kubernetes=None, k8s_cpu_scalar=1.0, + k8s_service_account_name=None, flux=None, tibanna=None, tibanna_sfn=None, @@ -1126,6 +1127,7 @@ def files(items): drmaa_log_dir=drmaa_log_dir, kubernetes=kubernetes, k8s_cpu_scalar=k8s_cpu_scalar, + k8s_service_account_name=k8s_service_account_name, flux=flux, tibanna=tibanna, tibanna_sfn=tibanna_sfn,