Skip to content

Commit

Permalink
feat: implement precommand (#2482)
Browse files Browse the repository at this point in the history
### Description

<!--Add a description of your PR here-->

### QC
<!-- Make sure that you can tick the boxes below. -->

* [x] The PR contains a test case for the changes or the changes are
already covered by an existing test case.
* [x] The documentation (`docs/`) is updated to reflect the changes or
this is not necessary (e.g. if the change does neither modify the
language nor the behavior or functionalities of Snakemake).
  • Loading branch information
johanneskoester committed Oct 16, 2023
1 parent 83f5d7e commit ff0f979
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 164 deletions.
80 changes: 5 additions & 75 deletions snakemake/cli.py
Expand Up @@ -807,6 +807,11 @@ def get_argument_parser(profiles=None):
action="store_true",
help=("Do not evaluate or execute subworkflows."),
)
group_exec.add_argument(
"--precommand",
help="Only used in case of remote execution. Command to be executed before "
"Snakemake executes each job on the remote compute node.",
)

group_group = parser.add_argument_group("GROUPING")
group_group.add_argument(
Expand Down Expand Up @@ -1304,16 +1309,6 @@ def get_argument_parser(profiles=None):
default="",
help="Specify prefix for default storage provider. E.g. a bucket name.",
)
group_behavior.add_argument(
"--default-storage-provider-auto-deploy",
action="store_true",
help="Automatically deploy the default storage provider if it is not present "
"in the environment. This uses pip and will modify your current environment "
"by installing the storage plugin and all its dependencies if not present. "
"Use this if you run Snakemake with a remote executor plugin like "
"kubernetes where the jobs will run in a container that might not have the "
"required storage plugin installed.",
)
group_behavior.add_argument(
"--no-shared-fs",
action="store_true",
Expand Down Expand Up @@ -1385,32 +1380,6 @@ def get_argument_parser(profiles=None):
"Currently slack and workflow management system (wms) are supported.",
)

group_slurm = parser.add_argument_group("SLURM")
slurm_mode_group = group_slurm.add_mutually_exclusive_group()

slurm_mode_group.add_argument(
"--slurm",
action="store_true",
help=(
"Execute snakemake rules as SLURM batch jobs according"
" to their 'resources' definition. SLURM resources as "
" 'partition', 'ntasks', 'cpus', etc. need to be defined"
" per rule within the 'resources' definition. Note, that"
" memory can only be defined as 'mem_mb' or 'mem_mb_per_cpu'"
" as analogous to the SLURM 'mem' and 'mem-per-cpu' flags"
" to sbatch, respectively. Here, the unit is always 'MiB'."
" In addition '--default_resources' should contain the"
" SLURM account."
),
),
slurm_mode_group.add_argument(
"--slurm-jobstep",
action="store_true",
help=configargparse.SUPPRESS, # this should be hidden and only be used
# for snakemake to be working in jobscript-
# mode
)

group_cluster = parser.add_argument_group("REMOTE EXECUTION")

group_cluster.add_argument(
Expand Down Expand Up @@ -1456,44 +1425,6 @@ def get_argument_parser(profiles=None):
)

group_flux = parser.add_argument_group("FLUX")
group_google_life_science = parser.add_argument_group("GOOGLE_LIFE_SCIENCE")
group_tes = parser.add_argument_group("TES")
group_tibanna = parser.add_argument_group("TIBANNA")

group_tibanna.add_argument(
"--tibanna",
action="store_true",
help="Execute workflow on AWS cloud using Tibanna. This requires "
"--default-storage-prefix to be set to S3 bucket name and prefix"
" (e.g. 'bucketname/subdirectory') where input is already stored"
" and output will be sent to. Using --tibanna implies --default-resources"
" is set as default. Optionally, use --precommand to"
" specify any preparation command to run before snakemake command"
" on the cloud (inside snakemake container on Tibanna VM)."
" Also, --use-conda, --use-singularity, --config, --configfile are"
" supported and will be carried over.",
)
group_tibanna.add_argument(
"--tibanna-sfn",
help="Name of Tibanna Unicorn step function (e.g. tibanna_unicorn_monty)."
"This works as serverless scheduler/resource allocator and must be "
"deployed first using tibanna cli. (e.g. tibanna deploy_unicorn --usergroup="
"monty --buckets=bucketname)",
)
group_tibanna.add_argument(
"--precommand",
help="Any command to execute before snakemake command on AWS cloud "
"such as wget, git clone, unzip, etc. This is used with --tibanna."
"Do not include input/output download/upload commands - file transfer"
" between S3 bucket and the run environment (container) is automatically"
" handled by Tibanna.",
)
group_tibanna.add_argument(
"--tibanna-config",
nargs="+",
help="Additional tibanna config e.g. --tibanna-config spot_instance=true subnet="
"<subnet_id> security group=<security_group_id>",
)

group_flux.add_argument(
"--flux",
Expand Down Expand Up @@ -1892,7 +1823,6 @@ def args_to_api(args, parser):
conda_not_block_search_path_envvars=args.conda_not_block_search_path_envvars,
apptainer_args=args.apptainer_args,
apptainer_prefix=args.apptainer_prefix,
default_storage_provider_auto_deploy=args.default_storage_provider_auto_deploy,
),
snakefile=args.snakefile,
workdir=args.directory,
Expand Down
22 changes: 12 additions & 10 deletions snakemake/dag.py
Expand Up @@ -281,7 +281,7 @@ def update_conda_envs(self):
(job.conda_env_spec, job.container_img_url)
for job in self.jobs
if job.conda_env_spec
and (self.workflow.storage_settings.assume_shared_fs or job.is_local)
and (job.is_local or self.workflow.global_or_node_local_shared_fs)
}

# Then based on md5sum values
Expand All @@ -305,21 +305,23 @@ def update_conda_envs(self):
self.conda_envs[key] = env

async def retrieve_storage_inputs(self):
if self.is_main_process or (
self.workflow.exec_mode == ExecMode.REMOTE
and not self.workflow.storage_settings.assume_shared_fs
):
if self.is_main_process or self.workflow.remote_exec_no_shared_fs:
async with asyncio.TaskGroup() as tg:
for job in self.jobs:
for f in job.input:
if f.is_storage and self.is_external_input(f, job):
tg.create_task(f.retrieve_from_storage())

async def store_storage_outputs(self):
if self.workflow.remote_exec_no_shared_fs:
async with asyncio.TaskGroup() as tg:
for job in self.jobs:
for f in job.output:
if f.is_storage:
tg.create_task(f.store_in_storage())

def cleanup_storage_objects(self):
if self.is_main_process or (
self.workflow.exec_mode == ExecMode.REMOTE
and not self.workflow.storage_settings.assume_shared_fs
):
if self.is_main_process or self.workflow.remote_exec_no_shared_fs:
cleaned = set()
for job in self.jobs:
for f in chain(job.input, job.output):
Expand Down Expand Up @@ -1546,7 +1548,7 @@ async def update_checkpoint_dependencies(self, jobs=None):
updated = True
if updated:
await self.postprocess()
if self.workflow.storage_settings.assume_shared_fs:
if self.workflow.global_or_node_local_shared_fs:
await self.retrieve_storage_inputs()
return updated

Expand Down
2 changes: 2 additions & 0 deletions snakemake/executors/dryrun.py
Expand Up @@ -18,6 +18,8 @@
non_local_exec=False,
dryrun_exec=True,
implies_no_shared_fs=False,
pass_envvar_declarations_to_cmd=False,
auto_deploy_default_storage_provider=False,
)


Expand Down
14 changes: 3 additions & 11 deletions snakemake/executors/local.py
Expand Up @@ -42,6 +42,8 @@
common_settings = CommonSettings(
non_local_exec=False,
implies_no_shared_fs=False,
pass_envvar_declarations_to_cmd=False,
auto_deploy_default_storage_provider=False,
)


Expand All @@ -55,17 +57,7 @@


class Executor(RealExecutor):
def __init__(
self,
workflow: WorkflowExecutorInterface,
logger: LoggerExecutorInterface,
):
super().__init__(
workflow,
logger,
pass_envvar_declarations_to_cmd=False,
)

def __post_init__(self):
self.use_threads = self.workflow.execution_settings.use_threads
self.keepincomplete = self.workflow.execution_settings.keep_incomplete
cores = self.workflow.resource_settings.cores
Expand Down
13 changes: 2 additions & 11 deletions snakemake/executors/touch.py
Expand Up @@ -22,21 +22,12 @@
non_local_exec=False,
implies_no_shared_fs=False,
touch_exec=True,
pass_envvar_declarations_to_cmd=False,
auto_deploy_default_storage_provider=False,
)


class Executor(RealExecutor):
def __init__(
self,
workflow: WorkflowExecutorInterface,
logger: LoggerExecutorInterface,
):
super().__init__(
workflow,
logger,
pass_envvar_declarations_to_cmd=False,
)

def run_job(
self,
job: JobExecutorInterface,
Expand Down
8 changes: 7 additions & 1 deletion snakemake/io.py
Expand Up @@ -533,7 +533,13 @@ def is_fifo(self):
@iocache
async def size(self):
if self.is_storage:
return await self.storage_object.managed_size()
try:
return await self.storage_object.managed_size()
except WorkflowError as e:
try:
return await self.size_local()
except IOError:
raise e
else:
return await self.size_local()

Expand Down
2 changes: 1 addition & 1 deletion snakemake/jobs.py
Expand Up @@ -1059,7 +1059,7 @@ async def postprocess(
# No postprocessing necessary, we have just created the skeleton notebook and
# execution will anyway stop afterwards.
return
if self.dag.workflow.storage_settings.assume_shared_fs:
if self.dag.workflow.global_or_node_local_shared_fs:
if not error and handle_touch:
self.dag.handle_touch(self)
if handle_log:
Expand Down
2 changes: 1 addition & 1 deletion snakemake/settings.py
Expand Up @@ -220,7 +220,6 @@ class DeploymentSettings(SettingsBase, DeploymentSettingsExecutorInterface):
conda_not_block_search_path_envvars: bool = False
apptainer_args: str = ""
apptainer_prefix: Optional[Path] = None
default_storage_provider_auto_deploy: bool = False

def imply_deployment_method(self, method: DeploymentMethod):
self.deployment_method = set(self.deployment_method)
Expand Down Expand Up @@ -350,6 +349,7 @@ class RemoteExecutionSettings(SettingsBase, RemoteExecutionSettingsExecutorInter
preemptible_rules: PreemptibleRules = field(default_factory=PreemptibleRules)
envvars: Sequence[str] = tuple()
immediate_submit: bool = False
precommand: Optional[str] = None


@dataclass
Expand Down
24 changes: 20 additions & 4 deletions snakemake/spawn_jobs.py
@@ -1,7 +1,7 @@
from dataclasses import dataclass, fields
import os
import sys
from typing import TypeVar, TYPE_CHECKING, Any
from typing import Mapping, TypeVar, TYPE_CHECKING, Any
from snakemake_interface_executor_plugins.utils import format_cli_arg, join_cli_args
from snakemake_interface_storage_plugins.registry import StoragePluginRegistry

Expand Down Expand Up @@ -117,19 +117,33 @@ def workflow_property_to_arg(

return format_cli_arg(flag, value, quote=quote)

def envvars(self):
def envvars(self) -> Mapping[str, str]:
envvars = {
var: os.environ[var]
for var in self.workflow.remote_execution_settings.envvars
}
envvars.update(self.get_storage_provider_envvars())
return envvars

def precommand(self, auto_deploy_default_storage_provider: bool) -> str:
precommand = self.workflow.remote_execution_settings.precommand or ""
if (
auto_deploy_default_storage_provider
and self.workflow.storage_settings.default_storage_provider is not None
):
package_name = StoragePluginRegistry().get_plugin_package_name(
self.workflow.storage_settings.default_storage_provider
)
if precommand:
precommand += " && "
precommand += f"pip install {package_name}"
return precommand

def general_args(
self,
pass_default_storage_provider_args: bool = True,
pass_default_resources_args: bool = False,
):
) -> str:
"""Return a string to add to self.exec_job that includes additional
arguments from the command line. This is currently used in the
ClusterExecutor and CPUExecutor, as both were using the same
Expand Down Expand Up @@ -164,8 +178,10 @@ def general_args(
),
w2a("deployment_settings.apptainer_prefix"),
w2a("deployment_settings.apptainer_args"),
w2a("deployment_settings.default_storage_provider_auto_deploy"),
w2a("resource_settings.max_threads"),
w2a(
"storage_settings.assume_shared_fs", flag="--no-shared-fs", invert=True
),
w2a(
"execution_settings.keep_metadata", flag="--drop-metadata", invert=True
),
Expand Down
29 changes: 4 additions & 25 deletions snakemake/storage.py
Expand Up @@ -34,31 +34,10 @@ def __init__(self, workflow: Workflow):
self._default_storage_provider = None

if self.workflow.storage_settings.default_storage_provider is not None:
self._register_default_storage()

def _register_default_storage(self):
plugin_name = self.workflow.storage_settings.default_storage_provider
if (
not StoragePluginRegistry().is_installed(plugin_name)
and self.workflow.deployment_settings.default_storage_provider_auto_deploy
):
try:
subprocess.run(
["pip", "install", f"snakemake-storage-plugin-{plugin_name}"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=True,
)
except subprocess.CalledProcessError as e:
raise WorkflowError(
f"Failed to install storage plugin {plugin_name} via pip: {e.stdout.decode()}",
e,
)
StoragePluginRegistry().collect_plugins()
self._default_storage_provider = self.register_storage(
plugin_name,
is_default=True,
)
self._default_storage_provider = self.register_storage(
self.workflow.storage_settings.default_storage_provider,
is_default=True,
)

@property
def default_storage_provider(self):
Expand Down

0 comments on commit ff0f979

Please sign in to comment.