Skip to content

Commit

Permalink
feat: add ability to inject conda environments into running Snakefile (
Browse files Browse the repository at this point in the history
…#2479)

### Description

<!--Add a description of your PR here-->

### QC
<!-- Make sure that you can tick the boxes below. -->

* [x] The PR contains a test case for the changes or the changes are
already covered by an existing test case.
* [ ] The documentation (`docs/`) is updated to reflect the changes or
this is not necessary (e.g. if the change does neither modify the
language nor the behavior or functionalities of Snakemake).
  • Loading branch information
johanneskoester committed Oct 13, 2023
1 parent fb65c33 commit 6140e29
Show file tree
Hide file tree
Showing 16 changed files with 214 additions and 55 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/main.yml
Expand Up @@ -106,7 +106,10 @@ jobs:
#PYTHONTRACEMALLOC: 10
shell: bash -el {0}
run: |
pytest --show-capture=stderr --splits 10 --group ${{ matrix.test_group }} -v -x tests/test_expand.py tests/test_io.py tests/test_schema.py tests/test_linting.py tests/test_schema.py tests/test_linting.py tests/test_executor_test_suite.py tests/tests.py
# long tests
pytest --show-capture=stderr --splits 10 --group ${{ matrix.test_group }} -v -x tests/tests.py
# other tests
pytest --show-capture=stderr -v -x tests/test_expand.py tests/test_io.py tests/test_schema.py tests/test_linting.py tests/test_executor_test_suite.py
build-container-image:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -60,6 +60,7 @@ install_requires =
toposort >=1.10
wrapt
yte >=1.5.1,<2.0
conda-inject >=1.1.1,<2.0

[options.extras_require]
google-cloud =
Expand Down
41 changes: 24 additions & 17 deletions snakemake/api.py
Expand Up @@ -38,7 +38,6 @@

from snakemake_interface_executor_plugins.settings import ExecMode, ExecutorSettingsBase
from snakemake_interface_executor_plugins.registry import ExecutorPluginRegistry
from snakemake_interface_storage_plugins.settings import StorageProviderSettingsBase
from snakemake_interface_common.exceptions import ApiError
from snakemake_interface_common.plugin_registry.plugin import TaggedSettings

Expand Down Expand Up @@ -100,6 +99,7 @@ def workflow(
config_settings: Optional[ConfigSettings] = None,
storage_settings: Optional[StorageSettings] = None,
workflow_settings: Optional[WorkflowSettings] = None,
deployment_settings: Optional[DeploymentSettings] = None,
storage_provider_settings: Optional[Mapping[str, TaggedSettings]] = None,
snakefile: Optional[Path] = None,
workdir: Optional[Path] = None,
Expand All @@ -124,6 +124,8 @@ def workflow(
storage_settings = StorageSettings()
if workflow_settings is None:
workflow_settings = WorkflowSettings()
if deployment_settings is None:
deployment_settings = DeploymentSettings()
if storage_provider_settings is None:
storage_provider_settings = dict()

Expand All @@ -141,6 +143,7 @@ def workflow(
resource_settings=resource_settings,
storage_settings=storage_settings,
workflow_settings=workflow_settings,
deployment_settings=deployment_settings,
storage_provider_settings=storage_provider_settings,
)
return self._workflow_api
Expand All @@ -151,11 +154,11 @@ def _cleanup(self):
logger.cleanup()
if self._workflow_api is not None:
self._workflow_api._workdir_handler.change_back()
if (
self._workflow_api._workflow_store is not None
and self._workflow_api._workflow._workdir_handler is not None
):
self._workflow_api._workflow._workdir_handler.change_back()
if self._workflow_api._workflow_store is not None:
for conda_env in self._workflow_api._workflow_store.injected_conda_envs:
conda_env.remove()
if self._workflow_api._workflow._workdir_handler is not None:
self._workflow_api._workflow._workdir_handler.change_back()

def print_exception(self, ex: Exception):
"""Print an exception during workflow execution in a human readable way
Expand Down Expand Up @@ -231,14 +234,15 @@ class WorkflowApi(ApiBase):
resource_settings: ResourceSettings
storage_settings: StorageSettings
workflow_settings: WorkflowSettings
deployment_settings: DeploymentSettings
storage_provider_settings: Mapping[str, TaggedSettings]

_workflow_store: Optional[Workflow] = field(init=False, default=None)
_workdir_handler: Optional[WorkdirHandler] = field(init=False)

def dag(
self,
dag_settings: Optional[DAGSettings] = None,
deployment_settings: Optional[DeploymentSettings] = None,
):
"""Create a DAG API.
Expand All @@ -248,14 +252,11 @@ def dag(
"""
if dag_settings is None:
dag_settings = DAGSettings()
if deployment_settings is None:
deployment_settings = DeploymentSettings()

return DAGApi(
self.snakemake_api,
self,
dag_settings=dag_settings,
deployment_settings=deployment_settings,
)

def lint(self, json: bool = False):
Expand Down Expand Up @@ -312,6 +313,7 @@ def _get_workflow(self, **kwargs):
config_settings=self.config_settings,
resource_settings=self.resource_settings,
workflow_settings=self.workflow_settings,
deployment_settings=self.deployment_settings,
storage_settings=self.storage_settings,
output_settings=self.snakemake_api.output_settings,
overwrite_workdir=self.workdir,
Expand Down Expand Up @@ -344,11 +346,9 @@ class DAGApi(ApiBase):
snakemake_api: SnakemakeApi
workflow_api: WorkflowApi
dag_settings: DAGSettings
deployment_settings: DeploymentSettings

def __post_init__(self):
self.workflow_api._workflow.dag_settings = self.dag_settings
self.workflow_api._workflow.deployment_settings = self.deployment_settings

def execute_workflow(
self,
Expand All @@ -367,7 +367,6 @@ def execute_workflow(
executor: str -- The executor to use.
execution_settings: ExecutionSettings -- The execution settings for the workflow.
resource_settings: ResourceSettings -- The resource settings for the workflow.
deployment_settings: DeploymentSettings -- The deployment settings for the workflow.
remote_execution_settings: RemoteExecutionSettings -- The remote execution settings for the workflow.
executor_settings: Optional[ExecutorSettingsBase] -- The executor settings for the workflow.
updated_files: Optional[List[str]] -- An optional list where Snakemake will put all updated files.
Expand Down Expand Up @@ -522,17 +521,23 @@ def cleanup_metadata(self, paths: List[Path]):

def conda_cleanup_envs(self):
"""Cleanup the conda environments of the workflow."""
self.deployment_settings.imply_deployment_method(DeploymentMethod.CONDA)
self.workflow_api.deployment_settings.imply_deployment_method(
DeploymentMethod.CONDA
)
self.workflow_api._workflow.conda_cleanup_envs()

def conda_create_envs(self):
"""Only create the conda environments of the workflow."""
self.deployment_settings.imply_deployment_method(DeploymentMethod.CONDA)
self.workflow_api.deployment_settings.imply_deployment_method(
DeploymentMethod.CONDA
)
self.workflow_api._workflow.conda_create_envs()

def conda_list_envs(self):
"""List the conda environments of the workflow."""
self.deployment_settings.imply_deployment_method(DeploymentMethod.CONDA)
self.workflow_api.deployment_settings.imply_deployment_method(
DeploymentMethod.CONDA
)
self.workflow_api._workflow.conda_list_envs()

def cleanup_shadow(self):
Expand All @@ -541,7 +546,9 @@ def cleanup_shadow(self):

def container_cleanup_images(self):
"""Cleanup the container images of the workflow."""
self.deployment_settings.imply_deployment_method(DeploymentMethod.APPTAINER)
self.workflow_api.deployment_settings.imply_deployment_method(
DeploymentMethod.APPTAINER
)
self.workflow_api._workflow.container_cleanup_images()

def list_changes(self, change_type: ChangeType):
Expand Down
47 changes: 29 additions & 18 deletions snakemake/cli.py
Expand Up @@ -1305,6 +1305,16 @@ def get_argument_parser(profiles=None):
default="",
help="Specify prefix for default storage provider. E.g. a bucket name.",
)
group_behavior.add_argument(
"--default-storage-provider-auto-deploy",
action="store_true",
help="Automatically deploy the default storage provider if it is not present "
"in the environment. This uses pip and will modify your current environment "
"by installing the storage plugin and all its dependencies if not present. "
"Use this if you run Snakemake with a remote executor plugin like "
"kubernetes where the jobs will run in a container that might not have the "
"required storage plugin installed.",
)
group_behavior.add_argument(
"--no-shared-fs",
action="store_true",
Expand Down Expand Up @@ -1836,6 +1846,14 @@ def args_to_api(args, parser):
keep_logger=False,
)
) as snakemake_api:
deployment_method = args.software_deployment_method
if args.use_conda:
deployment_method.add(DeploymentMethod.CONDA)
if args.use_apptainer:
deployment_method.add(DeploymentMethod.APPTAINER)
if args.use_envmodules:
deployment_method.add(DeploymentMethod.ENV_MODULES)

try:
workflow_api = snakemake_api.workflow(
resource_settings=ResourceSettings(
Expand Down Expand Up @@ -1866,6 +1884,17 @@ def args_to_api(args, parser):
workflow_settings=WorkflowSettings(
wrapper_prefix=args.wrapper_prefix,
),
deployment_settings=DeploymentSettings(
deployment_method=deployment_method,
conda_prefix=args.conda_prefix,
conda_cleanup_pkgs=args.conda_cleanup_pkgs,
conda_base_path=args.conda_base_path,
conda_frontend=args.conda_frontend,
conda_not_block_search_path_envvars=args.conda_not_block_search_path_envvars,
apptainer_args=args.apptainer_args,
apptainer_prefix=args.apptainer_prefix,
default_storage_provider_auto_deploy=args.default_storage_provider_auto_deploy,
),
snakefile=args.snakefile,
workdir=args.directory,
)
Expand All @@ -1882,14 +1911,6 @@ def args_to_api(args, parser):
elif args.print_compilation:
workflow_api.print_compilation()
else:
deployment_method = args.software_deployment_method
if args.use_conda:
deployment_method.add(DeploymentMethod.CONDA)
if args.use_apptainer:
deployment_method.add(DeploymentMethod.APPTAINER)
if args.use_envmodules:
deployment_method.add(DeploymentMethod.ENV_MODULES)

dag_api = workflow_api.dag(
dag_settings=DAGSettings(
targets=args.targets,
Expand All @@ -1907,16 +1928,6 @@ def args_to_api(args, parser):
max_inventory_wait_time=args.max_inventory_time,
cache=args.cache,
),
deployment_settings=DeploymentSettings(
deployment_method=deployment_method,
conda_prefix=args.conda_prefix,
conda_cleanup_pkgs=args.conda_cleanup_pkgs,
conda_base_path=args.conda_base_path,
conda_frontend=args.conda_frontend,
conda_not_block_search_path_envvars=args.conda_not_block_search_path_envvars,
apptainer_args=args.apptainer_args,
apptainer_prefix=args.apptainer_prefix,
),
)

if args.preemptible_rules is not None:
Expand Down
25 changes: 16 additions & 9 deletions snakemake/common/tests/__init__.py
Expand Up @@ -57,6 +57,19 @@ def get_default_storage_provider_settings(
) -> Optional[Mapping[str, TaggedSettings]]:
...

def get_remote_execution_settings(self) -> settings.RemoteExecutionSettings:
return settings.RemoteExecutionSettings(
seconds_between_status_checks=0,
envvars=self.get_envvars(),
)

def get_deployment_settings(
self, deployment_method=frozenset()
) -> settings.DeploymentSettings:
return settings.DeploymentSettings(
deployment_method=deployment_method,
)

def get_assume_shared_fs(self) -> bool:
return True

Expand Down Expand Up @@ -101,26 +114,20 @@ def _run_workflow(self, test_name, tmp_path, deployment_method=frozenset()):
default_storage_prefix=self.get_default_storage_prefix(),
assume_shared_fs=self.get_assume_shared_fs(),
),
deployment_settings=self.get_deployment_settings(deployment_method),
storage_provider_settings=self.get_default_storage_provider_settings(),
workdir=Path(tmp_path),
snakefile=tmp_path / "Snakefile",
)

dag_api = workflow_api.dag(
deployment_settings=settings.DeploymentSettings(
deployment_method=deployment_method,
),
)
dag_api = workflow_api.dag()
dag_api.execute_workflow(
executor=self.get_executor(),
executor_settings=self.get_executor_settings(),
execution_settings=settings.ExecutionSettings(
latency_wait=self.latency_wait,
),
remote_execution_settings=settings.RemoteExecutionSettings(
seconds_between_status_checks=0,
envvars=self.get_envvars(),
),
remote_execution_settings=self.get_remote_execution_settings(),
)

@handle_testcase
Expand Down
7 changes: 7 additions & 0 deletions snakemake/parser.py
Expand Up @@ -366,6 +366,12 @@ def keyword(self):
return "global_containerized"


class GlobalConda(GlobalKeywordState):
@property
def keyword(self):
return "global_conda"


class Localrules(GlobalKeywordState):
def block_content(self, token):
if is_comma(token):
Expand Down Expand Up @@ -1180,6 +1186,7 @@ class Python(TokenAutomaton):
singularity=GlobalSingularity,
container=GlobalContainer,
containerized=GlobalContainerized,
conda=GlobalConda,
scattergather=Scattergather,
storage=Storage,
resource_scopes=ResourceScope,
Expand Down
1 change: 1 addition & 0 deletions snakemake/settings.py
Expand Up @@ -220,6 +220,7 @@ class DeploymentSettings(SettingsBase, DeploymentSettingsExecutorInterface):
conda_not_block_search_path_envvars: bool = False
apptainer_args: str = ""
apptainer_prefix: Optional[Path] = None
default_storage_provider_auto_deploy: bool = False

def imply_deployment_method(self, method: DeploymentMethod):
self.deployment_method = set(self.deployment_method)
Expand Down
1 change: 1 addition & 0 deletions snakemake/spawn_jobs.py
Expand Up @@ -164,6 +164,7 @@ def general_args(
),
w2a("deployment_settings.apptainer_prefix"),
w2a("deployment_settings.apptainer_args"),
w2a("deployment_settings.default_storage_provider_auto_deploy"),
w2a("resource_settings.max_threads"),
w2a(
"execution_settings.keep_metadata", flag="--drop-metadata", invert=True
Expand Down
34 changes: 30 additions & 4 deletions snakemake/storage.py
@@ -1,14 +1,17 @@
import copy, sys
import subprocess
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from snakemake.workflow import Workflow
from snakemake_interface_common.exceptions import WorkflowError
from snakemake_interface_common.exceptions import WorkflowError, InvalidPluginException
from snakemake_interface_storage_plugins.registry import StoragePluginRegistry
from snakemake_interface_storage_plugins.storage_provider import StorageProviderBase
from snakemake_interface_storage_plugins.storage_object import (
StorageObjectWrite,
StorageObjectRead,
)
from snakemake_interface_executor_plugins.settings import DeploymentMethod
from snakemake.common import __version__


class StorageRegistry:
Expand All @@ -17,6 +20,7 @@ class StorageRegistry:
"_storages",
"_default_storage_provider",
"default_storage_provider",
"_register_default_storage",
"register_storage",
"infer_provider",
"_storage_object",
Expand All @@ -30,9 +34,31 @@ def __init__(self, workflow: Workflow):
self._default_storage_provider = None

if self.workflow.storage_settings.default_storage_provider is not None:
self._default_storage_provider = self.register_storage(
self.workflow.storage_settings.default_storage_provider, is_default=True
)
self._register_default_storage()

def _register_default_storage(self):
plugin_name = self.workflow.storage_settings.default_storage_provider
if (
not StoragePluginRegistry().is_installed(plugin_name)
and self.workflow.deployment_settings.default_storage_provider_auto_deploy
):
try:
subprocess.run(
["pip", "install", f"snakemake-storage-plugin-{plugin_name}"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=True,
)
except subprocess.CalledProcessError as e:
raise WorkflowError(
f"Failed to install storage plugin {plugin_name} via pip: {e.stdout.decode()}",
e,
)
StoragePluginRegistry().collect_plugins()
self._default_storage_provider = self.register_storage(
plugin_name,
is_default=True,
)

@property
def default_storage_provider(self):
Expand Down

0 comments on commit 6140e29

Please sign in to comment.