From 2be3bfa4841967928069a2a024554b8a86b699f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20K=C3=B6ster?= Date: Fri, 15 Sep 2023 11:45:17 +0200 Subject: [PATCH] feat!: redesigned Snakemake API. It now uses a modern, dataclass based approach (#2403) ### Description ### QC * [x] The PR contains a test case for the changes or the changes are already covered by an existing test case. * [ ] The documentation (`docs/`) is updated to reflect the changes or this is not necessary (e.g. if the change does neither modify the language nor the behavior or functionalities of Snakemake). --- .github/workflows/main.yml | 164 +- .github/workflows/test-flux.yaml | 65 - motoState.p | 1 + setup.cfg | 3 +- snakemake/__init__.py | 7 + snakemake/api.py | 1368 +++++------- snakemake/caching/hash.py | 20 +- snakemake/caching/remote.py | 3 +- snakemake/cli.py | 1325 ++++------- snakemake/common/__init__.py | 37 +- snakemake/common/argparse.py | 57 + snakemake/common/configfile.py | 43 + snakemake/common/git.py | 85 + snakemake/common/tests/__init__.py | 122 + .../tests/testcases}/__init__.py | 0 .../common/tests/testcases/groups/Snakefile | 32 + .../common/tests/testcases/groups/__init__.py | 0 .../common/tests/testcases/simple/Snakefile | 30 + .../common/tests/testcases/simple/__init__.py | 0 snakemake/common/workdir_handler.py | 24 + snakemake/cwl.py | 7 +- snakemake/dag.py | 215 +- snakemake/deployment/conda.py | 4 +- snakemake/deployment/singularity.py | 4 +- snakemake/exceptions.py | 15 +- snakemake/executors/__init__.py | 1988 ----------------- snakemake/executors/azure_batch.py | 933 -------- snakemake/executors/dryrun.py | 77 + snakemake/executors/flux.py | 7 +- snakemake/executors/ga4gh_tes.py | 323 --- snakemake/executors/google_lifesciences.py | 1054 --------- .../executors/google_lifesciences_helper.py | 113 - snakemake/executors/local.py | 475 ++++ snakemake/executors/slurm/slurm_jobstep.py | 118 - snakemake/executors/slurm/slurm_submit.py | 457 ---- snakemake/executors/touch.py | 73 + snakemake/io.py | 120 - snakemake/jobs.py | 77 +- snakemake/linting/rules.py | 9 - snakemake/logging.py | 29 +- snakemake/modules.py | 3 +- snakemake/notebook.py | 7 - snakemake/parser.py | 118 +- snakemake/path_modifier.py | 6 +- snakemake/persistence.py | 38 +- snakemake/remote/FTP.py | 2 +- snakemake/remote/HTTP.py | 117 +- snakemake/report/__init__.py | 2 +- snakemake/resources.py | 7 +- snakemake/ruleinfo.py | 1 - snakemake/rules.py | 117 +- snakemake/scheduler.py | 677 +++--- snakemake/settings.py | 381 ++++ snakemake/shell.py | 5 +- snakemake/snakemake.code-workspace | 8 + snakemake/sourcecache.py | 2 +- snakemake/spawn_jobs.py | 151 ++ snakemake/stats.py | 82 - snakemake/target_jobs.py | 6 +- snakemake/unit_tests/__init__.py | 4 +- .../unit_tests/templates/ruletest.py.jinja2 | 2 +- snakemake/utils.py | 11 +- snakemake/workflow.py | 1628 ++++++-------- test-environment.yml | 6 +- test.py | 19 + tests/common.py | 194 +- tests/test01/Snakefile | 3 +- tests/test14/Snakefile.nonstandard | 50 - tests/test14/expected-results/test.1.inter | 2 - tests/test14/expected-results/test.1.inter2 | 2 - tests/test14/expected-results/test.2.inter | 2 - tests/test14/expected-results/test.2.inter2 | 2 - tests/test14/expected-results/test.3.inter | 2 - tests/test14/expected-results/test.3.inter2 | 2 - .../test14/expected-results/test.predictions | 10 - tests/test14/qsub | 7 - tests/test14/qsub.py | 14 - tests/test14/raw.10.txt | 0 tests/test14/raw.11.txt | 0 tests/test14/raw.12.txt | 0 tests/test14/raw.13.txt | 0 tests/test14/raw.14.txt | 0 tests/test14/raw.15.txt | 0 tests/test14/raw.16.txt | 0 tests/test14/raw.17.txt | 0 tests/test14/raw.18.txt | 0 tests/test14/raw.19.txt | 0 tests/test14/raw.2.txt | 0 tests/test14/raw.20.txt | 0 tests/test14/raw.21.txt | 0 tests/test14/raw.3.txt | 0 tests/test14/raw.4.txt | 0 tests/test14/raw.5.txt | 0 tests/test14/raw.6.txt | 0 tests/test14/raw.7.txt | 0 tests/test14/raw.8.txt | 0 tests/test14/raw.9.txt | 0 tests/test14/test.in | 1 - tests/test_azure_batch_executor.py | 26 - tests/test_cluster_sidecar/Snakefile | 21 - .../test_cluster_sidecar/expected-results/f.1 | 0 .../test_cluster_sidecar/expected-results/f.2 | 0 .../expected-results/launched.txt | 2 - .../expected-results/sidecar.txt | 2 - tests/test_cluster_sidecar/sbatch | 11 - tests/test_cluster_sidecar/sidecar.sh | 20 - tests/test_cluster_sidecar/test.in | 1 - .../Snakefile.nonstandard | 32 - .../expected-results/test.1.inter | 2 - .../expected-results/test.1.inter2 | 2 - .../expected-results/test.2.inter | 2 - .../expected-results/test.2.inter2 | 2 - .../expected-results/test.3.inter | 2 - .../expected-results/test.3.inter2 | 2 - .../expected-results/test.predictions | 10 - tests/test_cluster_statusscript/qsub | 11 - tests/test_cluster_statusscript/status.sh | 1 - tests/test_cluster_statusscript/test.in | 1 - .../Snakefile.nonstandard | 13 - .../expected-results/output.txt | 0 tests/test_cluster_statusscript_multi/sbatch | 8 - .../test_cluster_statusscript_multi/status.sh | 9 - tests/test_conda_function/Snakefile | 4 +- tests/test_conda_named/Snakefile | 6 +- tests/test_conda_pin_file/Snakefile | 2 +- .../Snakefile | 2 + .../env.yaml | 4 + tests/test_google_lifesciences.py | 191 -- tests/test_kubernetes.py | 102 - tests/test_list_untracked/Snakefile | 2 +- .../expected-results/leftover_files | 1 - .../expected-results/leftover_files_WIN | 1 - tests/test_pipes/Snakefile | 2 +- tests/test_slurm.py | 99 - tests/test_srcdir/Snakefile | 7 - tests/test_srcdir/expected-results/test.out | 1 - tests/test_srcdir/script.sh | 2 - tests/test_temp/qsub | 6 - tests/test_tes.py | 84 - tests/test_tibanna.py | 24 - tests/testapi.py | 12 +- tests/tests.py | 552 ++--- 142 files changed, 4397 insertions(+), 10062 deletions(-) delete mode 100644 .github/workflows/test-flux.yaml create mode 100644 motoState.p create mode 100644 snakemake/common/argparse.py create mode 100644 snakemake/common/configfile.py create mode 100644 snakemake/common/git.py create mode 100644 snakemake/common/tests/__init__.py rename snakemake/{executors/slurm => common/tests/testcases}/__init__.py (100%) create mode 100644 snakemake/common/tests/testcases/groups/Snakefile rename tests/test14/raw.0.txt => snakemake/common/tests/testcases/groups/__init__.py (100%) create mode 100644 snakemake/common/tests/testcases/simple/Snakefile rename tests/test14/raw.1.txt => snakemake/common/tests/testcases/simple/__init__.py (100%) create mode 100644 snakemake/common/workdir_handler.py delete mode 100644 snakemake/executors/azure_batch.py create mode 100644 snakemake/executors/dryrun.py delete mode 100644 snakemake/executors/ga4gh_tes.py delete mode 100644 snakemake/executors/google_lifesciences.py delete mode 100755 snakemake/executors/google_lifesciences_helper.py create mode 100644 snakemake/executors/local.py delete mode 100644 snakemake/executors/slurm/slurm_jobstep.py delete mode 100644 snakemake/executors/slurm/slurm_submit.py create mode 100644 snakemake/executors/touch.py create mode 100644 snakemake/settings.py create mode 100644 snakemake/snakemake.code-workspace create mode 100644 snakemake/spawn_jobs.py delete mode 100644 snakemake/stats.py create mode 100644 test.py delete mode 100644 tests/test14/Snakefile.nonstandard delete mode 100644 tests/test14/expected-results/test.1.inter delete mode 100644 tests/test14/expected-results/test.1.inter2 delete mode 100644 tests/test14/expected-results/test.2.inter delete mode 100644 tests/test14/expected-results/test.2.inter2 delete mode 100644 tests/test14/expected-results/test.3.inter delete mode 100644 tests/test14/expected-results/test.3.inter2 delete mode 100644 tests/test14/expected-results/test.predictions delete mode 100755 tests/test14/qsub delete mode 100755 tests/test14/qsub.py delete mode 100644 tests/test14/raw.10.txt delete mode 100644 tests/test14/raw.11.txt delete mode 100644 tests/test14/raw.12.txt delete mode 100644 tests/test14/raw.13.txt delete mode 100644 tests/test14/raw.14.txt delete mode 100644 tests/test14/raw.15.txt delete mode 100644 tests/test14/raw.16.txt delete mode 100644 tests/test14/raw.17.txt delete mode 100644 tests/test14/raw.18.txt delete mode 100644 tests/test14/raw.19.txt delete mode 100644 tests/test14/raw.2.txt delete mode 100644 tests/test14/raw.20.txt delete mode 100644 tests/test14/raw.21.txt delete mode 100644 tests/test14/raw.3.txt delete mode 100644 tests/test14/raw.4.txt delete mode 100644 tests/test14/raw.5.txt delete mode 100644 tests/test14/raw.6.txt delete mode 100644 tests/test14/raw.7.txt delete mode 100644 tests/test14/raw.8.txt delete mode 100644 tests/test14/raw.9.txt delete mode 100644 tests/test14/test.in delete mode 100644 tests/test_azure_batch_executor.py delete mode 100644 tests/test_cluster_sidecar/Snakefile delete mode 100644 tests/test_cluster_sidecar/expected-results/f.1 delete mode 100644 tests/test_cluster_sidecar/expected-results/f.2 delete mode 100644 tests/test_cluster_sidecar/expected-results/launched.txt delete mode 100644 tests/test_cluster_sidecar/expected-results/sidecar.txt delete mode 100755 tests/test_cluster_sidecar/sbatch delete mode 100755 tests/test_cluster_sidecar/sidecar.sh delete mode 100644 tests/test_cluster_sidecar/test.in delete mode 100644 tests/test_cluster_statusscript/Snakefile.nonstandard delete mode 100644 tests/test_cluster_statusscript/expected-results/test.1.inter delete mode 100644 tests/test_cluster_statusscript/expected-results/test.1.inter2 delete mode 100644 tests/test_cluster_statusscript/expected-results/test.2.inter delete mode 100644 tests/test_cluster_statusscript/expected-results/test.2.inter2 delete mode 100644 tests/test_cluster_statusscript/expected-results/test.3.inter delete mode 100644 tests/test_cluster_statusscript/expected-results/test.3.inter2 delete mode 100644 tests/test_cluster_statusscript/expected-results/test.predictions delete mode 100755 tests/test_cluster_statusscript/qsub delete mode 100755 tests/test_cluster_statusscript/status.sh delete mode 100644 tests/test_cluster_statusscript/test.in delete mode 100644 tests/test_cluster_statusscript_multi/Snakefile.nonstandard delete mode 100644 tests/test_cluster_statusscript_multi/expected-results/output.txt delete mode 100755 tests/test_cluster_statusscript_multi/sbatch delete mode 100755 tests/test_cluster_statusscript_multi/status.sh create mode 100644 tests/test_converting_path_for_r_script/env.yaml delete mode 100644 tests/test_google_lifesciences.py delete mode 100644 tests/test_kubernetes.py delete mode 100644 tests/test_slurm.py delete mode 100644 tests/test_srcdir/Snakefile delete mode 100644 tests/test_srcdir/expected-results/test.out delete mode 100644 tests/test_srcdir/script.sh delete mode 100755 tests/test_temp/qsub delete mode 100644 tests/test_tes.py delete mode 100644 tests/test_tibanna.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ccde92195..e4d337280 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,8 +32,7 @@ jobs: source activate black - black --check --diff snakemake tests/tests.py tests/test_tes.py - tests/test_io.py tests/common.py tests/test_google_lifesciences.py + black --check --diff . - name: Comment PR if: github.event_name == 'pull_request' && failure() uses: marocchino/sticky-pull-request-comment@v2.6.2 @@ -60,103 +59,6 @@ jobs: with: fetch-depth: 0 - ###### slurm setup ##### - # prior to slurm-setup we need the podmand-correct command - # see https://github.com/containers/podman/issues/13338 - - name: Download slurm ansible roles - run: | - ansible-galaxy install galaxyproject.slurm,1.0.1 - - name: Update apt cache - run: | - sudo apt-get update - - name: Define slurm playbook - uses: 1arp/create-a-file-action@0.2 - with: - file: slurm-playbook.yml - content: | - - name: Slurm all in One - hosts: localhost - roles: - - role: galaxyproject.slurm - become: true - vars: - slurm_upgrade: true - slurm_roles: ['controller', 'exec', 'dbd'] - slurm_config_dir: /etc/slurm - #slurm_cgroup_config: - # CgroupMountpoint: "/sys/fs/cgroup" - # CgroupAutomount: yes - # ConstrainCores: yes - # TaskAffinity: no - # ConstrainRAMSpace: yes - # ConstrainSwapSpace: no - # ConstrainDevices: no - # AllowedRamSpace: 100 - # AllowedSwapSpace: 0 - # MaxRAMPercent: 100 - # MaxSwapPercent: 100 - # MinRAMSpace: 30 - slurm_config: - ClusterName: cluster - #ProctrackType: proctrack/pgid - #SlurmctldHost: localhost # TODO try if we need this - SlurmctldLogFile: /var/log/slurm/slurmctld.log - SlurmctldPidFile: /run/slurmctld.pid - SlurmdLogFile: /var/log/slurm/slurmd.log - SlurmdPidFile: /run/slurmd.pid - SlurmdSpoolDir: /tmp/slurmd # the default /var/lib/slurm/slurmd does not work because of noexec mounting in github actions - StateSaveLocation: /var/lib/slurm/slurmctld - #TaskPlugin: "task/affinity,task/cgroup" - AccountingStorageType: accounting_storage/slurmdbd - slurmdbd_config: - StorageType: accounting_storage/mysql - PidFile: /run/slurmdbd.pid - LogFile: /var/log/slurm/slurmdbd.log - StoragePass: root - StorageUser: root - StorageHost: 127.0.0.1 # see https://stackoverflow.com/questions/58222386/github-actions-using-mysql-service-throws-access-denied-for-user-rootlocalh - StoragePort: 8888 - DbdHost: localhost - slurm_create_user: yes - #slurm_munge_key: "../../../munge.key" - slurm_nodes: - - name: localhost - State: UNKNOWN - Sockets: 1 - CoresPerSocket: 2 - slurm_user: - comment: "Slurm Workload Manager" - gid: 1002 - group: slurm - home: "/var/lib/slurm" - name: slurm - shell: "/bin/bash" - uid: 1002 - - name: Set XDG_RUNTIME_DIR - run: | - mkdir -p /tmp/1002-runtime # work around podman issue (https://github.com/containers/podman/issues/13338) - echo XDG_RUNTIME_DIR=/tmp/1002-runtime >> $GITHUB_ENV - - name: Setup slurm - run: | - ansible-playbook slurm-playbook.yml || (journalctl -xe && exit 1) - - name: Add Slurm Account - run: | - echo "Waiting 5 seconds for slurm cluster to be fully initialized." - sleep 5 - sudo sacctmgr -i create account "Name=runner" - sudo sacctmgr -i create user "Name=runner" "Account=runner" - - name: Configure proxy for sacct - run: | - # By using this script instead of the real sacct, we avoid the need to install - # a full slurmdbd in the CI. - echo 'alias sacct=.github/workflows/scripts/sacct-proxy.py' >> ~/.bashrc - - name: Test slurm submission - run: | - srun -vvvv echo "hello world" - sudo cat /var/log/slurm/slurmd.log - - name: Indicate supported MPI types - run: | - srun --mpi=list - name: Setup mamba uses: conda-incubator/setup-miniconda@v2 with: @@ -173,13 +75,14 @@ jobs: # TODO remove and add as regular dependency once released pip install git+https://github.com/snakemake/snakemake-interface-executor-plugins.git + pip install git+https://github.com/snakemake/snakemake-executor-plugin-cluster-generic.git + pip install -e . - # additionally add singularity - + # additionally add singularity (not necessary anymore, included in the test env now) # TODO remove version constraint: needed because 3.8.7 fails with missing libz: # bin/unsquashfs: error while loading shared libraries: libz.so.1: # cannot open shared object file: No such file or directory - mamba install -n snakemake "singularity<=3.8.6" + # mamba install -n snakemake "singularity<=3.8.6" - name: Setup apt dependencies run: | sudo gem install apt-spy2 @@ -193,28 +96,6 @@ jobs: sleep 10 docker exec -u irods provider iput /incoming/infile cp -r tests/test_remote_irods/setup-data ~/.irods - #- name: Setup Gcloud - # uses: GoogleCloudPlatform/github-actions/setup-gcloud@v0.2.1 - # if: env.GCP_AVAILABLE - # with: - # project_id: "${{ secrets.GCP_PROJECT_ID }}" - # service_account_email: "${{ secrets.GCP_SA_EMAIL }}" - # service_account_key: "${{ secrets.GCP_SA_KEY }}" - # export_default_credentials: true - #- name: Setup AWS - # uses: aws-actions/configure-aws-credentials@v1 - # if: env.AWS_AVAILABLE - # with: - # aws-access-key-id: "${{ secrets.AWS_ACCESS_KEY_ID }}" - # aws-secret-access-key: "${{ secrets.AWS_SECRET_ACCESS_KEY }}" - # aws-region: us-east-1 - - - name: Test Slurm - env: - CI: true - shell: bash -el {0} - run: | - pytest --show-capture=stderr -v tests/test_slurm.py - name: Test local env: @@ -223,6 +104,7 @@ jobs: shell: bash -el {0} run: | pytest --show-capture=stderr -v -x tests/test_expand.py tests/test_io.py tests/test_schema.py tests/test_linting.py tests/tests.py tests/test_schema.py tests/test_linting.py tests/tests.py + - name: Build and publish docker image if: >- contains(github.event.pull_request.labels.*.name, @@ -240,37 +122,6 @@ jobs: 'update-container-image') run: | echo CONTAINER_IMAGE=snakemake/snakemake:$GITHUB_SHA >> $GITHUB_ENV - #- name: Test Google Life Sciences Executor - # if: env.GCP_AVAILABLE - # shell: bash -el {0} - # run: | - # pytest -s -v -x tests/test_google_lifesciences.py - #- name: Test Kubernetes execution - # if: env.GCP_AVAILABLE - # env: - # CI: true - # shell: bash -el {0} - # run: | - # pytest -s -v -x tests/test_kubernetes.py - - # TODO temporarily disable testing of the azure batch executor as our azure - # account is currently disabled for unknown reason. - # Reactivate once that is fixed. - # - name: Test Azure Batch Executor - # shell: bash -el {0} - # env: - # AZ_BLOB_PREFIX: "${{ secrets.AZ_BLOB_PREFIX }}" - # AZ_BLOB_ACCOUNT_URL: "${{ secrets.AZ_STORAGE_ACCOUNT_URL }}" - # AZ_BLOB_CREDENTIAL: "${{ secrets.AZ_STORAGE_KEY }}" - # AZ_BATCH_ACCOUNT_URL: "${{ secrets.AZ_BATCH_ACCOUNT_URL }}" - # AZ_BATCH_ACCOUNT_KEY: "${{ secrets.AZ_BATCH_KEY }}" - # run: | - # pytest -s -v -x tests/test_azure_batch_executor.py - - - name: Test GA4GH TES executor - shell: bash -el {0} - run: | - pytest --show-capture=stderr -s -v -x tests/test_tes.py - name: Delete container image if: >- @@ -298,7 +149,7 @@ jobs: shell: python run: | import fileinput - excluded_on_win = ["environment-modules", "cwltool", "cwl-utils"] + excluded_on_win = ["environment-modules", "cwltool", "cwl-utils", "apptainer", "squashfuse"] for line in fileinput.input("test-environment.yml", inplace=True): if all(pkg not in line for pkg in excluded_on_win): print(line) @@ -316,6 +167,7 @@ jobs: # TODO remove and add as regular dependency once released pip install git+https://github.com/snakemake/snakemake-interface-executor-plugins.git + pip install -e . - name: Run tests env: CI: true diff --git a/.github/workflows/test-flux.yaml b/.github/workflows/test-flux.yaml deleted file mode 100644 index 2bc5b77aa..000000000 --- a/.github/workflows/test-flux.yaml +++ /dev/null @@ -1,65 +0,0 @@ -name: Test Flux Executor -# TODO temporarily disable flux executor tests as the test container is based on Python 3.8 but Snakemake now requires Python 3.9 -# on: -# pull_request: [] - -jobs: - build: - runs-on: ubuntu-latest - permissions: - packages: read - strategy: - fail-fast: false - matrix: - container: ["fluxrm/flux-sched:focal"] - - container: - image: ${{ matrix.container }} - options: "--platform=linux/amd64 --user root -it" - - name: ${{ matrix.container }} - steps: - - name: Make Space - run: | - rm -rf /usr/share/dotnet - rm -rf /opt/ghc - - - name: Checkout - uses: actions/checkout@v3 - - - name: Setup miniconda - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: snakemake - channels: "conda-forge, bioconda" - miniforge-variant: Mambaforge - miniforge-version: latest - - - name: Install Snakemake - shell: bash -el {0} - run: | - conda config --set channel_priority strict - mamba install python>=3.9 pip - # TODO remove and add as regular dependency once released - pip install git+https://github.com/snakemake/snakemake-interface-executor-plugins.git - pip install . - - - name: Start Flux and Test Workflow - shell: bash -el {0} - run: | - # We must have python3->python accessible for this to work - ln -s /bin/python3 /usr/local/bin/python - su fluxuser - cd examples/flux - # This run does not use conda - which snakemake - flux start snakemake --show-failed-logs --verbose --flux --jobs=1 --no-shared-fs - - - name: Test Flux with Conda - shell: bash -el {0} - run: | - which python - which conda - cp -R ./tests/test_conda /tmp/test_conda - cd /tmp/test_conda - flux start snakemake --show-failed-logs --verbose --flux --jobs=1 --use-conda --conda-frontend=conda diff --git a/motoState.p b/motoState.p new file mode 100644 index 000000000..e2ecf720d --- /dev/null +++ b/motoState.p @@ -0,0 +1 @@ +€}”. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 5fa20aa34..6ebb2ce0d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ packages = find: python_requires = >=3.7 install_requires = appdirs + immutables configargparse connection_pool >=0.0.3 datrie @@ -86,4 +87,4 @@ console_scripts = include = snakemake, snakemake.* [options.package_data] -* = *.css, *.sh, *.html, *.jinja2, *.js, *.svg +* = *.css, *.sh, *.html, *.jinja2, *.js, *.svg, Snakefile diff --git a/snakemake/__init__.py b/snakemake/__init__.py index 7db1142ae..0b47770e2 100644 --- a/snakemake/__init__.py +++ b/snakemake/__init__.py @@ -7,3 +7,10 @@ # Reexports that are part of the public API: from snakemake.shell import shell + + +if __name__ == "__main__": + from snakemake.cli import main + import sys + + main(sys.argv) diff --git a/snakemake/api.py b/snakemake/api.py index 6b9361efc..3e63c70c0 100644 --- a/snakemake/api.py +++ b/snakemake/api.py @@ -3,857 +3,595 @@ __email__ = "johannes.koester@uni-due.de" __license__ = "MIT" +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path import sys +from typing import Dict, List, Optional, Set +import os +from functools import partial +import importlib -from snakemake.common import MIN_PY_VERSION +from snakemake.common import MIN_PY_VERSION, SNAKEFILE_CHOICES +from snakemake.settings import ( + ChangeType, + GroupSettings, + SchedulingSettings, + WorkflowSettings, +) if sys.version_info < MIN_PY_VERSION: raise ValueError(f"Snakemake requires at least Python {'.'.join(MIN_PY_VERSION)}.") -import os -from functools import partial -import importlib +from snakemake.common.workdir_handler import WorkdirHandler +from snakemake.settings import ( + DAGSettings, + DeploymentMethod, + DeploymentSettings, + ExecutionSettings, + OutputSettings, + ConfigSettings, + RemoteExecutionSettings, + ResourceSettings, + StorageSettings, +) -from snakemake_interface_executor_plugins.utils import ExecMode +from snakemake_interface_executor_plugins.settings import ExecMode +from snakemake_interface_executor_plugins import ExecutorSettingsBase +from snakemake_interface_executor_plugins.registry import ExecutorPluginRegistry +from snakemake_interface_common.exceptions import ApiError from snakemake.workflow import Workflow -from snakemake.exceptions import ( - print_exception, - WorkflowError, -) +from snakemake.exceptions import print_exception from snakemake.logging import setup_logger, logger -from snakemake.io import load_configfile from snakemake.shell import shell -from snakemake.utils import update_config from snakemake.common import ( MIN_PY_VERSION, - RERUN_TRIGGERS, __version__, - dict_to_key_value_args, ) from snakemake.resources import DefaultResources -def snakemake( - snakefile, - batch=None, - cache=None, - report=None, - report_stylesheet=None, - containerize=False, - lint=None, - generate_unit_tests=None, - listrules=False, - list_target_rules=False, - cores=1, - nodes=None, - local_cores=1, - max_threads=None, - resources=dict(), - overwrite_threads=None, - overwrite_scatter=None, - overwrite_resource_scopes=None, - default_resources=None, - overwrite_resources=None, - config=dict(), - configfiles=None, - config_args=None, - workdir=None, - targets=None, - target_jobs=None, - dryrun=False, - touch=False, - forcetargets=False, - forceall=False, - forcerun=[], - until=[], - omit_from=[], - prioritytargets=[], - stats=None, - printshellcmds=False, - debug_dag=False, - printdag=False, - printrulegraph=False, - printfilegraph=False, - printd3dag=False, - nocolor=False, - quiet=False, - keepgoing=False, - slurm=None, - slurm_jobstep=None, - rerun_triggers=RERUN_TRIGGERS, - cluster=None, - cluster_sync=None, - drmaa=None, - drmaa_log_dir=None, - jobname="snakejob.{rulename}.{jobid}.sh", - immediate_submit=False, - standalone=False, - ignore_ambiguity=False, - snakemakepath=None, - lock=True, - unlock=False, - cleanup_metadata=None, - conda_cleanup_envs=False, - cleanup_shadow=False, - cleanup_scripts=True, - cleanup_containers=False, - force_incomplete=False, - ignore_incomplete=False, - list_version_changes=False, - list_code_changes=False, - list_input_changes=False, - list_params_changes=False, - list_untracked=False, - list_resources=False, - summary=False, - archive=None, - delete_all_output=False, - delete_temp_output=False, - detailed_summary=False, - latency_wait=3, - wait_for_files=None, - print_compilation=False, - debug=False, - notemp=False, - all_temp=False, - keep_remote_local=False, - nodeps=False, - keep_target_files=False, - allowed_rules=None, - jobscript=None, - greediness=None, - no_hooks=False, - overwrite_shellcmd=None, - updated_files=None, - log_handler=[], - keep_logger=False, - max_jobs_per_second=None, - max_status_checks_per_second=100, - restart_times=0, - attempt=1, - verbose=False, - force_use_threads=False, - use_conda=False, - use_singularity=False, - use_env_modules=False, - singularity_args="", - conda_frontend="conda", - conda_prefix=None, - conda_cleanup_pkgs=None, - list_conda_envs=False, - singularity_prefix=None, - shadow_prefix=None, - scheduler="ilp", - scheduler_ilp_solver=None, - conda_create_envs_only=False, - mode=ExecMode.default, - wrapper_prefix=None, - kubernetes=None, - container_image=None, - k8s_cpu_scalar=1.0, - k8s_service_account_name=None, - flux=False, - tibanna=False, - tibanna_sfn=None, - az_batch=False, - az_batch_enable_autoscale=False, - az_batch_account_url=None, - google_lifesciences=False, - google_lifesciences_regions=None, - google_lifesciences_location=None, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=None, - google_lifesciences_network=None, - google_lifesciences_subnetwork=None, - tes=None, - preemption_default=None, - preemptible_rules=None, - precommand="", - default_remote_provider=None, - default_remote_prefix="", - tibanna_config=False, - assume_shared_fs=True, - cluster_status=None, - cluster_cancel=None, - cluster_cancel_nargs=None, - cluster_sidecar=None, - export_cwl=None, - show_failed_logs=False, - keep_incomplete=False, - keep_metadata=True, - messaging=None, - edit_notebook=None, - envvars=None, - overwrite_groups=None, - group_components=None, - max_inventory_wait_time=20, - execute_subworkflows=True, - conda_not_block_search_path_envvars=False, - scheduler_solver_path=None, - conda_base_path=None, - local_groupid="local", - executor_args=None, -): - """Run snakemake on a given snakefile. - - This function provides access to the whole snakemake functionality. It is not thread-safe. - - Args: - snakefile (str): the path to the snakefile - batch (Batch): whether to compute only a partial DAG, defined by the given Batch object (default None) - report (str): create an HTML report for a previous run at the given path - lint (str): print lints instead of executing (None, "plain" or "json", default None) - listrules (bool): list rules (default False) - list_target_rules (bool): list target rules (default False) - cores (int): the number of provided cores (ignored when using cluster support) (default 1) - nodes (int): the number of provided cluster nodes (ignored without cluster support) (default 1) - local_cores (int): the number of provided local cores if in cluster mode (ignored without cluster support) (default 1) - resources (dict): provided resources, a dictionary assigning integers to resource names, e.g. {gpu=1, io=5} (default {}) - default_resources (DefaultResources): default values for resources not defined in rules (default None) - config (dict): override values for workflow config - workdir (str): path to the working directory (default None) - targets (list): list of targets, e.g. rule or file names (default None) - target_jobs (dict): list of snakemake.target_jobs.TargetSpec objects directly targeting specific jobs (default None) - dryrun (bool): only dry-run the workflow (default False) - touch (bool): only touch all output files if present (default False) - forcetargets (bool): force given targets to be re-created (default False) - forceall (bool): force all output files to be re-created (default False) - forcerun (list): list of files and rules that shall be re-created/re-executed (default []) - execute_subworkflows (bool): execute subworkflows if present (default True) - prioritytargets (list): list of targets that shall be run with maximum priority (default []) - stats (str): path to file that shall contain stats about the workflow execution (default None) - printshellcmds (bool): print the shell command of each job (default False) - printdag (bool): print the dag in the graphviz dot language (default False) - printrulegraph (bool): print the graph of rules in the graphviz dot language (default False) - printfilegraph (bool): print the graph of rules with their input and output files in the graphviz dot language (default False) - printd3dag (bool): print a D3.js compatible JSON representation of the DAG (default False) - nocolor (bool): do not print colored output (default False) - quiet (bool): do not print any default job information (default False) - keepgoing (bool): keep going upon errors (default False) - cluster (str): submission command of a cluster or batch system to use, e.g. qsub (default None) - cluster_sync (str): blocking cluster submission command (like SGE 'qsub -sync y') (default None) - drmaa (str): if not None use DRMAA for cluster support, str specifies native args passed to the cluster when submitting a job - drmaa_log_dir (str): the path to stdout and stderr output of DRMAA jobs (default None) - jobname (str): naming scheme for cluster job scripts (default "snakejob.{rulename}.{jobid}.sh") - immediate_submit (bool): immediately submit all cluster jobs, regardless of dependencies (default False) - standalone (bool): kill all processes very rudely in case of failure (do not use this if you use this API) (default False) (deprecated) - ignore_ambiguity (bool): ignore ambiguous rules and always take the first possible one (default False) - snakemakepath (str): deprecated parameter whose value is ignored. Do not use. - lock (bool): lock the working directory when executing the workflow (default True) - unlock (bool): just unlock the working directory (default False) - cleanup_metadata (list): just cleanup metadata of given list of output files (default None) - drop_metadata (bool): drop metadata file tracking information after job finishes (--report and --list_x_changes information will be incomplete) (default False) - conda_cleanup_envs (bool): just cleanup unused conda environments (default False) - cleanup_shadow (bool): just cleanup old shadow directories (default False) - cleanup_scripts (bool): delete wrapper scripts used for execution (default True) - cleanup_containers (bool): delete unused (singularity) containers (default False) - force_incomplete (bool): force the re-creation of incomplete files (default False) - ignore_incomplete (bool): ignore incomplete files (default False) - list_version_changes (bool): list output files with changed rule version (default False) - list_code_changes (bool): list output files with changed rule code (default False) - list_input_changes (bool): list output files with changed input files (default False) - list_params_changes (bool): list output files with changed params (default False) - list_untracked (bool): list files in the workdir that are not used in the workflow (default False) - summary (bool): list summary of all output files and their status (default False) - archive (str): archive workflow into the given tarball - delete_all_output (bool): remove all files generated by the workflow (default False) - delete_temp_output (bool): remove all temporary files generated by the workflow (default False) - latency_wait (int): how many seconds to wait for an output file to appear after the execution of a job, e.g. to handle filesystem latency (default 3) - wait_for_files (list): wait for given files to be present before executing the workflow - list_resources (bool): list resources used in the workflow (default False) - summary (bool): list summary of all output files and their status (default False). If no option is specified a basic summary will be output. If 'detailed' is added as an option e.g --summary detailed, extra info about the input and shell commands will be included - detailed_summary (bool): list summary of all input and output files and their status (default False) - print_compilation (bool): print the compilation of the snakefile (default False) - debug (bool): allow to use the debugger within rules - notemp (bool): ignore temp file flags, e.g. do not delete output files marked as a temp after use (default False) - keep_remote_local (bool): keep local copies of remote files (default False) - nodeps (bool): ignore dependencies (default False) - keep_target_files (bool): do not adjust the paths of given target files relative to the working directory. - allowed_rules (set): restrict allowed rules to the given set. If None or empty, all rules are used. - jobscript (str): path to a custom shell script template for cluster jobs (default None) - greediness (float): set the greediness of scheduling. This value between 0 and 1 determines how careful jobs are selected for execution. The default value (0.5 if prioritytargets are used, 1.0 else) provides the best speed and still acceptable scheduling quality. - overwrite_shellcmd (str): a shell command that shall be executed instead of those given in the workflow. This is for debugging purposes only. - updated_files(list): a list that will be filled with the files that are updated or created during the workflow execution - verbose (bool): show additional debug output (default False) - max_jobs_per_second (int): maximal number of cluster/drmaa jobs per second, None to impose no limit (default None) - restart_times (int): number of times to restart failing jobs (default 0) - attempt (int): initial value of Job.attempt. This is intended for internal use only (default 1). - force_use_threads: whether to force the use of threads over processes. helpful if shared memory is full or unavailable (default False) - use_conda (bool): use conda environments for each job (defined with conda directive of rules) - use_singularity (bool): run jobs in singularity containers (if defined with singularity directive) - use_env_modules (bool): load environment modules if defined in rules - singularity_args (str): additional arguments to pass to a singularity - conda_prefix (str): the directory in which conda environments will be created (default None) - conda_cleanup_pkgs (snakemake.deployment.conda.CondaCleanupMode): - whether to clean up conda tarballs after env creation (default None), valid values: "tarballs", "cache" - singularity_prefix (str): the directory to which singularity images will be pulled (default None) - shadow_prefix (str): prefix for shadow directories. The job-specific shadow directories will be created in $SHADOW_PREFIX/shadow/ (default None) - conda_create_envs_only (bool): if specified, only builds the conda environments specified for each job, then exits. - list_conda_envs (bool): list conda environments and their location on disk. - mode (snakemake.common.Mode): execution mode - wrapper_prefix (str): prefix for wrapper script URLs (default None) - kubernetes (str): submit jobs to Kubernetes, using the given namespace. - container_image (str): Docker image to use, e.g., for Kubernetes. - k8s_cpu_scalar (float): What proportion of each k8s node's CPUs are availabe to snakemake? - k8s_service_account_name (str): Custom k8s service account, needed for workload identity. - flux (bool): Launch workflow to flux cluster. - default_remote_provider (str): default remote provider to use instead of local files (e.g. S3, GS) - default_remote_prefix (str): prefix for default remote provider (e.g. name of the bucket). - tibanna (bool): submit jobs to AWS cloud using Tibanna. - tibanna_sfn (str): Step function (Unicorn) name of Tibanna (e.g. tibanna_unicorn_monty). This must be deployed first using tibanna cli. - az_batch (bool): Submit jobs to azure batch. - az_batch_enable_autoscale (bool): Enable autoscaling of the azure batch pool nodes. This sets the initial dedicated node pool count to zero and resizes the pool only after 5 minutes. So this flag is only recommended for relatively long running jobs., - az_batch_account_url (str): Azure batch account url. - google_lifesciences (bool): submit jobs to Google Cloud Life Sciences (pipelines API). - google_lifesciences_regions (list): a list of regions (e.g., us-east1) - google_lifesciences_location (str): Life Sciences API location (e.g., us-central1) - google_lifesciences_cache (bool): save a cache of the compressed working directories in Google Cloud Storage for later usage. - google_lifesciences_service_account_email (str): Service account to install on Google pipelines API VM instance. - google_lifesciences_network (str): Network name for Google VM instances. - google_lifesciences_subnetwork (str): Subnetwork name for Google VM instances. - tes (str): Execute workflow tasks on GA4GH TES server given by URL. - precommand (str): commands to run on AWS cloud before the snakemake command (e.g. wget, git clone, unzip, etc). Use with --tibanna. - preemption_default (int): set a default number of preemptible instance retries (for Google Life Sciences executor only) - preemptible_rules (list): define custom preemptible instance retries for specific rules (for Google Life Sciences executor only) - tibanna_config (list): Additional tibanna config e.g. --tibanna-config spot_instance=true subnet= security group= - assume_shared_fs (bool): assume that cluster nodes share a common filesystem (default true). - cluster_status (str): status command for cluster execution. If None, Snakemake will rely on flag files. Otherwise, it expects the command to return "success", "failure" or "running" when executing with a cluster jobid as a single argument. - cluster_cancel (str): command to cancel multiple job IDs (like SLURM 'scancel') (default None) - cluster_cancel_nargs (int): maximal number of job ids to pass to cluster_cancel (default 1000) - cluster_sidecar (str): command that starts a sidecar process, see cluster documentation (default None) - export_cwl (str): Compile workflow to CWL and save to given file - log_handler (function): redirect snakemake output to this custom log handler, a function that takes a log message dictionary (see below) as its only argument (default None). The log message dictionary for the log handler has to following entries: - keep_incomplete (bool): keep incomplete output files of failed jobs - edit_notebook (object): "notebook.EditMode" object to configure notebook server for interactive editing of a rule notebook. If None, do not edit. - scheduler (str): Select scheduling algorithm (default ilp) - scheduler_ilp_solver (str): Set solver for ilp scheduler. - overwrite_groups (dict): Rule to group assignments (default None) - group_components (dict): Number of connected components given groups shall span before being split up (1 by default if empty) - conda_not_block_search_path_envvars (bool): Do not block search path envvars (R_LIBS, PYTHONPATH, ...) when using conda environments. - scheduler_solver_path (str): Path to Snakemake environment (this can be used to e.g. overwrite the search path for the ILP solver used during scheduling). - conda_base_path (str): Path to conda base environment (this can be used to overwrite the search path for conda, mamba, and activate). - local_groupid (str): Local groupid to use as a placeholder for groupid-referrring input functions of local jobs (internal use only, default: local). - log_handler (list): redirect snakemake output to this list of custom log handlers, each a function that takes a log message dictionary (see below) as its only argument (default []). The log message dictionary for the log handler has to following entries: - executor_args (dataclasses.Dataclass): custom Data class to pass to custom executors for more flexibility - :level: - the log level ("info", "error", "debug", "progress", "job_info") - - :level="info", "error" or "debug": - :msg: - the log message - :level="progress": - :done: - number of already executed jobs - - :total: - number of total jobs - - :level="job_info": - :input: - list of input files of a job - - :output: - list of output files of a job - - :log: - path to log file of a job - - :local: - whether a job is executed locally (i.e. ignoring cluster) - - :msg: - the job message - - :reason: - the job reason - - :priority: - the job priority - - :threads: - the threads of the job - - - Returns: - bool: True if workflow execution was successful. +class ApiBase(ABC): + def __post_init__(self): + self._check() + + def _check(self): + # nothing to check by default + # override in subclasses if needed + pass + + +def resolve_snakefile(path: Optional[Path]): + """Get path to the snakefile. + Arguments + --------- + path: Optional[Path] -- The path to the snakefile. If not provided, default locations will be tried. """ - assert not immediate_submit or ( - immediate_submit and notemp - ), "immediate_submit has to be combined with notemp (it does not support temp file handling)" - - if tibanna: - assume_shared_fs = False - default_remote_provider = "S3" - default_remote_prefix = default_remote_prefix.rstrip("/") - assert ( - default_remote_prefix - ), "default_remote_prefix needed if tibanna is specified" - assert tibanna_sfn, "tibanna_sfn needed if tibanna is specified" - if tibanna_config: - tibanna_config_dict = dict() - for cf in tibanna_config: - k, v = cf.split("=") - if v == "true": - v = True - elif v == "false": - v = False - elif v.isnumeric(): - v = int(v) - else: - try: - v = float(v) - except ValueError: - pass - tibanna_config_dict.update({k: v}) - tibanna_config = tibanna_config_dict - - # Azure batch uses compute engine and storage - if az_batch: - assume_shared_fs = False - default_remote_provider = "AzBlob" - - # Google Cloud Life Sciences API uses compute engine and storage - if google_lifesciences: - assume_shared_fs = False - default_remote_provider = "GS" - default_remote_prefix = default_remote_prefix.rstrip("/") - if kubernetes: - assume_shared_fs = False - - # Currently preemptible instances only supported for Google LifeSciences Executor - if preemption_default or preemptible_rules and not google_lifesciences: - logger.warning( - "Preemptible instances are only available for the Google Life Sciences Executor." + if path is None: + for p in SNAKEFILE_CHOICES: + if p.exists(): + return p + raise ApiError( + f"No Snakefile found, tried {', '.join(map(str, SNAKEFILE_CHOICES))}." ) + return path + + +@dataclass +class SnakemakeApi(ApiBase): + """The Snakemake API. + + Arguments + --------- - if updated_files is None: - updated_files = list() - - run_local = not ( - cluster - or cluster_sync - or drmaa - or kubernetes - or tibanna - or az_batch - or google_lifesciences - or tes - or slurm - or slurm_jobstep - ) - if run_local: - if not dryrun: - # clean up all previously recorded jobids. - shell.cleanup() - else: - if default_resources is None: - # use full default resources if in cluster or cloud mode - default_resources = DefaultResources(mode="full") - if edit_notebook: - raise WorkflowError( - "Notebook edit mode is only allowed with local execution." + output_settings: OutputSettings -- The output settings for the Snakemake API. + """ + + output_settings: OutputSettings = field(default_factory=OutputSettings) + _workflow_api: Optional["WorkflowApi"] = field(init=False, default=None) + _is_in_context: bool = field(init=False, default=False) + + def workflow( + self, + resource_settings: ResourceSettings, + config_settings: Optional[ConfigSettings] = None, + storage_settings: Optional[StorageSettings] = None, + workflow_settings: Optional[WorkflowSettings] = None, + snakefile: Optional[Path] = None, + workdir: Optional[Path] = None, + ): + """Create the workflow API. + + Note that if provided, this also changes to the provided workdir. + It will change back to the previous working directory when the workflow API object is deleted. + + Arguments + --------- + config_settings: ConfigSettings -- The config settings for the workflow. + resource_settings: ResourceSettings -- The resource settings for the workflow. + storage_settings: StorageSettings -- The storage settings for the workflow. + snakefile: Optional[Path] -- The path to the snakefile. If not provided, default locations will be tried. + workdir: Optional[Path] -- The path to the working directory. If not provided, the current working directory will be used. + """ + + if config_settings is None: + config_settings = ConfigSettings() + if storage_settings is None: + storage_settings = StorageSettings() + if workflow_settings is None: + workflow_settings = WorkflowSettings() + + self._check_is_in_context() + + self._setup_logger() + + snakefile = resolve_snakefile(snakefile) + + self._workflow_api = WorkflowApi( + snakemake_api=self, + snakefile=snakefile, + workdir=workdir, + config_settings=config_settings, + resource_settings=resource_settings, + storage_settings=storage_settings, + workflow_settings=workflow_settings, + ) + return self._workflow_api + + def _cleanup(self): + """Cleanup the workflow.""" + if not self.output_settings.keep_logger: + logger.cleanup() + if self._workflow_api is not None: + self._workflow_api._workdir_handler.change_back() + if ( + self._workflow_api._workflow_store is not None + and self._workflow_api._workflow._workdir_handler is not None + ): + self._workflow_api._workflow._workdir_handler.change_back() + + def print_exception(self, ex: Exception): + """Print an exception during workflow execution in a human readable way + (with adjusted line numbers for exceptions raised in Snakefiles and stack + traces that hide Snakemake internals for better readability). + + Arguments + --------- + ex: Exception -- The exception to print. + """ + linemaps = ( + self._workflow_api._workflow.linemaps + if self._workflow_api is not None + else dict() + ) + print_exception(ex, linemaps) + + def _setup_logger( + self, + stdout: bool = False, + mode: ExecMode = ExecMode.DEFAULT, + dryrun: bool = False, + ): + if not self.output_settings.keep_logger: + setup_logger( + handler=self.output_settings.log_handlers, + quiet=self.output_settings.quiet, + nocolor=self.output_settings.nocolor, + debug=self.output_settings.verbose, + printshellcmds=self.output_settings.printshellcmds, + debug_dag=self.output_settings.debug_dag, + stdout=stdout, + mode=mode, + show_failed_logs=self.output_settings.show_failed_logs, + dryrun=dryrun, ) - shell.conda_block_conflicting_envvars = not conda_not_block_search_path_envvars - - # force thread use for any kind of cluster - use_threads = ( - force_use_threads - or (os.name not in ["posix", "nt"]) - or cluster - or cluster_sync - or drmaa - ) - - if not keep_logger: - stdout = ( - ( - dryrun - and not (printdag or printd3dag or printrulegraph or printfilegraph) + def _check_is_in_context(self): + if not self._is_in_context: + raise ApiError( + "This method can only be called when SnakemakeApi is used within a with " + "statement." ) - or listrules - or list_target_rules - or list_resources + + def __enter__(self): + self._is_in_context = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._is_in_context = False + self._cleanup() + + +@dataclass +class WorkflowApi(ApiBase): + """The workflow API. + + Arguments + --------- + snakemake_api: SnakemakeApi -- The Snakemake API. + snakefile: Path -- The path to the snakefile. + config_settings: ConfigSettings -- The config settings for the workflow. + resource_settings: ResourceSettings -- The resource settings for the workflow. + """ + + snakemake_api: SnakemakeApi + snakefile: Path + workdir: Optional[Path] + config_settings: ConfigSettings + resource_settings: ResourceSettings + storage_settings: StorageSettings + workflow_settings: WorkflowSettings + _workflow_store: Optional[Workflow] = field(init=False, default=None) + _workdir_handler: Optional[WorkdirHandler] = field(init=False) + + def dag( + self, + dag_settings: Optional[DAGSettings] = None, + deployment_settings: Optional[DeploymentSettings] = None, + ): + """Create a DAG API. + + Arguments + --------- + dag_settings: DAGSettings -- The DAG settings for the DAG API. + """ + if dag_settings is None: + dag_settings = DAGSettings() + if deployment_settings is None: + deployment_settings = DeploymentSettings() + + return DAGApi( + self.snakemake_api, + self, + dag_settings=dag_settings, + deployment_settings=deployment_settings, ) - setup_logger( - handler=log_handler, - quiet=quiet, - printshellcmds=printshellcmds, - debug_dag=debug_dag, - nocolor=nocolor, - stdout=stdout, - debug=verbose, - use_threads=use_threads, - mode=mode, - show_failed_logs=show_failed_logs, - dryrun=dryrun, + def lint(self, json: bool = False): + """Lint the workflow. + + Arguments + --------- + json: bool -- Whether to print the linting results as JSON. + + Returns + ------- + True if any lints were printed + """ + workflow = self._get_workflow(check_envvars=False) + workflow.include( + self.snakefile, overwrite_default_target=True, print_compilation=False + ) + workflow.check() + return workflow.lint(json=json) + + def list_rules(self, only_targets: bool = False): + """List the rules of the workflow. + + Arguments + --------- + only_targets: bool -- Whether to only list target rules. + """ + self._workflow.list_rules(only_targets=only_targets) + + def list_resources(self): + """List the resources of the workflow.""" + self._workflow.list_resources() + + def print_compilation(self): + """Print the pure python compilation of the workflow.""" + workflow = self._get_workflow() + workflow.include(self.snakefile, print_compilation=True) + + @property + def _workflow(self): + if self._workflow_store is None: + workflow = self._get_workflow() + workflow.include( + self.snakefile, overwrite_default_target=True, print_compilation=False + ) + workflow.check() + self._workflow_store = workflow + return self._workflow_store + + def _get_workflow(self, **kwargs): + from snakemake.workflow import Workflow + + return Workflow( + config_settings=self.config_settings, + resource_settings=self.resource_settings, + workflow_settings=self.workflow_settings, + storage_settings=self.storage_settings, + output_settings=self.snakemake_api.output_settings, + overwrite_workdir=self.workdir, + **kwargs, ) - if greediness is None: - greediness = 0.5 if prioritytargets else 1.0 - else: - if not (0 <= greediness <= 1.0): - logger.error("Error: greediness must be a float between 0 and 1.") - return False - - if not os.path.exists(snakefile): - logger.error(f'Error: Snakefile "{snakefile}" not found.') - return False - snakefile = os.path.abspath(snakefile) - - cluster_mode = ( - (cluster is not None) + (cluster_sync is not None) + (drmaa is not None) - ) - if cluster_mode > 1: - logger.error("Error: cluster and drmaa args are mutually exclusive") - return False - - if debug and (cluster_mode or cores is not None and cores > 1): - logger.error( - "Error: debug mode cannot be used with more than one core or cluster execution." + def __post_init__(self): + super().__post_init__() + self.snakefile = self.snakefile.absolute() + self._workdir_handler = WorkdirHandler(self.workdir) + self._workdir_handler.change_to() + + def _check(self): + if not self.snakefile.exists(): + raise ApiError(f'Snakefile "{self.snakefile}" not found.') + + +@dataclass +class DAGApi(ApiBase): + """The DAG API. + + Arguments + --------- + snakemake_api: SnakemakeApi -- The Snakemake API. + workflow_api: WorkflowApi -- The workflow API. + dag_settings: DAGSettings -- The DAG settings for the DAG API. + """ + + snakemake_api: SnakemakeApi + workflow_api: WorkflowApi + dag_settings: DAGSettings + deployment_settings: DeploymentSettings + + def __post_init__(self): + self.workflow_api._workflow.dag_settings = self.dag_settings + self.workflow_api._workflow.deployment_settings = self.deployment_settings + + def execute_workflow( + self, + executor: str = "local", + execution_settings: Optional[ExecutionSettings] = None, + remote_execution_settings: Optional[RemoteExecutionSettings] = None, + scheduling_settings: Optional[SchedulingSettings] = None, + group_settings: Optional[GroupSettings] = None, + executor_settings: Optional[ExecutorSettingsBase] = None, + updated_files: Optional[List[str]] = None, + ): + """Execute the workflow. + + Arguments + --------- + executor: str -- The executor to use. + execution_settings: ExecutionSettings -- The execution settings for the workflow. + resource_settings: ResourceSettings -- The resource settings for the workflow. + deployment_settings: DeploymentSettings -- The deployment settings for the workflow. + remote_execution_settings: RemoteExecutionSettings -- The remote execution settings for the workflow. + executor_settings: Optional[ExecutorSettingsBase] -- The executor settings for the workflow. + updated_files: Optional[List[str]] -- An optional list where Snakemake will put all updated files. + """ + + if execution_settings is None: + execution_settings = ExecutionSettings() + if remote_execution_settings is None: + remote_execution_settings = RemoteExecutionSettings() + if scheduling_settings is None: + scheduling_settings = SchedulingSettings() + if group_settings is None: + group_settings = GroupSettings() + + if ( + remote_execution_settings.immediate_submit + and not self.workflow_api.storage_settings.notemp + ): + raise ApiError( + "immediate_submit has to be combined with notemp (it does not support temp file handling)" + ) + + executor_plugin_registry = _get_executor_plugin_registry() + executor_plugin = executor_plugin_registry.get_plugin(executor) + + if executor_plugin.common_settings.implies_no_shared_fs: + self.workflow_api.storage_settings.assume_shared_fs = False + + self.snakemake_api._setup_logger( + stdout=executor_plugin.common_settings.dryrun_exec, + mode=execution_settings.mode, + dryrun=executor_plugin.common_settings.dryrun_exec, ) - return False - - overwrite_config = dict() - if configfiles is None: - configfiles = [] - for f in configfiles: - # get values to override. Later configfiles override earlier ones. - update_config(overwrite_config, load_configfile(f)) - # convert provided paths to absolute paths - configfiles = list(map(os.path.abspath, configfiles)) - - # directly specified elements override any configfiles - if config: - update_config(overwrite_config, config) - if config_args is None: - config_args = dict_to_key_value_args(config) - - if workdir: - olddir = os.getcwd() - if not os.path.exists(workdir): - logger.info(f"Creating specified working directory {workdir}.") - os.makedirs(workdir) - workdir = os.path.abspath(workdir) - os.chdir(workdir) - - logger.setup_logfile() - - try: - # handle default remote provider - _default_remote_provider = None - if default_remote_provider is not None: - try: - rmt = importlib.import_module( - "snakemake.remote." + default_remote_provider + + if executor_plugin.common_settings.local_exec: + if ( + not executor_plugin.common_settings.dryrun_exec + and not executor_plugin.common_settings.touch_exec + ): + if self.workflow_api.resource_settings.cores is None: + raise ApiError( + "cores have to be specified for local execution " + "(use --cores N with N being a number >= 1 or 'all')" + ) + # clean up all previously recorded jobids. + shell.cleanup() + else: + # set cores if that is not done yet + if self.workflow_api.resource_settings.cores is None: + self.workflow_api.resource_settings.cores = 1 + if ( + execution_settings.debug + and self.workflow_api.resource_settings.cores > 1 + ): + raise ApiError( + "debug mode cannot be used with multi-core execution, " + "please enforce a single core by setting --cores 1" + ) + else: + if self.workflow_api.resource_settings.nodes is None: + raise ApiError( + "maximum number of parallel jobs/used nodes has to be specified for remote execution " + "(use --jobs N with N being a number >= 1)" ) - except ImportError as e: - raise WorkflowError("Unknown default remote provider.") - if rmt.RemoteProvider.supports_default: - _default_remote_provider = rmt.RemoteProvider( - keep_local=keep_remote_local, is_default=True + # non local execution + if self.workflow_api.resource_settings.default_resources is None: + # use full default resources if in cluster or cloud mode + self.workflow_api.resource_settings.default_resources = ( + DefaultResources(mode="full") ) - else: - raise WorkflowError( - "Remote provider {} does not (yet) support to " - "be used as default provider." + if execution_settings.edit_notebook is not None: + raise ApiError( + "notebook edit mode is only allowed with local execution." ) + if execution_settings.debug: + raise ApiError("debug mode cannot be used with non-local execution") - workflow = Workflow( - snakefile=snakefile, - rerun_triggers=rerun_triggers, - jobscript=jobscript, - overwrite_shellcmd=overwrite_shellcmd, - overwrite_config=overwrite_config, - overwrite_workdir=workdir, - overwrite_configfiles=configfiles, - overwrite_threads=overwrite_threads, - max_threads=max_threads, - overwrite_scatter=overwrite_scatter, - overwrite_groups=overwrite_groups, - overwrite_resources=overwrite_resources, - overwrite_resource_scopes=overwrite_resource_scopes, - group_components=group_components, - config_args=config_args, - debug=debug, - verbose=verbose, - use_conda=use_conda or list_conda_envs or conda_cleanup_envs, - use_singularity=use_singularity, - use_env_modules=use_env_modules, - conda_frontend=conda_frontend, - conda_prefix=conda_prefix, - conda_cleanup_pkgs=conda_cleanup_pkgs, - singularity_prefix=singularity_prefix, - shadow_prefix=shadow_prefix, - singularity_args=singularity_args, - scheduler_type=scheduler, - scheduler_ilp_solver=scheduler_ilp_solver, - mode=mode, - wrapper_prefix=wrapper_prefix, - printshellcmds=printshellcmds, - restart_times=restart_times, - attempt=attempt, - default_remote_provider=_default_remote_provider, - default_remote_prefix=default_remote_prefix, - run_local=run_local, - assume_shared_fs=assume_shared_fs, - default_resources=default_resources, - cache=cache, - cores=cores, - nodes=nodes, - resources=resources, - edit_notebook=edit_notebook, - envvars=envvars, - max_inventory_wait_time=max_inventory_wait_time, - conda_not_block_search_path_envvars=conda_not_block_search_path_envvars, - execute_subworkflows=execute_subworkflows, - scheduler_solver_path=scheduler_solver_path, - conda_base_path=conda_base_path, - check_envvars=not lint, # for linting, we do not need to check whether requested envvars exist - all_temp=all_temp, - local_groupid=local_groupid, - keep_metadata=keep_metadata, - latency_wait=latency_wait, - executor_args=executor_args, - cleanup_scripts=cleanup_scripts, - immediate_submit=immediate_submit, - quiet=quiet, + execution_settings.use_threads = ( + execution_settings.use_threads + or (os.name not in ["posix"]) + or not executor_plugin.common_settings.local_exec ) - success = True - workflow.include( - snakefile, - overwrite_default_target=True, - print_compilation=print_compilation, + logger.setup_logfile() + + workflow = self.workflow_api._workflow + workflow.execution_settings = execution_settings + workflow.remote_execution_settings = remote_execution_settings + workflow.scheduling_settings = scheduling_settings + workflow.group_settings = group_settings + + workflow.execute( + executor_plugin=executor_plugin, + executor_settings=executor_settings, + updated_files=updated_files, ) - workflow.check() - if not print_compilation: - if lint: - success = not workflow.lint(json=lint == "json") - elif listrules: - workflow.list_rules() - elif list_target_rules: - workflow.list_rules(only_targets=True) - elif list_resources: - workflow.list_resources() - else: - # if not printdag and not printrulegraph: - # handle subworkflows - subsnakemake = partial( - snakemake, - local_cores=local_cores, - max_threads=max_threads, - cache=cache, - overwrite_threads=overwrite_threads, - overwrite_scatter=overwrite_scatter, - overwrite_resources=overwrite_resources, - overwrite_resource_scopes=overwrite_resource_scopes, - default_resources=default_resources, - dryrun=dryrun, - touch=touch, - printshellcmds=printshellcmds, - debug_dag=debug_dag, - nocolor=nocolor, - quiet=quiet, - keepgoing=keepgoing, - cluster=cluster, - cluster_sync=cluster_sync, - drmaa=drmaa, - drmaa_log_dir=drmaa_log_dir, - jobname=jobname, - immediate_submit=immediate_submit, - standalone=standalone, - ignore_ambiguity=ignore_ambiguity, - restart_times=restart_times, - attempt=attempt, - lock=lock, - unlock=unlock, - cleanup_metadata=cleanup_metadata, - conda_cleanup_envs=conda_cleanup_envs, - cleanup_containers=cleanup_containers, - cleanup_shadow=cleanup_shadow, - cleanup_scripts=cleanup_scripts, - force_incomplete=force_incomplete, - ignore_incomplete=ignore_incomplete, - latency_wait=latency_wait, - verbose=verbose, - notemp=notemp, - all_temp=all_temp, - keep_remote_local=keep_remote_local, - nodeps=nodeps, - jobscript=jobscript, - greediness=greediness, - no_hooks=no_hooks, - overwrite_shellcmd=overwrite_shellcmd, - config=config, - config_args=config_args, - keep_logger=True, - force_use_threads=use_threads, - use_conda=use_conda, - use_singularity=use_singularity, - use_env_modules=use_env_modules, - conda_prefix=conda_prefix, - conda_cleanup_pkgs=conda_cleanup_pkgs, - conda_frontend=conda_frontend, - singularity_prefix=singularity_prefix, - shadow_prefix=shadow_prefix, - singularity_args=singularity_args, - scheduler=scheduler, - scheduler_ilp_solver=scheduler_ilp_solver, - list_conda_envs=list_conda_envs, - kubernetes=kubernetes, - container_image=container_image, - k8s_cpu_scalar=k8s_cpu_scalar, - k8s_service_account_name=k8s_service_account_name, - conda_create_envs_only=conda_create_envs_only, - default_remote_provider=default_remote_provider, - default_remote_prefix=default_remote_prefix, - tibanna=tibanna, - tibanna_sfn=tibanna_sfn, - az_batch=az_batch, - az_batch_enable_autoscale=az_batch_enable_autoscale, - az_batch_account_url=az_batch_account_url, - google_lifesciences=google_lifesciences, - google_lifesciences_regions=google_lifesciences_regions, - google_lifesciences_location=google_lifesciences_location, - google_lifesciences_cache=google_lifesciences_cache, - google_lifesciences_service_account_email=google_lifesciences_service_account_email, - google_lifesciences_network=google_lifesciences_network, - google_lifesciences_subnetwork=google_lifesciences_subnetwork, - flux=flux, - tes=tes, - precommand=precommand, - preemption_default=preemption_default, - preemptible_rules=preemptible_rules, - tibanna_config=tibanna_config, - assume_shared_fs=assume_shared_fs, - cluster_status=cluster_status, - cluster_cancel=cluster_cancel, - cluster_cancel_nargs=cluster_cancel_nargs, - cluster_sidecar=cluster_sidecar, - max_jobs_per_second=max_jobs_per_second, - max_status_checks_per_second=max_status_checks_per_second, - overwrite_groups=overwrite_groups, - group_components=group_components, - max_inventory_wait_time=max_inventory_wait_time, - conda_not_block_search_path_envvars=conda_not_block_search_path_envvars, - local_groupid=local_groupid, - ) - success = workflow.execute( - targets=targets, - target_jobs=target_jobs, - dryrun=dryrun, - generate_unit_tests=generate_unit_tests, - touch=touch, - scheduler_type=scheduler, - scheduler_ilp_solver=scheduler_ilp_solver, - local_cores=local_cores, - forcetargets=forcetargets, - forceall=forceall, - forcerun=forcerun, - prioritytargets=prioritytargets, - until=until, - omit_from=omit_from, - keepgoing=keepgoing, - printrulegraph=printrulegraph, - printfilegraph=printfilegraph, - printdag=printdag, - slurm=slurm, - slurm_jobstep=slurm_jobstep, - cluster=cluster, - cluster_sync=cluster_sync, - jobname=jobname, - drmaa=drmaa, - drmaa_log_dir=drmaa_log_dir, - kubernetes=kubernetes, - container_image=container_image, - k8s_cpu_scalar=k8s_cpu_scalar, - k8s_service_account_name=k8s_service_account_name, - tibanna=tibanna, - tibanna_sfn=tibanna_sfn, - az_batch=az_batch, - az_batch_enable_autoscale=az_batch_enable_autoscale, - az_batch_account_url=az_batch_account_url, - google_lifesciences=google_lifesciences, - google_lifesciences_regions=google_lifesciences_regions, - google_lifesciences_location=google_lifesciences_location, - google_lifesciences_cache=google_lifesciences_cache, - google_lifesciences_service_account_email=google_lifesciences_service_account_email, - google_lifesciences_network=google_lifesciences_network, - google_lifesciences_subnetwork=google_lifesciences_subnetwork, - tes=tes, - flux=flux, - precommand=precommand, - preemption_default=preemption_default, - preemptible_rules=preemptible_rules, - tibanna_config=tibanna_config, - max_jobs_per_second=max_jobs_per_second, - max_status_checks_per_second=max_status_checks_per_second, - printd3dag=printd3dag, - ignore_ambiguity=ignore_ambiguity, - stats=stats, - force_incomplete=force_incomplete, - ignore_incomplete=ignore_incomplete, - list_version_changes=list_version_changes, - list_code_changes=list_code_changes, - list_input_changes=list_input_changes, - list_params_changes=list_params_changes, - list_untracked=list_untracked, - list_conda_envs=list_conda_envs, - summary=summary, - archive=archive, - delete_all_output=delete_all_output, - delete_temp_output=delete_temp_output, - wait_for_files=wait_for_files, - detailed_summary=detailed_summary, - nolock=not lock, - unlock=unlock, - notemp=notemp, - keep_remote_local=keep_remote_local, - nodeps=nodeps, - keep_target_files=keep_target_files, - cleanup_metadata=cleanup_metadata, - conda_cleanup_envs=conda_cleanup_envs, - cleanup_containers=cleanup_containers, - cleanup_shadow=cleanup_shadow, - subsnakemake=subsnakemake, - updated_files=updated_files, - allowed_rules=allowed_rules, - greediness=greediness, - no_hooks=no_hooks, - force_use_threads=use_threads, - conda_create_envs_only=conda_create_envs_only, - cluster_status=cluster_status, - cluster_cancel=cluster_cancel, - cluster_cancel_nargs=cluster_cancel_nargs, - cluster_sidecar=cluster_sidecar, - report=report, - report_stylesheet=report_stylesheet, - export_cwl=export_cwl, - batch=batch, - keepincomplete=keep_incomplete, - containerize=containerize, - ) + def generate_unit_tests(self, path: Path): + """Generate unit tests for the workflow. + + Arguments + --------- + path: Path -- The path to store the unit tests. + """ + self.workflow_api._workflow.generate_unit_tests(path=path) + + def containerize(self): + """Containerize the workflow.""" + self.workflow_api._workflow.containerize() + + def create_report( + self, + path: Path, + stylesheet: Optional[Path] = None, + ): + """Create a report for the workflow. + + Arguments + --------- + report: Path -- The path to the report. + report_stylesheet: Optional[Path] -- The path to the report stylesheet. + """ + self.workflow_api._workflow.create_report( + path=path, + stylesheet=stylesheet, + ) - except BrokenPipeError: - # ignore this exception and stop. It occurs if snakemake output is piped into less and less quits before reading the whole output. - # in such a case, snakemake shall stop scheduling and quit with error 1 - success = False - except BaseException as ex: - if "workflow" in locals(): - print_exception(ex, workflow.linemaps) - else: - print_exception(ex, dict()) - success = False - - if workdir: - os.chdir(olddir) - if "workflow" in locals() and workflow.persistence: - workflow.persistence.unlock() - if not keep_logger: - logger.cleanup() - return success + def printdag(self): + """Print the DAG of the workflow.""" + self.workflow_api._workflow.printdag() + + def printrulegraph(self): + """Print the rule graph of the workflow.""" + self.workflow_api._workflow.printrulegraph() + + def printfilegraph(self): + """Print the file graph of the workflow.""" + self.workflow_api._workflow.printfilegraph() + + def printd3dag(self): + """Print the DAG of the workflow in D3.js compatible JSON.""" + self.workflow_api._workflow.printd3dag() + + def unlock(self): + """Unlock the workflow.""" + self.workflow_api._workflow.unlock() + + def cleanup_metadata(self, paths: List[Path]): + """Cleanup the metadata of the workflow.""" + self.workflow_api._workflow.cleanup_metadata(paths) + + def conda_cleanup_envs(self): + """Cleanup the conda environments of the workflow.""" + self.deployment_settings.imply_deployment_method(DeploymentMethod.CONDA) + self.workflow_api._workflow.conda_cleanup_envs() + + def conda_create_envs(self): + """Only create the conda environments of the workflow.""" + self.deployment_settings.imply_deployment_method(DeploymentMethod.CONDA) + self.workflow_api._workflow.conda_create_envs() + + def conda_list_envs(self): + """List the conda environments of the workflow.""" + self.deployment_settings.imply_deployment_method(DeploymentMethod.CONDA) + self.workflow_api._workflow.conda_list_envs() + + def cleanup_shadow(self): + """Cleanup the shadow directories of the workflow.""" + self.workflow_api._workflow.cleanup_shadow() + + def container_cleanup_images(self): + """Cleanup the container images of the workflow.""" + self.deployment_settings.imply_deployment_method(DeploymentMethod.APPTAINER) + self.workflow_api._workflow.container_cleanup_images() + + def list_changes(self, change_type: ChangeType): + """List the changes of the workflow. + + Arguments + --------- + change_type: ChangeType -- The type of changes to list. + """ + self.workflow_api._workflow.list_changes(change_type=change_type) + + def list_untracked(self): + """List the untracked files of the workflow.""" + self.workflow_api._workflow.list_untracked() + + def summary(self, detailed: bool = False): + """Summarize the workflow. + + Arguments + --------- + detailed: bool -- Whether to print a detailed summary. + """ + self.workflow_api._workflow.summary(detailed=detailed) + + def archive(self, path: Path): + """Archive the workflow. + + Arguments + --------- + path: Path -- The path to the archive. + """ + self.workflow_api._workflow.archive(path=path) + + def delete_output(self, only_temp: bool = False, dryrun: bool = False): + """Delete the output of the workflow. + + Arguments + --------- + only_temp: bool -- Whether to only delete temporary output. + dryrun: bool -- Whether to only dry-run the deletion. + """ + self.workflow_api._workflow.delete_output(only_temp=only_temp, dryrun=dryrun) + + def export_to_cwl(self, path: Path): + """Export the workflow to CWL. + + Arguments + --------- + path: Path -- The path to the CWL file. + """ + self.workflow_api._workflow.export_to_cwl(path=path) + + +def _get_executor_plugin_registry(): + from snakemake.executors import local as local_executor + from snakemake.executors import dryrun as dryrun_executor + from snakemake.executors import touch as touch_executor + + registry = ExecutorPluginRegistry() + registry.register_plugin("local", local_executor) + registry.register_plugin("dryrun", dryrun_executor) + registry.register_plugin("touch", touch_executor) + + return registry diff --git a/snakemake/caching/hash.py b/snakemake/caching/hash.py index 5a82b9481..843d99eb2 100644 --- a/snakemake/caching/hash.py +++ b/snakemake/caching/hash.py @@ -11,6 +11,7 @@ from snakemake import script from snakemake import wrapper from snakemake.exceptions import WorkflowError +from snakemake.settings import DeploymentMethod # ATTENTION: increase version number whenever the hashing algorithm below changes! __version__ = "0.1" @@ -74,7 +75,7 @@ def _get_provenance_hash(self, job: Job, cache_mode: str): wrapper.get_script( job.rule.wrapper, sourcecache=job.rule.workflow.sourcecache, - prefix=workflow.wrapper_prefix, + prefix=workflow.workflow_settings.wrapper_prefix, ), job.rule.workflow.sourcecache, basedir=job.rule.basedir, @@ -114,11 +115,22 @@ def _get_provenance_hash(self, job: Job, cache_mode: str): # Hash used containers or conda environments. if cache_mode != "omit-software": - if workflow.use_conda and job.conda_env: - if workflow.use_singularity and job.conda_env.container_img_url: + if ( + DeploymentMethod.CONDA in workflow.deployment_settings.deployment_method + and job.conda_env + ): + if ( + DeploymentMethod.APPTAINER + in workflow.deployment_settings.deployment_method + and job.conda_env.container_img_url + ): h.update(job.conda_env.container_img_url.encode()) h.update(job.conda_env.content) - elif workflow.use_singularity and job.container_img_url: + elif ( + DeploymentMethod.APPTAINER + in workflow.deployment_settings.deployment_method + and job.container_img_url + ): h.update(job.container_img_url.encode()) # Generate hashes of dependencies, and add them in a blockchain fashion (as input to the current hash, sorted by hash value). diff --git a/snakemake/caching/remote.py b/snakemake/caching/remote.py index 249ee0a66..1f8e6911a 100644 --- a/snakemake/caching/remote.py +++ b/snakemake/caching/remote.py @@ -9,6 +9,7 @@ from snakemake.exceptions import WorkflowError from snakemake.jobs import Job from snakemake.io import get_flag_value +from snakemake.remote import AbstractRemoteProvider class OutputFileCache(AbstractOutputFileCache): @@ -18,7 +19,7 @@ class OutputFileCache(AbstractOutputFileCache): each output file. This is the remote version. """ - def __init__(self, remote_provider): + def __init__(self, remote_provider: AbstractRemoteProvider): super().__init__() self.remote_provider = remote_provider diff --git a/snakemake/cli.py b/snakemake/cli.py index 05560f17f..c8bbd45ff 100644 --- a/snakemake/cli.py +++ b/snakemake/cli.py @@ -3,55 +3,66 @@ __email__ = "johannes.koester@uni-due.de" __license__ = "MIT" +import argparse import sys +from typing import Set + +import configargparse from snakemake import logging -from snakemake.api import snakemake +import snakemake.common.argparse +from snakemake.api import SnakemakeApi, _get_executor_plugin_registry, resolve_snakefile import os import glob from argparse import ArgumentDefaultsHelpFormatter -import logging as _logging from pathlib import Path import re -import threading -import webbrowser -from functools import partial import shlex from importlib.machinery import SourceFileLoader +from snakemake.settings import ( + ChangeType, + ConfigSettings, + DAGSettings, + DeploymentMethod, + DeploymentSettings, + ExecutionSettings, + NotebookEditMode, + OutputSettings, + PreemptibleRules, + Quietness, + RemoteExecutionSettings, + ResourceSettings, + SchedulingSettings, + StorageSettings, + WorkflowSettings, +) -from snakemake_interface_executor_plugins.utils import url_can_parse, ExecMode -from snakemake_interface_executor_plugins.registry import ExecutorPluginRegistry - +from snakemake_interface_executor_plugins.settings import ExecMode from snakemake.target_jobs import parse_target_jobs_cli_args + from snakemake.workflow import Workflow from snakemake.dag import Batch from snakemake.exceptions import ( CliException, ResourceScopesException, + print_exception, ) -from snakemake.io import wait_for_files from snakemake.utils import update_config, available_cpu_count from snakemake.common import ( - RERUN_TRIGGERS, + SNAKEFILE_CHOICES, __version__, - MIN_PY_VERSION, get_appdirs, + get_container_image, parse_key_value_arg, ) from snakemake.resources import ResourceScopes, parse_resources, DefaultResources - -SNAKEFILE_CHOICES = [ - "Snakefile", - "snakefile", - "workflow/Snakefile", - "workflow/snakefile", -] +from snakemake.settings import RerunTrigger def parse_set_threads(args): return parse_set_ints( - args.set_threads, + args, "Invalid threads definition: entries have to be defined as RULE=THREADS pairs " "(with THREADS being a positive integer).", ) @@ -66,8 +77,8 @@ def parse_set_resources(args): from collections import defaultdict assignments = defaultdict(dict) - if args.set_resources is not None: - for entry in args.set_resources: + if args is not None: + for entry in args: key, value = parse_key_value_arg(entry, errmsg=errmsg) key = key.split(":") if len(key) != 2: @@ -86,7 +97,7 @@ def parse_set_resources(args): def parse_set_scatter(args): return parse_set_ints( - args.set_scatter, + args, "Invalid scatter definition: entries have to be defined as NAME=SCATTERITEMS pairs " "(with SCATTERITEMS being a positive integer).", ) @@ -97,11 +108,10 @@ def parse_set_resource_scope(args): "Invalid resource scopes: entries must be defined as RESOURCE=SCOPE pairs, " "where SCOPE is either 'local', 'global', or 'excluded'" ) - if args.set_resource_scopes is not None: + if args is not None: try: return ResourceScopes( - parse_key_value_arg(entry, errmsg=err_msg) - for entry in args.set_resource_scopes + parse_key_value_arg(entry, errmsg=err_msg) for entry in args ) except ResourceScopesException as err: invalid_resources = ", ".join( @@ -177,16 +187,16 @@ def _bool_parser(value): raise ValueError -def parse_config(args): +def parse_config(entries): """Parse config from args.""" import yaml yaml_base_load = lambda s: yaml.load(s, Loader=yaml.loader.BaseLoader) parsers = [int, float, _bool_parser, yaml_base_load, str] config = dict() - if args.config is not None: + if entries: valid = re.compile(r"[a-zA-Z_]\w*$") - for entry in args.config: + for entry in entries: key, val = parse_key_value_arg( entry, errmsg="Invalid config definition: Config entries have to be defined as name=value pairs.", @@ -212,36 +222,18 @@ def parse_config(args): return config -def parse_cores(cores, allow_none=False): - if cores is None: - if allow_none: - return cores - raise CliException( - "Error: you need to specify the maximum number of CPU cores to " - "be used at the same time. If you want to use N cores, say --cores N " - "or -cN. For all cores on your system (be sure that this is " - "appropriate) use --cores all. For no parallelization use --cores 1 or " - "-c1." - ) +def parse_cores(cores): if cores == "all": return available_cpu_count() try: return int(cores) except ValueError: raise CliException( - "Error parsing number of cores (--cores, -c, -j): must be integer, " - "empty, or 'all'." + "Error parsing number of cores (--cores, -c): must be integer or 'all'." ) -def parse_jobs(jobs, allow_none=False): - if jobs is None: - if allow_none: - return jobs - raise CliException( - "Error: you need to specify the maximum number of jobs to " - "be queued or executed at the same time with --jobs or -j." - ) +def parse_jobs(jobs): if jobs == "unlimited": return sys.maxsize try: @@ -252,20 +244,6 @@ def parse_jobs(jobs, allow_none=False): ) -def parse_cores_jobs(cores, jobs, no_exec, non_local_exec, dryrun): - if no_exec or dryrun: - cores = parse_cores(cores, allow_none=True) or 1 - jobs = parse_jobs(jobs, allow_none=True) or 1 - elif non_local_exec: - cores = parse_cores(cores, allow_none=True) - jobs = parse_jobs(jobs) - else: - cores = parse_cores(cores or jobs) - jobs = None - - return cores, jobs - - def get_profile_file(profile, file, return_default=False): dirs = get_appdirs() if os.path.exists(profile): @@ -292,7 +270,6 @@ def get_profile_file(profile, file, return_default=False): def get_argument_parser(profiles=None): """Generate and return argument parser.""" - import configargparse from snakemake.profiles import ProfileConfigFileParser dirs = get_appdirs() @@ -317,7 +294,7 @@ def get_argument_parser(profiles=None): exit(1) config_files.append(config_file) - parser = configargparse.ArgumentParser( + parser = snakemake.common.argparse.ArgumentParser( description="Snakemake is a Python based language and execution " "environment for GNU Make-like workflows.", formatter_class=ArgumentDefaultsHelpFormatter, @@ -328,9 +305,9 @@ def get_argument_parser(profiles=None): group_exec = parser.add_argument_group("EXECUTION") group_exec.add_argument( - "target", + "targets", nargs="*", - default=None, + default=set(), help="Targets to build. May be rules or files.", ) @@ -380,7 +357,7 @@ def get_argument_parser(profiles=None): The profile folder has to contain a file 'config.yaml'. This file can be used to set default values for command line options in YAML format. For example, - '--cluster qsub' becomes 'cluster: qsub' in the YAML + '--executor slurm' becomes 'executor: slurm' in the YAML file. It is advisable to use the workflow profile to set or overwrite e.g. workflow specific resources like the amount of threads of a particular rule or the amount of memory needed. @@ -404,6 +381,7 @@ def get_argument_parser(profiles=None): "--snakefile", "-s", metavar="FILE", + type=Path, help=( "The workflow definition in form of a snakefile." "Usually, you should not need to specify this. " @@ -418,9 +396,8 @@ def get_argument_parser(profiles=None): "--cores", "-c", action="store", - const=available_cpu_count(), - nargs="?", metavar="N", + type=parse_cores, help=( "Use at most N CPU cores/jobs in parallel. " "If N is omitted or 'all', the limit is set to the number of " @@ -436,12 +413,12 @@ def get_argument_parser(profiles=None): "--jobs", "-j", metavar="N", - nargs="?", - const=available_cpu_count(), action="store", + type=parse_jobs, help=( "Use at most N CPU cluster/cloud jobs in parallel. For local execution this is " - "an alias for --cores. Note: Set to 'unlimited' in case, this does not play a role." + "an alias for --cores (it is though recommended to use --cores in that case). " + "Note: Set to 'unlimited' to allow any number of parallel jobs." ), ) group_exec.add_argument( @@ -459,8 +436,10 @@ def get_argument_parser(profiles=None): group_exec.add_argument( "--resources", "--res", - nargs="*", + nargs="+", metavar="NAME=INT", + default=dict(), + parse_func=parse_resources, help=( "Define additional resources that shall constrain the scheduling " "analogously to --cores (see above). A resource is defined as " @@ -478,6 +457,8 @@ def get_argument_parser(profiles=None): "--set-threads", metavar="RULE=THREADS", nargs="+", + default=dict(), + parse_func=parse_set_threads, help="Overwrite thread usage of rules. This allows to fine-tune workflow " "parallelization. In particular, this is helpful to target certain cluster nodes " "by e.g. shifting a rule to use more, or less threads than defined in the workflow. " @@ -496,6 +477,8 @@ def get_argument_parser(profiles=None): "--set-resources", metavar="RULE:RESOURCE=VALUE", nargs="+", + default=dict(), + parse_func=parse_set_resources, help="Overwrite resource usage of rules. This allows to fine-tune workflow " "resources. In particular, this is helpful to target certain cluster nodes " "by e.g. defining a certain partition for a rule, or overriding a temporary directory. " @@ -506,6 +489,8 @@ def get_argument_parser(profiles=None): "--set-scatter", metavar="NAME=SCATTERITEMS", nargs="+", + default=dict(), + parse_func=parse_set_scatter, help="Overwrite number of scatter items of scattergather processes. This allows to fine-tune " "workflow parallelization. Thereby, SCATTERITEMS has to be a positive integer, and NAME has to be " "the name of the scattergather process defined via a scattergather directive in the workflow.", @@ -514,6 +499,8 @@ def get_argument_parser(profiles=None): "--set-resource-scopes", metavar="RESOURCE=[global|local]", nargs="+", + default=dict(), + parse_func=parse_set_resource_scope, help="Overwrite resource scopes. A scope determines how a constraint is " "reckoned in cluster execution. With RESOURCE=local, a constraint applied to " "RESOURCE using --resources will be considered the limit for each group " @@ -530,6 +517,7 @@ def get_argument_parser(profiles=None): "--default-res", nargs="*", metavar="NAME=INT", + parse_func=DefaultResources, help=( "Define default values of resources for rules that do not define their own values. " "In addition to plain integers, python expressions over inputsize are allowed (e.g. '2*input.size_mb'). " @@ -547,30 +535,21 @@ def get_argument_parser(profiles=None): ) group_exec.add_argument( - "--preemption-default", - type=int, - default=None, + "--preemptible-rules", + nargs="*", + parse_func=set, help=( - "A preemptible instance can be requested when using the Google Life Sciences API. If you set a --preemption-default," - "all rules will be subject to the default. Specifically, this integer is the number of restart attempts that will be " - "made given that the instance is killed unexpectedly. Note that preemptible instances have a maximum running time of 24 " - "hours. If you want to set preemptible instances for only a subset of rules, use --preemptible-rules instead." + "Define which rules shall use a preemptible machine which can be prematurely killed by e.g. a cloud provider (also called spot instances). " + "This is currently only supported by the Google Life Sciences executor and ignored by all other executors. " + "If no rule names are provided, all rules are considered to be preemptible. " + "The " ), ) group_exec.add_argument( - "--preemptible-rules", - nargs="+", - default=None, - help=( - "A preemptible instance can be requested when using the Google Life Sciences API. If you want to use these instances " - "for a subset of your rules, you can use --preemptible-rules and then specify a list of rule and integer pairs, where " - "each integer indicates the number of restarts to use for the rule's instance in the case that the instance is " - "terminated unexpectedly. --preemptible-rules can be used in combination with --preemption-default, and will take " - "priority. Note that preemptible instances have a maximum running time of 24. If you want to apply a consistent " - "number of retries across all your rules, use --preemption-default instead. " - "Example: snakemake --preemption-default 10 --preemptible-rules map_reads=3 call_variants=0" - ), + "--preemptible-retries", + type=int, + help="Number of retries that shall be made in order to finish a job from of rule that has been marked as preemptible via the --preemptible-rules setting.", ) group_exec.add_argument( @@ -578,6 +557,8 @@ def get_argument_parser(profiles=None): "-C", nargs="*", metavar="KEY=VALUE", + default=dict(), + parse_func=parse_config, help=( "Set or overwrite values in the workflow config object. " "The workflow config object is accessible as variable config inside " @@ -590,6 +571,8 @@ def get_argument_parser(profiles=None): "--configfiles", nargs="+", metavar="FILE", + default=list(), + type=Path, help=( "Specify or overwrite the config file of the workflow (see the docs). " "Values specified in JSON or YAML format are available in the global config " @@ -603,13 +586,14 @@ def get_argument_parser(profiles=None): "--envvars", nargs="+", metavar="VARNAME", + parse_func=set, help="Environment variables to pass to cloud jobs.", ) group_exec.add_argument( "--directory", "-d", metavar="DIR", - action="store", + type=Path, help=( "Specify working directory (relative paths in " "the snakefile will use this as their origin)." @@ -641,8 +625,9 @@ def get_argument_parser(profiles=None): group_exec.add_argument( "--rerun-triggers", nargs="+", - choices=RERUN_TRIGGERS, - default=RERUN_TRIGGERS, + choices=RerunTrigger.choices(), + default=RerunTrigger.all(), + parse_func=RerunTrigger.parse_choices_set, help="Define what triggers the rerunning of a job. By default, " "all triggers are used, which guarantees that results are " "consistent with the workflow code and configuration. If you " @@ -662,7 +647,7 @@ def get_argument_parser(profiles=None): "--executor", "-e", help="Specify a custom executor, available via an executor plugin: snakemake_executor_", - choices=ExecutorPluginRegistry().plugins, + choices=_get_executor_plugin_registry().plugins.keys(), ) group_exec.add_argument( "--forceall", @@ -679,6 +664,8 @@ def get_argument_parser(profiles=None): "-R", nargs="*", metavar="TARGET", + parse_func=set, + default=set(), help=( "Force the re-execution or creation of the given rules or files." " Use this option if you changed a rule and want to have all its " @@ -690,14 +677,17 @@ def get_argument_parser(profiles=None): "-P", nargs="+", metavar="TARGET", + parse_func=set, + default=set(), help=( "Tell the scheduler to assign creation of given targets " - "(and all their dependencies) highest priority. (EXPERIMENTAL)" + "(and all their dependencies) highest priority." ), ) group_exec.add_argument( "--batch", metavar="RULE=BATCH/BATCHES", + type=parse_batch, help=( "Only create the given BATCH of the input files of the given RULE. " "This can be used to iteratively run parts of very large workflows. " @@ -714,6 +704,8 @@ def get_argument_parser(profiles=None): "-U", nargs="+", metavar="TARGET", + parse_func=set, + default=set(), help=( "Runs the pipeline until it reaches the specified rules or " "files. Only runs jobs that are dependencies of the specified " @@ -725,6 +717,8 @@ def get_argument_parser(profiles=None): "-O", nargs="+", metavar="TARGET", + parse_func=set, + default=set(), help=( "Prevent the execution or creation of the given rules or files " "as well as any rules or files that are downstream of these targets " @@ -811,26 +805,18 @@ def get_argument_parser(profiles=None): help=("Do not evaluate or execute subworkflows."), ) - # TODO add group_partitioning, allowing to define --group rulename=groupname. - # i.e. setting groups via the CLI for improving cluster performance given - # available resources. - # TODO add an additional flag --group-components groupname=3, allowing to set the - # number of connected components a group is allowed to span. By default, this is 1 - # (as now), but the flag allows to extend this. This can be used to run e.g. - # 3 jobs of the same rule in the same group, although they are not connected. - # Can be helpful for putting together many small jobs or benefitting of shared memory - # setups. - group_group = parser.add_argument_group("GROUPING") group_group.add_argument( "--groups", nargs="+", + parse_func=parse_groups, help="Assign rules to groups (this overwrites any " "group definitions from the workflow).", ) group_group.add_argument( "--group-components", nargs="+", + parse_func=parse_group_components, help="Set the number of connected components a group is " "allowed to span. By default, this is 1, but this flag " "allows to extend this. This can be used to run e.g. 3 " @@ -845,6 +831,7 @@ def get_argument_parser(profiles=None): nargs="?", const="report.html", metavar="FILE", + type=Path, help="Create an HTML report with results and statistics. " "This can be either a .html file or a .zip file. " "In the former case, all results are embedded into the .html (this only works for small data). " @@ -854,6 +841,7 @@ def get_argument_parser(profiles=None): group_report.add_argument( "--report-stylesheet", metavar="CSSFILE", + type=Path, help="Custom stylesheet to use for report. In particular, this can be used for " "branding the report with e.g. a custom logo, see docs.", ) @@ -899,6 +887,7 @@ def get_argument_parser(profiles=None): nargs="?", const=".tests/unit", metavar="TESTPATH", + type=Path, help="Automatically generate unit tests for each workflow rule. " "This assumes that all input files of each job are already present. " "Rules without a job with present input files will be skipped (a warning will be issued). " @@ -920,6 +909,7 @@ def get_argument_parser(profiles=None): help="Compile workflow to CWL and store it in given FILE.", ) group_utils.add_argument( + "--list-rules", "--list", "-l", action="store_true", @@ -1003,6 +993,7 @@ def get_argument_parser(profiles=None): group_utils.add_argument( "--archive", metavar="FILE", + type=Path, help="Archive the workflow into the given tar archive FILE. The archive " "will be created such that the workflow can be re-executed on a vanilla " "system. The function needs conda and git to be installed. " @@ -1019,6 +1010,7 @@ def get_argument_parser(profiles=None): "--cm", nargs="+", metavar="FILE", + type=Path, help="Cleanup the metadata " "of given files. That means that snakemake removes any tracked " "version info, and any marks that files are incomplete.", @@ -1038,16 +1030,10 @@ def get_argument_parser(profiles=None): "--unlock", action="store_true", help="Remove a lock on the working directory." ) group_utils.add_argument( - "--list-version-changes", - "--lv", - action="store_true", - help="List all output files that have been created with " - "a different version (as determined by the version keyword).", - ) - group_utils.add_argument( - "--list-code-changes", + "--list-changes", "--lc", - action="store_true", + choices=ChangeType.all(), + type=ChangeType.parse_choice, help="List all output files for which the rule body (run or shell) have " "changed in the Snakefile.", ) @@ -1075,7 +1061,8 @@ def get_argument_parser(profiles=None): "workflow. This can be used e.g. for identifying leftover files. Hidden files " "and directories are ignored.", ) - group_utils.add_argument( + group_delete_output = group_utils.add_mutually_exclusive_group() + group_delete_output.add_argument( "--delete-all-output", action="store_true", help="Remove all files generated by the workflow. Use together with --dry-run " @@ -1083,7 +1070,7 @@ def get_argument_parser(profiles=None): "not recurse into subworkflows. Write-protected files are not removed. " "Nevertheless, use with care!", ) - group_utils.add_argument( + group_delete_output.add_argument( "--delete-temp-output", action="store_true", help="Remove all temporary files generated by the workflow. Use together " @@ -1113,20 +1100,6 @@ def get_argument_parser(profiles=None): group_utils.add_argument("--version", "-v", action="version", version=__version__) group_output = parser.add_argument_group("OUTPUT") - group_output.add_argument( - "--gui", - nargs="?", - const="8000", - metavar="PORT", - type=str, - help="Serve an HTML based user interface to the given network and " - "port e.g. 168.129.10.15:8000. By default Snakemake is only " - "available in the local network (default port: 8000). To make " - "Snakemake listen to all ip addresses add the special host address " - "0.0.0.0 to the url (0.0.0.0:8000). This is important if Snakemake " - "is used in a virtualised environment like Docker. If possible, a " - "browser window is opened.", - ) group_output.add_argument( "--printshellcmds", "-p", @@ -1139,11 +1112,6 @@ def get_argument_parser(profiles=None): help="Print candidate and selected jobs (including their wildcards) while " "inferring DAG. This can help to debug unexpected DAG topology or errors.", ) - group_output.add_argument( - "--stats", - metavar="FILE", - help="Write stats about Snakefile execution in JSON format to the given file.", - ) group_output.add_argument( "--nocolor", action="store_true", help="Do not use a colored output." ) @@ -1151,8 +1119,9 @@ def get_argument_parser(profiles=None): "--quiet", "-q", nargs="*", - choices=["progress", "rules", "all"], + choices=Quietness.choices(), default=None, + parse_func=parse_quietness, help="Do not output certain information. " "If used without arguments, do not output any progress or rule " "information. Defining 'all' results in no information being " @@ -1220,6 +1189,7 @@ def get_argument_parser(profiles=None): "--wait-for-files", nargs="*", metavar="FILE", + parse_func=set, help="Wait --latency-wait seconds for these " "files to be present before executing the workflow. " "This option is used internally to handle filesystem latency in cluster " @@ -1253,7 +1223,7 @@ def get_argument_parser(profiles=None): help="Keep local copies of remote input files.", ) group_behavior.add_argument( - "--keep-target-files", + "--target-files-omit-workdir-adjustment", action="store_true", help="Do not adjust the paths of given target files relative to the working directory.", ) @@ -1267,6 +1237,8 @@ def get_argument_parser(profiles=None): group_behavior.add_argument( "--target-jobs", nargs="+", + parse_func=parse_target_jobs_cli_args, + default=set(), help="Target particular jobs by RULE:WILDCARD1=VALUE,WILDCARD2=VALUE,... " "This is meant for internal use by Snakemake itself only.", ) @@ -1290,9 +1262,15 @@ def get_argument_parser(profiles=None): "fractions allowed.", ) group_behavior.add_argument( - "-T", + "--seconds-between-status-checks", + default=10, + type=int, + help="Number of seconds to wait between two rounds of status checks.", + ) + group_behavior.add_argument( "--retries", "--restart-times", + "-T", default=0, type=int, help="Number of times to restart failing jobs (defaults to 0).", @@ -1346,12 +1324,12 @@ def get_argument_parser(profiles=None): "separately. Further, it won't take special measures " "to deal with filesystem latency issues. This option " "will in most cases only make sense in combination with " - "--default-remote-provider. Further, when using --cluster " - "you will have to also provide --cluster-status. " + "--default-remote-provider. " "Only activate this if you " "know what you are doing.", ) group_behavior.add_argument( + "--scheduler-greediness", "--greediness", type=float, default=None, @@ -1365,12 +1343,6 @@ def get_argument_parser(profiles=None): action="store_true", help="Do not invoke onstart, onsuccess or onerror hooks after execution.", ) - group_behavior.add_argument( - "--overwrite-shellcmd", - help="Provide a shell command that shall be executed instead of those " - "given in the workflow. " - "This is for debugging purposes only.", - ) group_behavior.add_argument( "--debug", action="store_true", @@ -1385,9 +1357,9 @@ def get_argument_parser(profiles=None): ) group_behavior.add_argument( "--mode", - choices=[ExecMode.default, ExecMode.subprocess, ExecMode.remote], - default=ExecMode.default, - type=int, + choices=ExecMode.all(), + default=ExecMode.DEFAULT, + type=ExecMode.parse_choice, help="Set execution mode of Snakemake (internal use only).", ) group_behavior.add_argument( @@ -1438,49 +1410,19 @@ def get_argument_parser(profiles=None): # mode ) - group_cluster = parser.add_argument_group("CLUSTER") + group_cluster = parser.add_argument_group("REMOTE EXECUTION") - # TODO extend below description to explain the wildcards that can be used - cluster_mode_group = group_cluster.add_mutually_exclusive_group() - cluster_mode_group.add_argument( - "--cluster", - metavar="CMD", - help=( - "Execute snakemake rules with the given submit command, " - "e.g. qsub. Snakemake compiles jobs into scripts that are " - "submitted to the cluster with the given command, once all input " - "files for a particular job are present.\n" - "The submit command can be decorated to make it aware of certain " - "job properties (name, rulename, input, output, params, wildcards, log, threads " - "and dependencies (see the argument below)), e.g.:\n" - "$ snakemake --cluster 'qsub -pe threaded {threads}'." - ), - ), - cluster_mode_group.add_argument( - "--cluster-sync", - metavar="CMD", - help=( - "cluster submission command will block, returning the remote exit" - "status upon remote termination (for example, this should be used" - "if the cluster command is 'qsub -sync y' (SGE)" - ), - ), - cluster_mode_group.add_argument( - "--drmaa", - nargs="?", - const="", - metavar="ARGS", - help="Execute snakemake on a cluster accessed via DRMAA, " - "Snakemake compiles jobs into scripts that are " - "submitted to the cluster with the given command, once all input " - "files for a particular job are present. ARGS can be used to " - "specify options of the underlying cluster system, " - "thereby using the job properties name, rulename, input, output, params, wildcards, log, " - "threads and dependencies, e.g.: " - "--drmaa ' -pe threaded {threads}'. Note that ARGS must be given in quotes and " - "with a leading whitespace.", + group_cluster.add_argument( + "--container-image", + metavar="IMAGE", + help="Docker image to use, e.g., when submitting jobs to kubernetes. " + "Defaults to 'https://hub.docker.com/r/snakemake/snakemake', tagged with " + "the same version as the currently running Snakemake instance. " + "Note that overwriting this value is up to your responsibility. " + "Any used image has to contain a working snakemake installation " + "that is compatible with (or ideally the same as) the currently " + "running version.", ) - group_cluster.add_argument( "--immediate-submit", "--is", @@ -1511,103 +1453,12 @@ def get_argument_parser(profiles=None): 'cluster (see --cluster). NAME is "snakejob.{name}.{jobid}.sh" ' "per default. The wildcard {jobid} has to be present in the name.", ) - group_cluster.add_argument( - "--cluster-status", - help="Status command for cluster execution. This is only considered " - "in combination with the --cluster flag. If provided, Snakemake will " - "use the status command to determine if a job has finished successfully " - "or failed. For this it is necessary that the submit command provided " - "to --cluster returns the cluster job id. Then, the status command " - "will be invoked with the job id. Snakemake expects it to return " - "'success' if the job was successful, 'failed' if the job failed and " - "'running' if the job still runs.", - ) - group_cluster.add_argument( - "--cluster-cancel", - default=None, - help="Specify a command that allows to stop currently running jobs. " - "The command will be passed a single argument, the job id.", - ) - group_cluster.add_argument( - "--cluster-cancel-nargs", - type=int, - default=1000, - help="Specify maximal number of job ids to pass to --cluster-cancel " - "command, defaults to 1000.", - ) - group_cluster.add_argument( - "--cluster-sidecar", - default=None, - help="Optional command to start a sidecar process during cluster " - "execution. Only active when --cluster is given as well.", - ) - group_cluster.add_argument( - "--drmaa-log-dir", - metavar="DIR", - help="Specify a directory in which stdout and stderr files of DRMAA" - " jobs will be written. The value may be given as a relative path," - " in which case Snakemake will use the current invocation directory" - " as the origin. If given, this will override any given '-o' and/or" - " '-e' native specification. If not given, all DRMAA stdout and" - " stderr files are written to the current working directory.", - ) group_flux = parser.add_argument_group("FLUX") - group_kubernetes = parser.add_argument_group("KUBERNETES") group_google_life_science = parser.add_argument_group("GOOGLE_LIFE_SCIENCE") - group_kubernetes = parser.add_argument_group("KUBERNETES") group_tes = parser.add_argument_group("TES") group_tibanna = parser.add_argument_group("TIBANNA") - group_kubernetes.add_argument( - "--kubernetes", - metavar="NAMESPACE", - nargs="?", - const="default", - help="Execute workflow in a kubernetes cluster (in the cloud). " - "NAMESPACE is the namespace you want to use for your job (if nothing " - "specified: 'default'). " - "Usually, this requires --default-remote-provider and " - "--default-remote-prefix to be set to a S3 or GS bucket where your . " - "data shall be stored. It is further advisable to activate conda " - "integration via --use-conda.", - ) - group_kubernetes.add_argument( - "--container-image", - metavar="IMAGE", - help="Docker image to use, e.g., when submitting jobs to kubernetes " - "Defaults to 'https://hub.docker.com/r/snakemake/snakemake', tagged with " - "the same version as the currently running Snakemake instance. " - "Note that overwriting this value is up to your responsibility. " - "Any used image has to contain a working snakemake installation " - "that is compatible with (or ideally the same as) the currently " - "running version.", - ) - group_kubernetes.add_argument( - "--k8s-cpu-scalar", - metavar="FLOAT", - default=0.95, - type=float, - help="K8s reserves some proportion of available CPUs for its own use. " - "So, where an underlying node may have 8 CPUs, only e.g. 7600 milliCPUs " - "are allocatable to k8s pods (i.e. snakemake jobs). As 8 > 7.6, k8s can't " - "find a node with enough CPU resource to run such jobs. This argument acts " - "as a global scalar on each job's CPU request, so that e.g. a job whose " - "rule definition asks for 8 CPUs will request 7600m CPUs from k8s, " - "allowing it to utilise one entire node. N.B: the job itself would still " - "see the original value, i.e. as the value substituted in {threads}.", - ) - - group_kubernetes.add_argument( - "--k8s-service-account-name", - metavar="SERVICEACCOUNTNAME", - default=None, - help="This argument allows the use of customer service accounts for " - "kubernetes pods. If specified serviceAccountName will be added to the " - "pod specs. This is needed when using workload identity which is enforced " - "when using Google Cloud GKE Autopilot.", - ) - group_tibanna.add_argument( "--tibanna", action="store_true", @@ -1642,72 +1493,6 @@ def get_argument_parser(profiles=None): help="Additional tibanna config e.g. --tibanna-config spot_instance=true subnet=" " security group=", ) - group_google_life_science.add_argument( - "--google-lifesciences", - action="store_true", - help="Execute workflow on Google Cloud cloud using the Google Life. " - " Science API. This requires default application credentials (json) " - " to be created and export to the environment to use Google Cloud " - " Storage, Compute Engine, and Life Sciences. The credential file " - " should be exported as GOOGLE_APPLICATION_CREDENTIALS for snakemake " - " to discover. Also, --use-conda, --use-singularity, --config, " - "--configfile are supported and will be carried over.", - ) - group_google_life_science.add_argument( - "--google-lifesciences-regions", - nargs="+", - default=["us-east1", "us-west1", "us-central1"], - help="Specify one or more valid instance regions (defaults to US)", - ) - group_google_life_science.add_argument( - "--google-lifesciences-location", - help="The Life Sciences API service used to schedule the jobs. " - " E.g., us-centra1 (Iowa) and europe-west2 (London) " - " Watch the terminal output to see all options found to be available. " - " If not specified, defaults to the first found with a matching prefix " - " from regions specified with --google-lifesciences-regions.", - ) - group_google_life_science.add_argument( - "--google-lifesciences-keep-cache", - action="store_true", - help="Cache workflows in your Google Cloud Storage Bucket specified " - "by --default-remote-prefix/{source}/{cache}. Each workflow working " - "directory is compressed to a .tar.gz, named by the hash of the " - "contents, and kept in Google Cloud Storage. By default, the caches " - "are deleted at the shutdown step of the workflow.", - ) - group_google_life_science.add_argument( - "--google-lifesciences-service-account-email", - help="Specify a service account email address", - ) - group_google_life_science.add_argument( - "--google-lifesciences-network", - help="Specify a network for a Google Compute Engine VM instance", - ) - group_google_life_science.add_argument( - "--google-lifesciences-subnetwork", - help="Specify a subnetwork for a Google Compute Engine VM instance", - ) - - group_azure_batch = parser.add_argument_group("AZURE_BATCH") - - group_azure_batch.add_argument( - "--az-batch", - action="store_true", - help="Execute workflow on azure batch", - ) - - group_azure_batch.add_argument( - "--az-batch-enable-autoscale", - action="store_true", - help="Enable autoscaling of the azure batch pool nodes, this option will set the initial dedicated node count to zero, and requires five minutes to resize the cluster, so is only recommended for longer running jobs.", - ) - - group_azure_batch.add_argument( - "--az-batch-account-url", - nargs="?", - help="Azure batch account url, requires AZ_BATCH_ACCOUNT_KEY environment variable to be set.", - ) group_flux.add_argument( "--flux", @@ -1717,10 +1502,21 @@ def get_argument_parser(profiles=None): "If you don't have a shared filesystem, additionally specify --no-shared-fs.", ) - group_tes.add_argument( - "--tes", - metavar="URL", - help="Send workflow tasks to GA4GH TES server specified by url.", + group_deployment = parser.add_argument_group("SOFTWARE DEPLOYMENT") + group_deployment.add_argument( + "--software-deployment-method", + "--deployment-method", + "--deployment", + nargs="+", + choices=DeploymentMethod.choices(), + parse_func=DeploymentMethod.parse_choices_set, + default=set(), + help="Specify software environment deployment method.", + ) + group_deployment.add_argument( + "--container-cleanup-images", + action="store_true", + help="Remove unused containers", ) group_conda = parser.add_argument_group("CONDA") @@ -1789,34 +1585,32 @@ def get_argument_parser(profiles=None): "Mamba is much faster and highly recommended.", ) - group_singularity = parser.add_argument_group("SINGULARITY") + group_singularity = parser.add_argument_group("APPTAINER/SINGULARITY") group_singularity.add_argument( + "--use-apptainer", "--use-singularity", action="store_true", - help="If defined in the rule, run job within a singularity container. " + help="If defined in the rule, run job within a apptainer/singularity container. " "If this flag is not set, the singularity directive is ignored.", ) group_singularity.add_argument( + "--apptainer-prefix", "--singularity-prefix", metavar="DIR", - help="Specify a directory in which singularity images will be stored." + help="Specify a directory in which apptainer/singularity images will be stored." "If not supplied, the value is set " "to the '.snakemake' directory relative to the invocation directory. " - "If supplied, the `--use-singularity` flag must also be set. The value " + "If supplied, the `--use-apptainer` flag must also be set. The value " "may be given as a relative path, which will be extrapolated to the " "invocation directory, or as an absolute path.", ) group_singularity.add_argument( + "--apptainer-args", "--singularity-args", default="", metavar="ARGS", - help="Pass additional args to singularity.", - ) - group_singularity.add_argument( - "--cleanup-containers", - action="store_true", - help="Remove unused (singularity) containers", + help="Pass additional args to apptainer/singularity.", ) group_env_modules = parser.add_argument_group("ENVIRONMENT MODULES") @@ -1831,7 +1625,7 @@ def get_argument_parser(profiles=None): ) # Add namespaced arguments to parser for each plugin - ExecutorPluginRegistry().register_cli_args(parser) + _get_executor_plugin_registry().register_cli_args(parser) return parser @@ -1847,46 +1641,11 @@ def generate_parser_metadata(parser, args): return metadata -def main(argv=None): - """Main entry point.""" - - if sys.version_info < MIN_PY_VERSION: - print( - f"Snakemake requires at least Python {MIN_PY_VERSION}.", - file=sys.stderr, - ) - exit(1) - +def parse_args(argv): parser = get_argument_parser() args = parser.parse_args(argv) - snakefile = args.snakefile - if snakefile is None: - for p in SNAKEFILE_CHOICES: - if os.path.exists(p): - snakefile = p - break - if snakefile is None: - print( - "Error: no Snakefile found, tried {}.".format( - ", ".join(SNAKEFILE_CHOICES) - ), - file=sys.stderr, - ) - sys.exit(1) - - # Custom argument parsing based on chosen executor - # We also only validate an executor plugin when it's selected - executor_args = None - if args.executor: - plugin = ExecutorPluginRegistry().plugins[args.executor] - - # This is the dataclass prepared by the executor - executor_args = plugin.get_executor_settings(args) - - # Hold a handle to the plugin class - executor_args._executor = plugin - + snakefile = resolve_snakefile(args.snakefile) workflow_profile = None if args.workflow_profile != "none": if args.workflow_profile: @@ -1906,9 +1665,9 @@ def main(argv=None): if args.profile == "none": args.profile = None - if (args.profile or workflow_profile) and args.mode == ExecMode.default: + if (args.profile or workflow_profile) and args.mode == ExecMode.DEFAULT: # Reparse args while inferring config file from profile. - # But only do this if the user has invoked Snakemake (ExecMode.default) + # But only do this if the user has invoked Snakemake (ExecMode.DEFAULT) profiles = [] if args.profile: profiles.append(args.profile) @@ -1927,503 +1686,344 @@ def main(argv=None): parser = get_argument_parser(profiles=profiles) args = parser.parse_args(argv) - def adjust_path(f): - if os.path.exists(f) or os.path.isabs(f): - return f + def adjust_path(path_or_value): + if isinstance(path_or_value, str): + adjusted = get_profile_file( + args.profile, path_or_value, return_default=False + ) + if adjusted is None: + return path_or_value + else: + return adjusted else: - return get_profile_file(args.profile, f, return_default=True) - - # update file paths to be relative to the profile - # (if they do not exist relative to CWD) - if args.jobscript: - args.jobscript = adjust_path(args.jobscript) - if args.cluster: - args.cluster = adjust_path(args.cluster) - if args.cluster_sync: - args.cluster_sync = adjust_path(args.cluster_sync) - for key in "cluster_status", "cluster_cancel", "cluster_sidecar": - if getattr(args, key): - setattr(args, key, adjust_path(getattr(args, key))) - if args.report_stylesheet: - args.report_stylesheet = adjust_path(args.report_stylesheet) - - if args.quiet is not None and len(args.quiet) == 0: - # default case, set quiet to progress and rule - args.quiet = ["progress", "rules"] + return path_or_value - if args.bash_completion: - cmd = b"complete -o bashdefault -C snakemake-bash-completion snakemake" - sys.stdout.buffer.write(cmd) - sys.exit(0) - - if args.batch is not None and args.forceall: - print( - "--batch may not be combined with --forceall, because recomputed upstream " - "jobs in subsequent batches may render already obtained results outdated." - ) + # Update file paths to be relative to the profile if profile + # contains them. + for key, _ in list(args._get_kwargs()): + setattr(args, key, adjust_path(getattr(args, key))) - try: - resources = parse_resources(args.resources) - config = parse_config(args) + return parser, args - if args.default_resources is not None: - default_resources = DefaultResources(args.default_resources) - else: - default_resources = None - batch = parse_batch(args) - overwrite_threads = parse_set_threads(args) - overwrite_resources = parse_set_resources(args) - overwrite_resource_scopes = parse_set_resource_scope(args) - - overwrite_scatter = parse_set_scatter(args) - - overwrite_groups = parse_groups(args) - group_components = parse_group_components(args) - except ValueError as e: - print(e, file=sys.stderr) - print("", file=sys.stderr) - sys.exit(1) - - non_local_exec = ( - args.cluster - or args.slurm - or args.slurm_jobstep - or args.cluster_sync - or args.tibanna - or args.kubernetes - or args.tes - or args.az_batch - or args.google_lifesciences - or args.drmaa - or args.flux - ) - no_exec = ( - args.print_compilation - or args.list_code_changes - or args.list_conda_envs - or args.list_input_changes - or args.list_params_changes - or args.list - or args.list_target_rules - or args.list_untracked - or args.list_version_changes - or args.export_cwl - or args.generate_unit_tests - or args.dag - or args.d3dag - or args.filegraph - or args.rulegraph - or args.summary - or args.detailed_summary - or args.lint - or args.containerize - or args.report - or args.gui - or args.archive - or args.unlock - or args.cleanup_metadata - ) - - try: - cores, jobs = parse_cores_jobs( - args.cores, args.jobs, no_exec, non_local_exec, args.dryrun - ) - args.cores = cores - args.jobs = jobs - except CliException as err: - print(err.msg, sys.stderr) - sys.exit(1) - - if args.drmaa_log_dir is not None and not os.path.isabs(args.drmaa_log_dir): - args.drmaa_log_dir = os.path.abspath(os.path.expanduser(args.drmaa_log_dir)) - - if args.runtime_profile: - import yappi - - yappi.start() +def parse_quietness(quietness) -> Set[Quietness]: + if quietness is not None and len(quietness) == 0: + # default case, set quiet to progress and rule + quietness = [Quietness.PROGRESS, Quietness.RULES] + else: + quietness = Quietness.parse_choices_set() + return quietness - if args.immediate_submit and not args.notemp: - print( - "Error: --immediate-submit has to be combined with --notemp, " - "because temp file handling is not supported in this mode.", - file=sys.stderr, - ) - sys.exit(1) - if (args.conda_prefix or args.conda_create_envs_only) and not args.use_conda: - if args.conda_prefix and os.environ.get("SNAKEMAKE_CONDA_PREFIX", False): +def setup_log_handlers(args, parser): + log_handler = [] + if args.log_handler_script is not None: + if not os.path.exists(args.log_handler_script): print( - "Warning: The enviorment variable SNAKEMAKE_CONDA_PREFIX is set" - "but --use-conda is not." - "Snakemake will ignore SNAKEMAKE_CONDA_PREFIX" - "and conda enviorments will not be used or created.", + "Error: no log handler script found, {}.".format( + args.log_handler_script + ), file=sys.stderr, ) - args.conda_prefix = None - else: + sys.exit(1) + log_script = SourceFileLoader("log", args.log_handler_script).load_module() + try: + log_handler.append(log_script.log_handler) + except: print( - "Error: --use-conda must be set if --conda-prefix or " - "--create-envs-only is set.", + 'Error: Invalid log handler script, {}. Expect python function "log_handler(msg)".'.format( + args.log_handler_script + ), file=sys.stderr, ) sys.exit(1) - if args.singularity_prefix and not args.use_singularity: - print( - "Error: --use_singularity must be set if --singularity-prefix is set.", - file=sys.stderr, - ) - sys.exit(1) + if args.log_service == "slack": + slack_logger = logging.SlackLogger() + log_handler.append(slack_logger.log_handler) - if args.kubernetes and ( - not args.default_remote_provider or not args.default_remote_prefix - ): - print( - "Error: --kubernetes must be combined with " - "--default-remote-provider and --default-remote-prefix, see " - "https://snakemake.readthedocs.io/en/stable/executing/cloud.html" - "#executing-a-snakemake-workflow-via-kubernetes", - file=sys.stderr, + elif args.wms_monitor or args.log_service == "wms": + # Generate additional metadata for server + metadata = generate_parser_metadata(parser, args) + wms_logger = logging.WMSLogger( + args.wms_monitor, args.wms_monitor_arg, metadata=metadata ) - sys.exit(1) + log_handler.append(wms_logger.log_handler) - if args.tibanna: - if not args.default_remote_prefix: - print( - "Error: --tibanna must be combined with --default-remote-prefix " - "to provide bucket name and subdirectory (prefix) " - "(e.g. 'bucketname/projectname'", - file=sys.stderr, - ) - sys.exit(1) - args.default_remote_prefix = args.default_remote_prefix.rstrip("/") - if not args.tibanna_sfn: - args.tibanna_sfn = os.environ.get("TIBANNA_DEFAULT_STEP_FUNCTION_NAME", "") - if not args.tibanna_sfn: - print( - "Error: to use --tibanna, either --tibanna-sfn or environment variable " - "TIBANNA_DEFAULT_STEP_FUNCTION_NAME must be set and exported " - "to provide name of the tibanna unicorn step function " - "(e.g. 'tibanna_unicorn_monty'). The step function must be deployed first " - "using tibanna cli (e.g. tibanna deploy_unicorn --usergroup=monty " - "--buckets=bucketname)", - file=sys.stderr, - ) - sys.exit(1) + return log_handler - if args.az_batch: - if not args.default_remote_provider or not args.default_remote_prefix: - print( - "Error: --az-batch must be combined with " - "--default-remote-provider AzBlob and --default-remote-prefix to " - "provide a blob container name\n", - file=sys.stderr, - ) - sys.exit(1) - elif args.az_batch_account_url is None: - print( - "Error: --az-batch-account-url must be set when --az-batch is used\n", - file=sys.stderr, - ) - sys.exit(1) - elif not url_can_parse(args.az_batch_account_url): - print( - "Error: invalide azure batch account url, please use format: https://{account_name}.{location}.batch.azure.com." - ) - sys.exit(1) - elif os.getenv("AZ_BATCH_ACCOUNT_KEY") is None: - print( - "Error: environment variable AZ_BATCH_ACCOUNT_KEY must be set when --az-batch is used\n", - file=sys.stderr, - ) - sys.exit(1) - if args.google_lifesciences: - if ( - not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") - and not args.google_lifesciences_service_account_email - ): - print( - "Error: Either the GOOGLE_APPLICATION_CREDENTIALS environment variable " - "or --google-lifesciences-service-account-email must be available " - "for --google-lifesciences", - file=sys.stderr, - ) - sys.exit(1) +def parse_edit_notebook(args): + edit_notebook = None + if args.draft_notebook: + args.targets = {args.draft_notebook} + edit_notebook = NotebookEditMode(draft_only=True) + elif args.edit_notebook: + args.targets = {args.edit_notebook} + args.force = True + edit_notebook = NotebookEditMode(args.notebook_listen) + return edit_notebook - if not args.default_remote_prefix: - print( - "Error: --google-lifesciences must be combined with " - " --default-remote-prefix to provide bucket name and " - "subdirectory (prefix) (e.g. 'bucketname/projectname'", - file=sys.stderr, - ) - sys.exit(1) - if args.delete_all_output and args.delete_temp_output: - print( - "Error: --delete-all-output and --delete-temp-output are mutually exclusive.", - file=sys.stderr, - ) - sys.exit(1) - - if args.gui is not None: - try: - import snakemake.gui as gui - except ImportError: - print( - "Error: GUI needs Flask to be installed. Install " - "with easy_install or contact your administrator.", - file=sys.stderr, - ) - sys.exit(1) +def parse_wait_for_files(args): + from snakemake.io import wait_for_files - _logging.getLogger("werkzeug").setLevel(_logging.ERROR) + aggregated_wait_for_files = args.wait_for_files + if args.wait_for_files_file is not None: + wait_for_files([args.wait_for_files_file], latency_wait=args.latency_wait) - _snakemake = partial(snakemake, os.path.abspath(snakefile)) - gui.register(_snakemake, args) + with open(args.wait_for_files_file) as fd: + extra_wait_files = [line.strip() for line in fd.readlines()] - if ":" in args.gui: - host, port = args.gui.split(":") + if aggregated_wait_for_files is None: + aggregated_wait_for_files = extra_wait_files else: - port = args.gui - host = "127.0.0.1" + aggregated_wait_for_files.update(extra_wait_files) + return aggregated_wait_for_files - url = f"http://{host}:{port}" - print(f"Listening on {url}.", file=sys.stderr) - def open_browser(): - try: - webbrowser.open(url) - except: - pass +def parse_rerun_triggers(values): + return {RerunTrigger[x] for x in values} - print("Open this address in your browser to access the GUI.", file=sys.stderr) - threading.Timer(0.5, open_browser).start() - success = True - try: - gui.app.run(debug=False, threaded=True, port=int(port), host=host) +def args_to_api(args, parser): + """Convert argparse args to API calls.""" - except (KeyboardInterrupt, SystemExit): - # silently close - pass - else: - log_handler = [] - if args.log_handler_script is not None: - if not os.path.exists(args.log_handler_script): - print( - "Error: no log handler script found, {}.".format( - args.log_handler_script - ), - file=sys.stderr, - ) - sys.exit(1) - log_script = SourceFileLoader("log", args.log_handler_script).load_module() - try: - log_handler.append(log_script.log_handler) - except: - print( - 'Error: Invalid log handler script, {}. Expect python function "log_handler(msg)".'.format( - args.log_handler_script - ), - file=sys.stderr, - ) - sys.exit(1) - - if args.log_service == "slack": - slack_logger = logging.SlackLogger() - log_handler.append(slack_logger.log_handler) - - elif args.wms_monitor or args.log_service == "wms": - # Generate additional metadata for server - metadata = generate_parser_metadata(parser, args) - wms_logger = logging.WMSLogger( - args.wms_monitor, args.wms_monitor_arg, metadata=metadata - ) - log_handler.append(wms_logger.log_handler) + if args.bash_completion: + cmd = b"complete -o bashdefault -C snakemake-bash-completion snakemake" + sys.stdout.buffer.write(cmd) + sys.exit(0) - if args.draft_notebook: - from snakemake import notebook + # handle legacy executor names + if args.dryrun: + args.executor = "dryrun" + elif args.touch: + args.executor = "touch" + elif args.executor is None: + args.executor = "local" + + executor_plugin = _get_executor_plugin_registry().get_plugin(args.executor) + executor_settings = executor_plugin.get_executor_settings(args) + + if args.cores is None: + if executor_plugin.common_settings.local_exec: + # use --jobs as an alias for --cores + args.cores = args.jobs + elif executor_plugin.common_settings.dryrun_exec: + args.cores = 1 + + # start profiler if requested + if args.runtime_profile: + import yappi - args.target = [args.draft_notebook] - args.edit_notebook = notebook.EditMode(draft_only=True) - elif args.edit_notebook: - from snakemake import notebook + yappi.start() - args.target = [args.edit_notebook] - args.force = True - args.edit_notebook = notebook.EditMode(args.notebook_listen) + log_handlers = setup_log_handlers(args, parser) - aggregated_wait_for_files = args.wait_for_files - if args.wait_for_files_file is not None: - wait_for_files([args.wait_for_files_file], latency_wait=args.latency_wait) + edit_notebook = parse_edit_notebook(args) - with open(args.wait_for_files_file) as fd: - extra_wait_files = [line.strip() for line in fd.readlines()] + wait_for_files = parse_wait_for_files(args) - if aggregated_wait_for_files is None: - aggregated_wait_for_files = extra_wait_files - else: - aggregated_wait_for_files.extend(extra_wait_files) - - success = snakemake( - snakefile, - batch=batch, - cache=args.cache, - report=args.report, - report_stylesheet=args.report_stylesheet, - lint=args.lint, - containerize=args.containerize, - generate_unit_tests=args.generate_unit_tests, - listrules=args.list, - list_target_rules=args.list_target_rules, - cores=args.cores, - local_cores=args.local_cores, - nodes=args.jobs, - resources=resources, - overwrite_threads=overwrite_threads, - max_threads=args.max_threads, - overwrite_scatter=overwrite_scatter, - default_resources=default_resources, - overwrite_resources=overwrite_resources, - overwrite_resource_scopes=overwrite_resource_scopes, - config=config, - configfiles=args.configfile, - config_args=args.config, - workdir=args.directory, - targets=args.target, - target_jobs=parse_target_jobs_cli_args(args), - dryrun=args.dryrun, + with SnakemakeApi( + OutputSettings( printshellcmds=args.printshellcmds, - debug_dag=args.debug_dag, - printdag=args.dag, - printrulegraph=args.rulegraph, - printfilegraph=args.filegraph, - printd3dag=args.d3dag, - touch=args.touch, - forcetargets=args.force, - forceall=args.forceall, - forcerun=args.forcerun, - prioritytargets=args.prioritize, - until=args.until, - omit_from=args.omit_from, - stats=args.stats, nocolor=args.nocolor, quiet=args.quiet, - keepgoing=args.keep_going, - slurm=args.slurm, - slurm_jobstep=args.slurm_jobstep, - rerun_triggers=args.rerun_triggers, - cluster=args.cluster, - cluster_sync=args.cluster_sync, - drmaa=args.drmaa, - drmaa_log_dir=args.drmaa_log_dir, - kubernetes=args.kubernetes, - container_image=args.container_image, - k8s_cpu_scalar=args.k8s_cpu_scalar, - k8s_service_account_name=args.k8s_service_account_name, - flux=args.flux, - tibanna=args.tibanna, - tibanna_sfn=args.tibanna_sfn, - az_batch=args.az_batch, - az_batch_enable_autoscale=args.az_batch_enable_autoscale, - az_batch_account_url=args.az_batch_account_url, - google_lifesciences=args.google_lifesciences, - google_lifesciences_regions=args.google_lifesciences_regions, - google_lifesciences_location=args.google_lifesciences_location, - google_lifesciences_cache=args.google_lifesciences_keep_cache, - google_lifesciences_service_account_email=args.google_lifesciences_service_account_email, - google_lifesciences_network=args.google_lifesciences_network, - google_lifesciences_subnetwork=args.google_lifesciences_subnetwork, - tes=args.tes, - precommand=args.precommand, - preemption_default=args.preemption_default, - preemptible_rules=args.preemptible_rules, - tibanna_config=args.tibanna_config, - jobname=args.jobname, - immediate_submit=args.immediate_submit, - standalone=True, - ignore_ambiguity=args.allow_ambiguity, - lock=not args.nolock, - unlock=args.unlock, - cleanup_metadata=args.cleanup_metadata, - conda_cleanup_envs=args.conda_cleanup_envs, - cleanup_containers=args.cleanup_containers, - cleanup_shadow=args.cleanup_shadow, - force_incomplete=args.rerun_incomplete, - ignore_incomplete=args.ignore_incomplete, - list_version_changes=args.list_version_changes, - list_code_changes=args.list_code_changes, - list_input_changes=args.list_input_changes, - list_params_changes=args.list_params_changes, - list_untracked=args.list_untracked, - summary=args.summary, - detailed_summary=args.detailed_summary, - archive=args.archive, - delete_all_output=args.delete_all_output, - delete_temp_output=args.delete_temp_output, - print_compilation=args.print_compilation, + debug_dag=args.debug_dag, verbose=args.verbose, - debug=args.debug, - jobscript=args.jobscript, - notemp=args.notemp, - all_temp=args.all_temp, - keep_remote_local=args.keep_remote, - greediness=args.greediness, - no_hooks=args.no_hooks, - overwrite_shellcmd=args.overwrite_shellcmd, - latency_wait=args.latency_wait, - wait_for_files=aggregated_wait_for_files, - keep_target_files=args.keep_target_files, - allowed_rules=args.allowed_rules, - max_jobs_per_second=args.max_jobs_per_second, - max_status_checks_per_second=args.max_status_checks_per_second, - restart_times=args.retries, - attempt=args.attempt, - force_use_threads=args.force_use_threads, - use_conda=args.use_conda, - conda_frontend=args.conda_frontend, - conda_prefix=args.conda_prefix, - conda_cleanup_pkgs=args.conda_cleanup_pkgs, - list_conda_envs=args.list_conda_envs, - use_singularity=args.use_singularity, - use_env_modules=args.use_envmodules, - singularity_prefix=args.singularity_prefix, - shadow_prefix=args.shadow_prefix, - singularity_args=args.singularity_args, - scheduler=args.scheduler, - scheduler_ilp_solver=args.scheduler_ilp_solver, - conda_create_envs_only=args.conda_create_envs_only, - mode=args.mode, - wrapper_prefix=args.wrapper_prefix, - default_remote_provider=args.default_remote_provider, - default_remote_prefix=args.default_remote_prefix, - assume_shared_fs=not args.no_shared_fs, - cluster_status=args.cluster_status, - cluster_cancel=args.cluster_cancel, - cluster_cancel_nargs=args.cluster_cancel_nargs, - cluster_sidecar=args.cluster_sidecar, - export_cwl=args.export_cwl, show_failed_logs=args.show_failed_logs, - keep_incomplete=args.keep_incomplete, - keep_metadata=not args.drop_metadata, - edit_notebook=args.edit_notebook, - envvars=args.envvars, - overwrite_groups=overwrite_groups, - group_components=group_components, - max_inventory_wait_time=args.max_inventory_time, - log_handler=log_handler, - execute_subworkflows=not args.no_subworkflows, - conda_not_block_search_path_envvars=args.conda_not_block_search_path_envvars, - scheduler_solver_path=args.scheduler_solver_path, - conda_base_path=args.conda_base_path, - local_groupid=args.local_groupid, - executor_args=executor_args, - cleanup_scripts=not args.skip_script_cleanup, + log_handlers=log_handlers, + keep_logger=False, ) + ) as snakemake_api: + try: + workflow_api = snakemake_api.workflow( + resource_settings=ResourceSettings( + cores=args.cores, + nodes=args.jobs, + local_cores=args.local_cores, + max_threads=args.max_threads, + resources=args.resources, + overwrite_threads=args.set_threads, + overwrite_scatter=args.set_scatter, + overwrite_resource_scopes=args.set_resource_scopes, + overwrite_resources=args.set_resources, + default_resources=args.default_resources, + ), + config_settings=ConfigSettings( + config=args.config, + configfiles=args.configfile, + ), + storage_settings=StorageSettings( + default_remote_provider=args.default_remote_provider, + default_remote_prefix=args.default_remote_prefix, + assume_shared_fs=not args.no_shared_fs, + keep_remote_local=args.keep_remote, + notemp=args.notemp, + all_temp=args.all_temp, + ), + workflow_settings=WorkflowSettings( + wrapper_prefix=args.wrapper_prefix, + ), + snakefile=args.snakefile, + workdir=args.directory, + ) + + if args.lint: + any_lint = workflow_api.lint() + if any_lint: + # trigger exit code 1 + return False + elif args.list_target_rules: + workflow_api.list_rules(only_targets=True) + elif args.list_rules: + workflow_api.list_rules(only_targets=False) + elif args.print_compilation: + workflow_api.print_compilation() + else: + deployment_method = args.software_deployment_method + if args.use_conda: + deployment_method.add(DeploymentMethod.CONDA) + if args.use_apptainer: + deployment_method.add(DeploymentMethod.APPTAINER) + if args.use_envmodules: + deployment_method.add(DeploymentMethod.ENV_MODULES) + + dag_api = workflow_api.dag( + dag_settings=DAGSettings( + targets=args.targets, + target_jobs=args.target_jobs, + target_files_omit_workdir_adjustment=args.target_files_omit_workdir_adjustment, + batch=args.batch, + forcetargets=args.force, + forceall=args.forceall, + forcerun=args.forcerun, + until=args.until, + omit_from=args.omit_from, + force_incomplete=args.rerun_incomplete, + allowed_rules=args.allowed_rules, + rerun_triggers=args.rerun_triggers, + max_inventory_wait_time=args.max_inventory_time, + cache=args.cache, + ), + deployment_settings=DeploymentSettings( + deployment_method=deployment_method, + conda_prefix=args.conda_prefix, + conda_cleanup_pkgs=args.conda_cleanup_pkgs, + conda_base_path=args.conda_base_path, + conda_frontend=args.conda_frontend, + conda_not_block_search_path_envvars=args.conda_not_block_search_path_envvars, + apptainer_args=args.apptainer_args, + apptainer_prefix=args.apptainer_prefix, + ), + ) + if args.preemptible_rules is not None: + if not preemptible_rules: + # no specific rule given, consider all to be made preemptible + preemptible_rules = PreemptibleRules(all=True) + else: + preemptible_rules = PreemptibleRules( + rules=args.preemptible_rules + ) + else: + preemptible_rules = PreemptibleRules() + + if args.containerize: + dag_api.containerize() + elif args.report: + dag_api.create_report( + path=args.report, + stylesheet=args.report_stylesheet, + ) + elif args.generate_unit_tests: + dag_api.generate_unit_tests(args.generate_unit_tests) + elif args.dag: + dag_api.printdag() + elif args.rulegraph: + dag_api.printrulegraph() + elif args.filegraph: + dag_api.printfilegraph() + elif args.d3dag: + dag_api.printd3dag() + elif args.unlock: + dag_api.unlock() + elif args.cleanup_metadata: + dag_api.cleanup_metadata(args.cleanup_metadata) + elif args.conda_cleanup_envs: + dag_api.conda_cleanup_envs() + elif args.conda_create_envs_only: + dag_api.conda_create_envs() + elif args.list_conda_envs: + dag_api.conda_list_envs() + elif args.cleanup_shadow: + dag_api.cleanup_shadow() + elif args.container_cleanup_images: + dag_api.container_cleanup_images() + elif args.list_changes: + dag_api.list_changes(args.list_changes) + elif args.list_untracked: + dag_api.list_untracked() + elif args.summary: + dag_api.summary() + elif args.detailed_summary: + dag_api.summary(detailed=True) + elif args.archive: + dag_api.archive(args.archive) + elif args.delete_all_output: + dag_api.delete_output() + elif args.delete_temp_output: + dag_api.delete_output(only_temp=True, dryrun=args.dryrun) + else: + dag_api.execute_workflow( + executor=args.executor, + execution_settings=ExecutionSettings( + keep_going=args.keep_going, + debug=args.debug, + standalone=True, + ignore_ambiguity=args.allow_ambiguity, + lock=not args.nolock, + ignore_incomplete=args.ignore_incomplete, + latency_wait=args.latency_wait, + wait_for_files=wait_for_files, + no_hooks=args.no_hooks, + retries=args.retries, + attempt=args.attempt, + use_threads=args.force_use_threads, + shadow_prefix=args.shadow_prefix, + mode=args.mode, + keep_incomplete=args.keep_incomplete, + keep_metadata=not args.drop_metadata, + edit_notebook=edit_notebook, + cleanup_scripts=not args.skip_script_cleanup, + ), + remote_execution_settings=RemoteExecutionSettings( + jobname=args.jobname, + jobscript=args.jobscript, + max_status_checks_per_second=args.max_status_checks_per_second, + seconds_between_status_checks=args.seconds_between_status_checks, + container_image=args.container_image, + preemptible_retries=args.preemptible_retries, + preemptible_rules=preemptible_rules, + envvars=args.envvars, + immediate_submit=args.immediate_submit, + ), + scheduling_settings=SchedulingSettings( + prioritytargets=args.prioritize, + scheduler=args.scheduler, + ilp_solver=args.scheduler_ilp_solver, + solver_path=args.scheduler_solver_path, + greediness=args.scheduler_greediness, + max_jobs_per_second=args.max_jobs_per_second, + ), + executor_settings=executor_settings, + ) + + except Exception as e: + snakemake_api.print_exception(e) + return False + + # store profiler results if requested if args.runtime_profile: with open(args.runtime_profile, "w") as out: profile = yappi.get_func_stats() @@ -2438,7 +2038,18 @@ def open_browser(): 4: ("tavg", 8), }, ) + return True + +def main(argv=None): + """Main entry point.""" + logging.setup_logger() + try: + parser, args = parse_args(argv) + success = args_to_api(args, parser) + except Exception as e: + print_exception(e) + sys.exit(1) sys.exit(0 if success else 1) diff --git a/snakemake/common/__init__.py b/snakemake/common/__init__.py index b6be147e8..4f79a0c61 100644 --- a/snakemake/common/__init__.py +++ b/snakemake/common/__init__.py @@ -3,8 +3,7 @@ __email__ = "johannes.koester@protonmail.com" __license__ = "MIT" -import concurrent.futures -import contextlib +from enum import Enum import itertools import math import operator @@ -32,17 +31,27 @@ NOTHING_TO_BE_DONE_MSG = ( "Nothing to be done (all requested files are present and up to date)." ) -RERUN_TRIGGERS = ["mtime", "params", "input", "software-env", "code"] ON_WINDOWS = platform.system() == "Windows" # limit the number of input/output files list in job properties # see https://github.com/snakemake/snakemake/issues/2097 IO_PROP_LIMIT = 100 +SNAKEFILE_CHOICES = list( + map( + Path, + ( + "Snakefile", + "snakefile", + "workflow/Snakefile", + "workflow/snakefile", + ), + ) +) def get_snakemake_searchpaths(): paths = [str(Path(__file__).parent.parent.parent)] + [ - path for path in sys.path if path.endswith("site-packages") + path for path in sys.path if os.path.isdir(path) ] return list(unique_justseen(paths)) @@ -56,6 +65,7 @@ def parse_key_value_arg(arg, errmsg): key, val = arg.split("=", 1) except ValueError: raise ValueError(errmsg + f" (Unparseable value: {repr(arg)})") + val = val.strip("'\"") return key, val @@ -267,25 +277,6 @@ def get_input_function_aux_params(func, candidate_params): return {k: v for k, v in candidate_params.items() if k in func_params} -_pool = concurrent.futures.ThreadPoolExecutor() - - -@contextlib.asynccontextmanager -async def async_lock(_lock: threading.Lock): - """Use a threaded lock form threading.Lock in an async context - - Necessary because asycio.Lock is not threadsafe, so only one thread can safely use - it at a time. - Source: https://stackoverflow.com/a/63425191 - """ - loop = asyncio.get_event_loop() - await loop.run_in_executor(_pool, _lock.acquire) - try: - yield # the lock is held - finally: - _lock.release() - - def unique_justseen(iterable, key=None): """ List unique elements, preserving order. Remember only the element just seen. diff --git a/snakemake/common/argparse.py b/snakemake/common/argparse.py new file mode 100644 index 000000000..5b8ee1735 --- /dev/null +++ b/snakemake/common/argparse.py @@ -0,0 +1,57 @@ +import argparse + +import configargparse + + +class ArgumentParser(configargparse.ArgumentParser): + def add_argument( + self, + *args, + parse_func=None, + **kwargs, + ): + if parse_func is not None: + register_parser_action(parse_func, kwargs) + super().add_argument(*args, **kwargs) + + def add_argument_group(self, *args, **kwargs): + group = ArgumentGroup(self, *args, **kwargs) + self._action_groups.append(group) + return group + + +class ArgumentGroup(argparse._ArgumentGroup): + def add_argument( + self, + *args, + parse_func=None, + **kwargs, + ): + if parse_func is not None: + register_parser_action(parse_func, kwargs) + super().add_argument(*args, **kwargs) + + +def register_parser_action(parse_func, kwargs): + if "action" in kwargs: + raise ValueError( + "Cannot specify action if parser argument is provided to add_argument." + ) + + class ParserAction(argparse._StoreAction): + def __init__(self, *args, **kwargs): + if "parser" in kwargs: + del kwargs["parse_func"] + super().__init__(*args, **kwargs) + + def __call__( + self, + parser, + namespace, + values, + option_string=None, + ): + parsed = parse_func(values) + setattr(namespace, self.dest, parsed) + + kwargs["action"] = ParserAction diff --git a/snakemake/common/configfile.py b/snakemake/common/configfile.py new file mode 100644 index 000000000..3b2fd7664 --- /dev/null +++ b/snakemake/common/configfile.py @@ -0,0 +1,43 @@ +import collections +import json +from pathlib import Path +from snakemake_interface_common.exceptions import WorkflowError + + +def _load_configfile(configpath_or_obj, filetype="Config"): + "Tries to load a configfile first as JSON, then as YAML, into a dict." + import yaml + + if isinstance(configpath_or_obj, str) or isinstance(configpath_or_obj, Path): + obj = open(configpath_or_obj, encoding="utf-8") + else: + obj = configpath_or_obj + + try: + with obj as f: + try: + return json.load(f, object_pairs_hook=collections.OrderedDict) + except ValueError: + f.seek(0) # try again + try: + import yte + + return yte.process_yaml(f, require_use_yte=True) + except yaml.YAMLError: + raise WorkflowError( + f"{filetype} file is not valid JSON or YAML. " + "In case of YAML, make sure to not mix " + "whitespace and tab indentation." + ) + except FileNotFoundError: + raise WorkflowError(f"{filetype} file {configpath_or_obj} not found.") + + +def load_configfile(configpath): + "Loads a JSON or YAML configfile as a dict, then checks that it's a dict." + config = _load_configfile(configpath) + if not isinstance(config, dict): + raise WorkflowError( + "Config file must be given as JSON or YAML with keys at top level." + ) + return config diff --git a/snakemake/common/git.py b/snakemake/common/git.py new file mode 100644 index 000000000..1e0ce6499 --- /dev/null +++ b/snakemake/common/git.py @@ -0,0 +1,85 @@ +import os +import re + +from snakemake.exceptions import WorkflowError + + +def split_git_path(path): + file_sub = re.sub(r"^git\+file:/+", "/", path) + (file_path, version) = file_sub.split("@") + file_path = os.path.realpath(file_path) + root_path = get_git_root(file_path) + if file_path.startswith(root_path): + file_path = file_path[len(root_path) :].lstrip("/") + return (root_path, file_path, version) + + +def get_git_root(path): + """ + Args: + path: (str) Path a to a directory/file that is located inside the repo + Returns: + path to the root folder for git repo + """ + import git + + try: + git_repo = git.Repo(path, search_parent_directories=True) + return git_repo.git.rev_parse("--show-toplevel") + except git.exc.NoSuchPathError: + tail, _ = os.path.split(path) + return get_git_root_parent_directory(tail, path) + + +def get_git_root_parent_directory(path, input_path): + """ + This function will recursively go through parent directories until a git + repository is found or until no parent directories are left, in which case + an error will be raised. This is needed when providing a path to a + file/folder that is located on a branch/tag not currently checked out. + + Args: + path: (str) Path a to a directory that is located inside the repo + input_path: (str) origin path, used when raising WorkflowError + Returns: + path to the root folder for git repo + """ + import git + + try: + git_repo = git.Repo(path, search_parent_directories=True) + return git_repo.git.rev_parse("--show-toplevel") + except git.exc.NoSuchPathError: + tail, _ = os.path.split(path) + if tail is None: + raise WorkflowError( + f"Neither provided git path ({input_path}) " + + "or parent directories contain a valid git repo." + ) + else: + return get_git_root_parent_directory(tail, input_path) + + +def git_content(git_file): + """ + This function will extract a file from a git repository, one located on + the filesystem. + The expected format is git+file:///path/to/your/repo/path_to_file@version + + Args: + env_file (str): consist of path to repo, @, version, and file information + Ex: git+file:///home/smeds/snakemake-wrappers/bio/fastqc/wrapper.py@0.19.3 + Returns: + file content or None if the expected format isn't meet + """ + import git + + if git_file.startswith("git+file:"): + (root_path, file_path, version) = split_git_path(git_file) + return git.Repo(root_path).git.show(f"{version}:{file_path}") + else: + raise WorkflowError( + "Provided git path ({}) doesn't meet the " + "expected format:".format(git_file) + ", expected format is " + "git+file://PATH_TO_REPO/PATH_TO_FILE_INSIDE_REPO@VERSION" + ) diff --git a/snakemake/common/tests/__init__.py b/snakemake/common/tests/__init__.py new file mode 100644 index 000000000..aab6983d4 --- /dev/null +++ b/snakemake/common/tests/__init__.py @@ -0,0 +1,122 @@ +from abc import ABC, abstractmethod +from pathlib import Path +import shutil +from typing import List, Optional + +import pytest +from snakemake import api, settings + +from snakemake_interface_executor_plugins import ExecutorSettingsBase +from snakemake_interface_executor_plugins.registry import ExecutorPluginRegistry + + +def handle_testcase(func): + def wrapper(self, tmp_path): + if self.expect_exception is None: + try: + return func(self, tmp_path) + finally: + self.cleanup_test() + else: + with pytest.raises(self.expect_exception): + try: + return func(self, tmp_path) + finally: + self.cleanup_test() + + return wrapper + + +class TestWorkflowsBase(ABC): + __test__ = False + expect_exception = None + + @abstractmethod + def get_executor(self) -> str: + ... + + @abstractmethod + def get_executor_settings(self) -> Optional[ExecutorSettingsBase]: + ... + + @abstractmethod + def get_default_remote_provider(self) -> Optional[str]: + ... + + @abstractmethod + def get_default_remote_prefix(self) -> Optional[str]: + ... + + def get_assume_shared_fs(self) -> bool: + return True + + def get_envvars(self) -> List[str]: + return [] + + def cleanup_test(self): + """This method is called after every testcase, also in case of exceptions. + + Override to clean up any test files (e.g. in remote storage). + """ + pass + + def _run_workflow(self, test_name, tmp_path, deployment_method=frozenset()): + test_path = Path(__file__).parent / "testcases" / test_name + tmp_path = Path(tmp_path) / test_name + self._copy_test_files(test_path, tmp_path) + + if self._common_settings().local_exec: + cores = 3 + nodes = None + else: + cores = 1 + nodes = 3 + + snakemake_api = api.SnakemakeApi( + settings.OutputSettings( + verbose=True, + ), + ) + workflow_api = snakemake_api.workflow( + resource_settings=settings.ResourceSettings( + cores=cores, + nodes=nodes, + ), + storage_settings=settings.StorageSettings( + default_remote_provider=self.get_default_remote_provider(), + default_remote_prefix=self.get_default_remote_prefix(), + assume_shared_fs=self.get_assume_shared_fs(), + ), + workdir=Path(tmp_path), + snakefile=test_path / "Snakefile", + ) + + dag_api = workflow_api.dag( + deployment_settings=settings.DeploymentSettings( + deployment_method=deployment_method, + ), + ) + dag_api.execute_workflow( + executor=self.get_executor(), + executor_settings=self.get_executor_settings(), + remote_execution_settings=settings.RemoteExecutionSettings( + seconds_between_status_checks=0, + envvars=self.get_envvars(), + ), + ) + snakemake_api.cleanup() + + @handle_testcase + def test_simple_workflow(self, tmp_path): + self._run_workflow("simple", tmp_path) + + @handle_testcase + def test_group_workflow(self, tmp_path): + self._run_workflow("groups", tmp_path) + + def _copy_test_files(self, test_path, tmp_path): + shutil.copytree(test_path, tmp_path) + + def _common_settings(self): + registry = ExecutorPluginRegistry() + return registry.get_plugin(self.get_executor()).common_settings diff --git a/snakemake/executors/slurm/__init__.py b/snakemake/common/tests/testcases/__init__.py similarity index 100% rename from snakemake/executors/slurm/__init__.py rename to snakemake/common/tests/testcases/__init__.py diff --git a/snakemake/common/tests/testcases/groups/Snakefile b/snakemake/common/tests/testcases/groups/Snakefile new file mode 100644 index 000000000..eb8eaf393 --- /dev/null +++ b/snakemake/common/tests/testcases/groups/Snakefile @@ -0,0 +1,32 @@ +rule all: + input: + "test3.out" + +rule a: + output: + "test1.{sample}.out" + group: "foo" + shell: + "touch {output}" + + +rule b: + input: + "test1.{sample}.out" + output: + "test2.{sample}.out" + group: "foo" + threads: 2 + shell: + "cp {input} {output}" + + +rule c: + input: + expand("test2.{sample}.out", sample=[1, 2, 3]) + output: + "test3.out" + resources: + mem="5MB" + shell: + "cat {input} > {output}" \ No newline at end of file diff --git a/tests/test14/raw.0.txt b/snakemake/common/tests/testcases/groups/__init__.py similarity index 100% rename from tests/test14/raw.0.txt rename to snakemake/common/tests/testcases/groups/__init__.py diff --git a/snakemake/common/tests/testcases/simple/Snakefile b/snakemake/common/tests/testcases/simple/Snakefile new file mode 100644 index 000000000..4f2143f33 --- /dev/null +++ b/snakemake/common/tests/testcases/simple/Snakefile @@ -0,0 +1,30 @@ +rule all: + input: + "test3.out" + +rule a: + output: + "test1.out" + shell: + "touch {output}" + + +rule b: + input: + "test1.out" + output: + "test2.out" + threads: 2 + shell: + "cp {input} {output}" + + +rule c: + input: + "test2.out" + output: + "test3.out" + resources: + mem="5MB" + shell: + "cp {input} {output}" \ No newline at end of file diff --git a/tests/test14/raw.1.txt b/snakemake/common/tests/testcases/simple/__init__.py similarity index 100% rename from tests/test14/raw.1.txt rename to snakemake/common/tests/testcases/simple/__init__.py diff --git a/snakemake/common/workdir_handler.py b/snakemake/common/workdir_handler.py new file mode 100644 index 000000000..d05bec254 --- /dev/null +++ b/snakemake/common/workdir_handler.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass, field +import os +from pathlib import Path +from typing import Optional + +from snakemake.logging import logger + + +@dataclass +class WorkdirHandler: + workdir: Optional[Path] = None + olddir: Optional[Path] = field(init=False, default=None) + + def change_to(self): + if self.workdir is not None: + self.olddir = Path.cwd() + if not self.workdir.exists(): + logger.info(f"Creating specified working directory {self.workdir}.") + self.workdir.mkdir(parents=True) + os.chdir(self.workdir) + + def change_back(self): + if self.workdir is not None: + os.chdir(self.olddir) diff --git a/snakemake/cwl.py b/snakemake/cwl.py index 0389e9b37..1cc1f0cdd 100644 --- a/snakemake/cwl.py +++ b/snakemake/cwl.py @@ -14,6 +14,7 @@ from snakemake.exceptions import WorkflowError from snakemake.shell import shell from snakemake.common import get_container_image +from snakemake_interface_executor_plugins.settings import ExecMode def cwl( @@ -199,18 +200,18 @@ def dag_to_cwl(dag): "requirements": {"ResourceRequirement": {"coresMin": "$(inputs.cores)"}}, "arguments": [ "--force", - "--keep-target-files", + "--target-files-omit-workdir-adjustment", "--keep-remote", "--force-use-threads", "--wrapper-prefix", - dag.workflow.wrapper_prefix, + dag.workflow.workflow_settings.wrapper_prefix, "--notemp", "--quiet", "--use-conda", "--no-hooks", "--nolock", "--mode", - str(Mode.subprocess), + str(ExecMode.SUBPROCESS.item_to_choice()), ], "inputs": { "snakefile": { diff --git a/snakemake/dag.py b/snakemake/dag.py index 5e9dbb123..977e306da 100755 --- a/snakemake/dag.py +++ b/snakemake/dag.py @@ -7,7 +7,6 @@ import os import shutil import subprocess -import sys import tarfile import textwrap import time @@ -18,12 +17,14 @@ from itertools import chain, filterfalse, groupby from operator import attrgetter from pathlib import Path +from snakemake.settings import DeploymentMethod from snakemake_interface_executor_plugins.dag import DAGExecutorInterface from snakemake import workflow from snakemake import workflow as _workflow from snakemake.common import DYNAMIC_FILL, ON_WINDOWS, group_into_chunks, is_local_file +from snakemake.settings import RerunTrigger from snakemake.deployment import singularity from snakemake.exceptions import ( AmbiguousRuleException, @@ -51,61 +52,11 @@ from snakemake.logging import logger from snakemake.output_index import OutputIndex from snakemake.sourcecache import LocalSourceFile, SourceFile +from snakemake.settings import ChangeType, Batch PotentialDependency = namedtuple("PotentialDependency", ["file", "jobs", "known"]) -class Batch: - """Definition of a batch for calculating only a partial DAG.""" - - def __init__(self, rulename: str, idx: int, batches: int): - assert idx <= batches - assert idx > 0 - self.rulename = rulename - self.idx = idx - self.batches = batches - - def get_batch(self, items: list): - """Return the defined batch of the given items. - Items are usually input files.""" - # make sure that we always consider items in the same order - if len(items) < self.batches: - raise WorkflowError( - "Batching rule {} has less input files than batches. " - "Please choose a smaller number of batches.".format(self.rulename) - ) - items = sorted(items) - - # we can equally split items using divmod: - # len(items) = (self.batches * quotient) + remainder - # Because remainder always < divisor (self.batches), - # each batch will be equal to quotient + (1 or 0 item) - # from the remainder - k, m = divmod(len(items), self.batches) - - # self.batch is one-based, hence we have to subtract 1 - idx = self.idx - 1 - - # First n batches will have k (quotient) items + - # one item from the remainder (m). Once we consume all items - # from the remainder, last batches only contain k items. - i = idx * k + min(idx, m) - batch_len = (idx + 1) * k + min(idx + 1, m) - - if self.is_final: - # extend the last batch to cover rest of list - return items[i:] - else: - return items[i:batch_len] - - @property - def is_final(self): - return self.idx == self.batches - - def __str__(self): - return f"{self.idx}/{self.batches} (rule {self.rulename})" - - class DAG(DAGExecutorInterface): """Directed acyclic graph of jobs.""" @@ -113,10 +64,8 @@ def __init__( self, workflow, rules=None, - dryrun=False, targetfiles=None, targetrules=None, - target_jobs_def=None, forceall=False, forcerules=None, forcefiles=None, @@ -126,14 +75,8 @@ def __init__( untilrules=None, omitfiles=None, omitrules=None, - ignore_ambiguity=False, - force_incomplete=False, ignore_incomplete=False, - notemp=False, - keep_remote_local=False, - batch=None, ): - self.dryrun = dryrun self.dependencies = defaultdict(partial(defaultdict, set)) self.depending = defaultdict(partial(defaultdict, set)) self._needrun = set() @@ -144,20 +87,16 @@ def __init__( self._len = 0 self.workflow: _workflow.Workflow = workflow self.rules = set(rules) - self.ignore_ambiguity = ignore_ambiguity self.targetfiles = targetfiles self.targetrules = targetrules - self.target_jobs_def = target_jobs_def - self.target_jobs_rules = ( - {spec.rulename for spec in target_jobs_def} if target_jobs_def else set() - ) + self.target_jobs_rules = { + spec.rulename for spec in self.workflow.dag_settings.target_jobs + } self.priorityfiles = priorityfiles self.priorityrules = priorityrules self.targetjobs = set() self.prioritytargetjobs = set() self._ready_jobs = set() - self.notemp = notemp - self.keep_remote_local = keep_remote_local self._jobid = dict() self.job_cache = dict() self.conda_envs = dict() @@ -176,7 +115,6 @@ def __init__( self.untilfiles = set() self.omitrules = set() self.omitfiles = set() - self.updated_subworkflow_files = set() if forceall: self.forcerules.update(self.rules) elif forcerules: @@ -196,20 +134,22 @@ def __init__( self.omitforce = set() - self.batch = batch - if batch is not None and not batch.is_final: + if self.batch is not None and not self.batch.is_final: # Since not all input files of a batching rule are considered, we cannot run # beyond that rule. # For the final batch, we do not need to omit anything. - self.omitrules.add(batch.rulename) + self.omitrules.add(self.batch.rulename) - self.force_incomplete = force_incomplete self.ignore_incomplete = ignore_incomplete self.periodic_wildcard_detector = PeriodicityDetector() self.update_output_index() + @property + def batch(self): + return self.workflow.dag_settings.batch + def init(self, progress=False): """Initialise the DAG.""" for job in map(self.rule2job, self.targetrules): @@ -225,19 +165,18 @@ def init(self, progress=False): ) self.targetjobs.add(job) - if self.target_jobs_def: - for spec in self.target_jobs_def: - job = self.update( - [ - self.new_job( - self.workflow.get_rule(spec.rulename), - wildcards_dict=spec.wildcards_dict, - ) - ], - progress=progress, - create_inventory=True, - ) - self.targetjobs.add(job) + for spec in self.workflow.dag_settings.target_jobs: + job = self.update( + [ + self.new_job( + self.workflow.get_rule(spec.rulename), + wildcards_dict=spec.wildcards_dict, + ) + ], + progress=progress, + create_inventory=True, + ) + self.targetjobs.add(job) self.cleanup() @@ -334,13 +273,17 @@ def update_conda_envs(self): env_set = { (job.conda_env_spec, job.container_img_url) for job in self.jobs - if job.conda_env_spec and (self.workflow.assume_shared_fs or job.is_local) + if job.conda_env_spec + and (self.workflow.storage_settings.assume_shared_fs or job.is_local) } # Then based on md5sum values for env_spec, simg_url in env_set: simg = None - if simg_url and self.workflow.use_singularity: + if simg_url and ( + DeploymentMethod.APPTAINER + in self.workflow.deployment_settings.deployment_method + ): assert ( simg_url in self.container_imgs ), "bug: must first pull singularity images" @@ -350,14 +293,15 @@ def update_conda_envs(self): env = env_spec.get_conda_env( self.workflow, container_img=simg, - cleanup=self.workflow.conda_cleanup_pkgs, + cleanup=self.workflow.deployment_settings.conda_cleanup_pkgs, ) self.conda_envs[key] = env def create_conda_envs(self, dryrun=False, quiet=False): + dryrun |= self.workflow.dryrun for env in self.conda_envs.values(): if (not dryrun or not quiet) and not env.is_named: - env.create(dryrun) + env.create(self.workflow.dryrun) def update_container_imgs(self): # First deduplicate based on job.conda_env_spec @@ -372,10 +316,10 @@ def update_container_imgs(self): img = singularity.Image(img_url, self, is_containerized) self.container_imgs[img_url] = img - def pull_container_imgs(self, dryrun=False, quiet=False): + def pull_container_imgs(self, quiet=False): for img in self.container_imgs.values(): - if not dryrun or not quiet: - img.pull(dryrun) + if not self.workflow.dryrun or not quiet: + img.pull(self.workflow.dryrun) def update_output_index(self): """Update the OutputIndex.""" @@ -387,7 +331,7 @@ def check_incomplete(self): if not self.ignore_incomplete: incomplete = self.incomplete_files if incomplete: - if self.force_incomplete: + if self.workflow.dag_settings.force_incomplete: logger.debug("Forcing incomplete files:") logger.debug("\t" + "\n\t".join(incomplete)) self.forcefiles.update(incomplete) @@ -400,7 +344,7 @@ def incomplete_external_jobid(self, job): Returns None, if job is not incomplete, or if no external jobid has been registered or if force_incomplete is True. """ - if self.force_incomplete: + if self.workflow.dag_settings.force_incomplete: return None jobids = self.workflow.persistence.external_jobids(job) if len(jobids) == 1: @@ -424,7 +368,10 @@ def check_dynamic(self): self.postprocess() def is_edit_notebook_job(self, job): - return self.workflow.edit_notebook and job.targetfile in self.targetfiles + return ( + self.workflow.execution_settings.edit_notebook + and job.targetfile in self.targetfiles + ) def get_job_group(self, job): return self._group.get(job) @@ -717,7 +664,7 @@ def temp_size(self, job): def handle_temp(self, job): """Remove temp files if they are no longer needed. Update temp_mtimes.""" - if self.notemp: + if self.workflow.storage_settings.notemp: return if job.is_group(): @@ -757,7 +704,7 @@ def unneeded_files(): yield from filterfalse(partial(needed, job), tempfiles) for f in unneeded_files(): - if self.dryrun: + if self.workflow.dryrun: logger.info(f"Would remove temporary output {f}") else: logger.info(f"Removing temporary output {f}.") @@ -795,7 +742,7 @@ def handle_remote(self, job, upload=True): "read AND write permissions." ) - if not self.keep_remote_local: + if not self.workflow.storage_settings.keep_remote_local: if not any(f.is_remote for f in job.input): return @@ -937,16 +884,17 @@ def is_strictly_higher_ordered(pivot_job): ambiguities = list( filter(lambda x: not x < producer and not producer < x, producers[1:]) ) - if ambiguities and not self.ignore_ambiguity: + if ambiguities and not self.workflow.execution_settings.ignore_ambiguity: raise AmbiguousRuleException(file, producer, ambiguities[0]) logger.dag_debug(dict(status="selected", job=producer)) - logger.dag_debug( - dict( - file=file, - msg="Producer found, hence exceptions are ignored.", - exception=WorkflowError(*exceptions), + if exceptions: + logger.dag_debug( + dict( + file=file, + msg="Producer found, hence exceptions are ignored.", + exception=WorkflowError(*exceptions), + ) ) - ) return producer def update_( @@ -1099,9 +1047,6 @@ def is_same_checksum(f, job): def update_needrun(job): reason = self.reason(job) noinitreason = not reason - updated_subworkflow_input = self.updated_subworkflow_files.intersection( - job.input - ) if ( job not in self.omitforce @@ -1109,8 +1054,6 @@ def update_needrun(job): or not self.forcefiles.isdisjoint(job.output) ): reason.forced = True - elif updated_subworkflow_input: - reason.updated_input.update(updated_subworkflow_input) elif job in self.targetjobs: # TODO find a way to handle added/removed input files here? if not job.has_products(include_logfiles=False): @@ -1127,7 +1070,7 @@ def update_needrun(job): if job.rule in self.targetrules: files = set(job.products(include_logfiles=False)) elif ( - self.target_jobs_def is not None + self.workflow.dag_settings.target_jobs and job.rule.name in self.target_jobs_rules ): files = set(job.products(include_logfiles=False)) @@ -1164,19 +1107,19 @@ def update_needrun(job): # The first pass (with depends_on_checkpoint_target == True) is not informative # for determining any other changes than file modification dates, as it will # change after evaluating the input function of the job in the second pass. - if "params" in self.workflow.rerun_triggers: + if RerunTrigger.PARAMS in self.workflow.rerun_triggers: reason.params_changed = any( self.workflow.persistence.params_changed(job) ) - if "input" in self.workflow.rerun_triggers: + if RerunTrigger.INPUT in self.workflow.rerun_triggers: reason.input_changed = any( self.workflow.persistence.input_changed(job) ) - if "code" in self.workflow.rerun_triggers: + if RerunTrigger.CODE in self.workflow.rerun_triggers: reason.code_changed = any( job.outputs_older_than_script_or_notebook() ) or any(self.workflow.persistence.code_changed(job)) - if "software-env" in self.workflow.rerun_triggers: + if RerunTrigger.SOFTWARE_ENV in self.workflow.rerun_triggers: reason.software_stack_changed = any( self.workflow.persistence.conda_env_changed(job) ) or any(self.workflow.persistence.container_changed(job)) @@ -1351,7 +1294,7 @@ def _update_group_components(self): for group in self._group.values(): groups_by_id[group.groupid].add(group) for groupid, conn_components in groups_by_id.items(): - n_components = self.workflow.group_components.get(groupid, 1) + n_components = self.workflow.group_settings.group_components.get(groupid, 1) if n_components > 1: for chunk in group_into_chunks(n_components, conn_components): if len(chunk) > 1: @@ -1674,9 +1617,15 @@ def finish(self, job, update_dynamic=True): if updated_dag: # We might have new jobs, so we need to ensure that all conda envs # and singularity images are set up. - if self.workflow.use_singularity: + if ( + DeploymentMethod.APPTAINER + in self.workflow.deployment_settings.deployment_method + ): self.pull_container_imgs() - if self.workflow.use_conda: + if ( + DeploymentMethod.CONDA + in self.workflow.deployment_settings.deployment_method + ): self.create_conda_envs() potential_new_ready_jobs = True @@ -1869,10 +1818,6 @@ def collect_potential_dependencies(self, job, known_producers): input_files = input_batch for file in input_files: - # omit the file if it comes from a subworkflow - if file in job.subworkflow_input: - continue - try: yield PotentialDependency(file, known_producers[file], True) except KeyError: @@ -2306,27 +2251,27 @@ def summary(self, detailed=False): else: yield "\t".join((f, date, rule, version, log, status, pending)) - def archive(self, path): + def archive(self, path: Path): """Archives workflow such that it can be re-run on a different system. Archiving includes git versioned files (i.e. Snakefiles, config files, ...), ancestral input files and conda environments. """ - if path.endswith(".tar"): + if path.suffix == ".tar": mode = "x" - elif path.endswith("tar.bz2"): + elif path.suffixes == [".tar", ".bz2"]: mode = "x:bz2" - elif path.endswith("tar.xz"): + elif path.suffixes == [".tar", ".xz"]: mode = "x:xz" - elif path.endswith("tar.gz"): + elif path.suffixes == [".tar", ".gz"]: mode = "x:gz" else: raise WorkflowError( "Unsupported archive format " "(supported: .tar, .tar.gz, .tar.bz2, .tar.xz)" ) - if os.path.exists(path): - raise WorkflowError("Archive already exists:\n" + path) + if path.exists(): + raise WorkflowError(f"Archive already exists:\n{path}") self.create_conda_envs() @@ -2415,7 +2360,7 @@ def list_untracked(self): ) dirs[:] = [d for d in dirs if not d[0] == "."] for f in sorted(list(files_in_cwd - used_files)): - logger.info(f) + print(f) def d3dag(self, max_jobs=10000): def node(job): @@ -2554,14 +2499,14 @@ def toposorted(self, jobs=None, inherit_pipe_dependencies=False): yield sorted_layer - def get_outputs_with_changes(self, change_type, include_needrun=True): + def get_outputs_with_changes(self, change_type: ChangeType, include_needrun=True): is_changed = lambda job: ( getattr(self.workflow.persistence, f"{change_type}_changed")(job) if not job.is_group() and (include_needrun or not self.needrun(job)) else [] ) changed = list(chain(*map(is_changed, self.jobs))) - if change_type == "code": + if change_type == ChangeType.CODE: for job in self.jobs: if not job.is_group() and (include_needrun or not self.needrun(job)): changed.extend(list(job.outputs_older_than_script_or_notebook())) @@ -2569,17 +2514,17 @@ def get_outputs_with_changes(self, change_type, include_needrun=True): def warn_about_changes(self, quiet=False): if not quiet: - for change_type in ["code", "input", "params"]: + for change_type in ChangeType.all(): changed = self.get_outputs_with_changes( change_type, include_needrun=False ) if changed: rerun_trigger = "" if not ON_WINDOWS: - rerun_trigger = f"\n To trigger a re-run, use 'snakemake -R $(snakemake --list-{change_type}-changes)'." + rerun_trigger = f"\n To trigger a re-run, use 'snakemake -R $(snakemake --list-changes {change_type})'." logger.warning( f"The {change_type} used to generate one or several output files has changed:\n" - f" To inspect which output files have changes, run 'snakemake --list-{change_type}-changes'." + f" To inspect which output files have changes, run 'snakemake --list-changes {change_type}'." f"{rerun_trigger}" ) diff --git a/snakemake/deployment/conda.py b/snakemake/deployment/conda.py index 5a6f53479..5e875806d 100644 --- a/snakemake/deployment/conda.py +++ b/snakemake/deployment/conda.py @@ -72,7 +72,7 @@ def __init__( if env_name is not None: assert env_file is None, "bug: both env_file and env_name specified" - self.frontend = workflow.conda_frontend + self.frontend = workflow.deployment_settings.conda_frontend self.workflow = workflow self._container_img = container_img @@ -89,7 +89,7 @@ def __init__( self._path = None self._archive_file = None self._cleanup = cleanup - self._singularity_args = workflow.singularity_args + self._singularity_args = workflow.deployment_settings.apptainer_args @lazy_property def conda(self): diff --git a/snakemake/deployment/singularity.py b/snakemake/deployment/singularity.py index 4c0a8c2f5..f043500e4 100644 --- a/snakemake/deployment/singularity.py +++ b/snakemake/deployment/singularity.py @@ -188,8 +188,8 @@ def check(self): if not self.checked: if not shutil.which("singularity"): raise WorkflowError( - "The singularity command has to be " - "available in order to use singularity " + "The apptainer or singularity command has to be " + "available in order to use apptainer/singularity " "integration." ) try: diff --git a/snakemake/exceptions.py b/snakemake/exceptions.py index 91dd612cc..e2b6be79d 100644 --- a/snakemake/exceptions.py +++ b/snakemake/exceptions.py @@ -7,8 +7,7 @@ import traceback import textwrap from tokenize import TokenError -from snakemake.logging import logger -from snakemake_interface_executor_plugins.exceptions import WorkflowError +from snakemake_interface_common.exceptions import WorkflowError, ApiError def format_error( @@ -67,11 +66,13 @@ def format_traceback(tb, linemaps): def log_verbose_traceback(ex): + from snakemake.logging import logger + tb = "Full " + "".join(traceback.format_exception(type(ex), ex, ex.__traceback__)) logger.debug(tb) -def print_exception(ex, linemaps): +def print_exception(ex, linemaps=None): """ Print an error message for a given exception. @@ -80,6 +81,8 @@ def print_exception(ex, linemaps): linemaps -- a dict of a dict that maps for each snakefile the compiled lines to source code lines in the snakefile. """ + from snakemake.logging import logger + log_verbose_traceback(ex) if isinstance(ex, SyntaxError) or isinstance(ex, IndentationError): logger.error( @@ -92,7 +95,7 @@ def print_exception(ex, linemaps): ) ) return - origin = get_exception_origin(ex, linemaps) + origin = get_exception_origin(ex, linemaps) if linemaps is not None else None if origin is not None: lineno, file = origin logger.error( @@ -142,6 +145,10 @@ def print_exception(ex, linemaps): rule=ex.rule, ) ) + elif isinstance(ex, ApiError): + logger.error(f"Error: {ex}") + elif isinstance(ex, CliException): + logger.error(f"Error: {ex}") elif isinstance(ex, KeyboardInterrupt): logger.info("Cancelling snakemake on user request.") else: diff --git a/snakemake/executors/__init__.py b/snakemake/executors/__init__.py index 5dfa909ba..e54f550df 100644 --- a/snakemake/executors/__init__.py +++ b/snakemake/executors/__init__.py @@ -3,1040 +3,10 @@ __email__ = "johannes.koester@uni-due.de" __license__ = "MIT" -from abc import ABC, abstractmethod -import asyncio import os -import sys import contextlib -import time -import json -import stat -import shutil -import shlex -import threading -import concurrent.futures -import subprocess -import tempfile -from functools import partial -from collections import namedtuple -import base64 -from typing import List -import uuid -import re -import math -from snakemake_interface_executor_plugins.executors.base import AbstractExecutor -from snakemake_interface_executor_plugins.executors.real import RealExecutor -from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.dag import DAGExecutorInterface -from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface -from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface -from snakemake_interface_executor_plugins.jobs import ( - ExecutorJobInterface, - SingleJobExecutorInterface, - GroupJobExecutorInterface, -) -from snakemake_interface_executor_plugins.utils import sleep -from snakemake_interface_executor_plugins.utils import ExecMode - -from snakemake.shell import shell from snakemake.logging import logger -from snakemake.stats import Stats -from snakemake.utils import makedirs -from snakemake.io import get_wildcard_names, Wildcards -from snakemake.exceptions import print_exception, get_exception_origin -from snakemake.exceptions import format_error, RuleException, log_verbose_traceback -from snakemake.exceptions import ( - WorkflowError, - SpawnedJobError, - CacheMissException, -) -from snakemake.common import ( - get_container_image, - get_uuid, - async_lock, -) - - -class DryrunExecutor(AbstractExecutor): - def get_exec_mode(self): - raise NotImplementedError() - - def printjob(self, job: ExecutorJobInterface): - super().printjob(job) - if job.is_group(): - for j in job.jobs: - self.printcache(j) - else: - self.printcache(job) - - def printcache(self, job: ExecutorJobInterface): - cache_mode = self.workflow.get_cache_mode(job.rule) - if cache_mode: - if self.workflow.output_file_cache.exists(job, cache_mode): - logger.info( - "Output file {} will be obtained from global between-workflow cache.".format( - job.output[0] - ) - ) - else: - logger.info( - "Output file {} will be written to global between-workflow cache.".format( - job.output[0] - ) - ) - - def cancel(self): - pass - - def shutdown(self): - pass - - def handle_job_success(self, job: ExecutorJobInterface): - pass - - def handle_job_error(self, job: ExecutorJobInterface): - pass - - -class TouchExecutor(RealExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - ): - super().__init__( - workflow, - dag, - stats, - logger, - executor_settings=None, - ) - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - try: - # Touching of output files will be done by handle_job_success - time.sleep(0.1) - callback(job) - except OSError as ex: - print_exception(ex, self.workflow.linemaps) - error_callback(job) - - def get_exec_mode(self): - raise NotImplementedError() - - def handle_job_success(self, job: ExecutorJobInterface): - super().handle_job_success(job, ignore_missing_output=True) - - def cancel(self): - pass - - def shutdown(self): - pass - - def get_python_executable(self): - raise NotImplementedError() - - -_ProcessPoolExceptions = (KeyboardInterrupt,) -try: - from concurrent.futures.process import BrokenProcessPool - - _ProcessPoolExceptions = (KeyboardInterrupt, BrokenProcessPool) -except ImportError: - pass - - -class CPUExecutor(RealExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - cores: int, - use_threads=False, - ): - super().__init__( - workflow, - dag, - stats, - logger, - executor_settings=None, - job_core_limit=cores, - ) - - self.use_threads = use_threads - - # Zero thread jobs do not need a thread, but they occupy additional workers. - # Hence we need to reserve additional workers for them. - workers = cores + 5 if cores is not None else 5 - self.workers = workers - self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.workers) - - def get_exec_mode(self): - return ExecMode.subprocess - - @property - def job_specific_local_groupid(self): - return False - - def get_job_exec_prefix(self, job: ExecutorJobInterface): - return f"cd {shlex.quote(self.workflow.workdir_init)}" - - def get_python_executable(self): - return sys.executable - - def get_envvar_declarations(self): - return "" - - def get_job_args(self, job: ExecutorJobInterface, **kwargs): - return f"{super().get_job_args(job, **kwargs)} --quiet" - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - - if job.is_group(): - # if we still don't have enough workers for this group, create a new pool here - missing_workers = max(len(job) - self.workers, 0) - if missing_workers: - self.workers += missing_workers - self.pool = concurrent.futures.ThreadPoolExecutor( - max_workers=self.workers - ) - - # the future waits for the entire group job - future = self.pool.submit(self.run_group_job, job) - else: - future = self.run_single_job(job) - - future.add_done_callback(partial(self._callback, job, callback, error_callback)) - - def job_args_and_prepare(self, job: ExecutorJobInterface): - job.prepare() - - conda_env = ( - job.conda_env.address if self.workflow.use_conda and job.conda_env else None - ) - container_img = ( - job.container_img_path if self.workflow.use_singularity else None - ) - env_modules = job.env_modules if self.workflow.use_env_modules else None - - benchmark = None - benchmark_repeats = job.benchmark_repeats or 1 - if job.benchmark is not None: - benchmark = str(job.benchmark) - return ( - job.rule, - job.input._plainstrings(), - job.output._plainstrings(), - job.params, - job.wildcards, - job.threads, - job.resources, - job.log._plainstrings(), - benchmark, - benchmark_repeats, - conda_env, - container_img, - self.workflow.singularity_args, - env_modules, - self.workflow.use_singularity, - self.workflow.linemaps, - self.workflow.debug, - self.workflow.cleanup_scripts, - job.shadow_dir, - job.jobid, - self.workflow.edit_notebook if self.dag.is_edit_notebook_job(job) else None, - self.workflow.conda_base_path, - job.rule.basedir, - self.workflow.sourcecache.runtime_cache_path, - ) - - def run_single_job(self, job: SingleJobExecutorInterface): - if ( - self.use_threads - or (not job.is_shadow and not job.is_run) - or job.is_template_engine - ): - future = self.pool.submit( - self.cached_or_run, job, run_wrapper, *self.job_args_and_prepare(job) - ) - else: - # run directive jobs are spawned into subprocesses - future = self.pool.submit(self.cached_or_run, job, self.spawn_job, job) - return future - - def run_group_job(self, job: GroupJobExecutorInterface): - """Run a pipe or service group job. - - This lets all items run simultaneously.""" - # we only have to consider pipe or service groups because in local running mode, - # these are the only groups that will occur - - futures = [self.run_single_job(j) for j in job] - n_non_service = sum(1 for j in job if not j.is_service) - - while True: - n_finished = 0 - for f in futures: - if f.done(): - ex = f.exception() - if ex is not None: - # kill all shell commands of the other group jobs - # there can be only shell commands because the - # run directive is not allowed for pipe jobs - for j in job: - shell.kill(j.jobid) - raise ex - else: - n_finished += 1 - if n_finished >= n_non_service: - # terminate all service jobs since all consumers are done - for j in job: - if j.is_service: - logger.info( - f"Terminating service job {j.jobid} since all consuming jobs are finished." - ) - shell.terminate(j.jobid) - logger.info( - f"Service job {j.jobid} has been successfully terminated." - ) - - return - time.sleep(1) - - def spawn_job(self, job: SingleJobExecutorInterface): - cmd = self.format_job_exec(job) - try: - subprocess.check_call(cmd, shell=True) - except subprocess.CalledProcessError as e: - raise SpawnedJobError() - - def cached_or_run(self, job: SingleJobExecutorInterface, run_func, *args): - """ - Either retrieve result from cache, or run job with given function. - """ - cache_mode = self.workflow.get_cache_mode(job.rule) - try: - if cache_mode: - self.workflow.output_file_cache.fetch(job, cache_mode) - return - except CacheMissException: - pass - run_func(*args) - if cache_mode: - self.workflow.output_file_cache.store(job, cache_mode) - - def shutdown(self): - self.pool.shutdown() - - def cancel(self): - self.pool.shutdown() - - def _callback( - self, job: SingleJobExecutorInterface, callback, error_callback, future - ): - try: - ex = future.exception() - if ex is not None: - raise ex - callback(job) - except _ProcessPoolExceptions: - self.handle_job_error(job) - # no error callback, just silently ignore the interrupt as the main scheduler is also killed - except SpawnedJobError: - # don't print error message, this is done by the spawned subprocess - error_callback(job) - except BaseException as ex: - self.print_job_error(job) - if self.workflow.verbose or (not job.is_group() and not job.is_shell): - print_exception(ex, self.workflow.linemaps) - error_callback(job) - - def handle_job_success(self, job: ExecutorJobInterface): - super().handle_job_success(job) - - def handle_job_error(self, job: ExecutorJobInterface): - super().handle_job_error(job) - if not self.keepincomplete: - job.cleanup() - self.workflow.persistence.cleanup(job) - - -GenericClusterJob = namedtuple( - "GenericClusterJob", - "job jobid callback error_callback jobscript jobfinished jobfailed", -) - - -class GenericClusterExecutor(RemoteExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - submitcmd="qsub", - statuscmd=None, - cancelcmd=None, - cancelnargs=None, - sidecarcmd=None, - jobname="snakejob.{rulename}.{jobid}.sh", - max_status_checks_per_second=1, - ): - self.submitcmd = submitcmd - if not workflow.assume_shared_fs and statuscmd is None: - raise WorkflowError( - "When no shared filesystem can be assumed, a " - "status command must be given." - ) - - self.statuscmd = statuscmd - self.cancelcmd = cancelcmd - self.sidecarcmd = sidecarcmd - self.cancelnargs = cancelnargs - self.external_jobid = dict() - # We need to collect all external ids so we can properly cancel even if - # the status update queue is running. - self.all_ext_jobids = list() - - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, - ) - - self.sidecar_vars = None - if self.sidecarcmd: - self._launch_sidecar() - - if not statuscmd and not self.assume_shared_fs: - raise WorkflowError( - "If no shared filesystem is used, you have to " - "specify a cluster status command." - ) - - def get_job_exec_prefix(self, job: ExecutorJobInterface): - if self.assume_shared_fs: - return f"cd {shlex.quote(self.workflow.workdir_init)}" - else: - return "" - - def get_job_exec_suffix(self, job: ExecutorJobInterface): - if self.statuscmd: - return "exit 0 || exit 1" - elif self.assume_shared_fs: - # TODO wrap with watch and touch {jobrunning} - # check modification date of {jobrunning} in the wait_for_job method - - return ( - f"touch {repr(self.get_jobfinished_marker(job))} || " - f"(touch {repr(self.get_jobfailed_marker(job))}; exit 1)" - ) - assert False, "bug: neither statuscmd defined nor shared FS" - - def get_jobfinished_marker(self, job: ExecutorJobInterface): - return os.path.join(self.tmpdir, f"{job.jobid}.jobfinished") - - def get_jobfailed_marker(self, job: ExecutorJobInterface): - return os.path.join(self.tmpdir, f"{job.jobid}.jobfailed") - - def _launch_sidecar(self): - def copy_stdout(executor, process): - """Run sidecar process and copy it's stdout to our stdout.""" - while process.poll() is None and executor.wait: - buf = process.stdout.readline() - if buf: - sys.stdout.write(buf) - # one final time ... - buf = process.stdout.readline() - if buf: - sys.stdout.write(buf) - - def wait(executor, process): - while executor.wait: - time.sleep(0.5) - process.terminate() - process.wait() - logger.info( - "Cluster sidecar process has terminated (retcode=%d)." - % process.returncode - ) - - logger.info("Launch sidecar process and read first output line.") - process = subprocess.Popen( - self.sidecarcmd, stdout=subprocess.PIPE, shell=False, encoding="utf-8" - ) - self.sidecar_vars = process.stdout.readline() - while self.sidecar_vars and self.sidecar_vars[-1] in "\n\r": - self.sidecar_vars = self.sidecar_vars[:-1] - logger.info("Done reading first output line.") - - thread_stdout = threading.Thread( - target=copy_stdout, name="sidecar_stdout", args=(self, process) - ) - thread_stdout.start() - thread_wait = threading.Thread( - target=wait, name="sidecar_stdout", args=(self, process) - ) - thread_wait.start() - - def cancel(self): - def _chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - if self.cancelcmd: # We have --cluster-cancel - # Enumerate job IDs and create chunks. If cancelnargs evaluates to false (0/None) - # then pass all job ids at once - jobids = list(self.all_ext_jobids) - chunks = list(_chunks(jobids, self.cancelnargs or len(jobids))) - # Go through the chunks and cancel the jobs, warn in case of failures. - failures = 0 - for chunk in chunks: - try: - cancel_timeout = 2 # rather fail on timeout than miss canceling all - env = dict(os.environ) - if self.sidecar_vars: - env["SNAKEMAKE_CLUSTER_SIDECAR_VARS"] = self.sidecar_vars - subprocess.check_call( - [self.cancelcmd] + chunk, - shell=False, - timeout=cancel_timeout, - env=env, - ) - except subprocess.SubprocessError: - failures += 1 - if failures: - logger.info( - ( - "{} out of {} calls to --cluster-cancel failed. This is safe to " - "ignore in most cases." - ).format(failures, len(chunks)) - ) - else: - logger.info( - "No --cluster-cancel given. Will exit after finishing currently running jobs." - ) - self.shutdown() - - def register_job(self, job: ExecutorJobInterface): - # Do not register job here. - # Instead do it manually once the jobid is known. - pass - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - jobid = job.jobid - - jobscript = self.get_jobscript(job) - self.write_jobscript(job, jobscript) - - jobfinished = self.get_jobfinished_marker(job) - jobfailed = self.get_jobfailed_marker(job) - - if self.statuscmd: - ext_jobid = self.dag.incomplete_external_jobid(job) - if ext_jobid: - # Job is incomplete and still running. - # We simply register it and wait for completion or failure. - logger.info( - "Resuming incomplete job {} with external jobid '{}'.".format( - jobid, ext_jobid - ) - ) - submit_callback(job) - with self.lock: - self.all_ext_jobids.append(ext_jobid) - self.active_jobs.append( - GenericClusterJob( - job, - ext_jobid, - callback, - error_callback, - jobscript, - jobfinished, - jobfailed, - ) - ) - return - - deps = " ".join( - self.external_jobid[f] for f in job.input if f in self.external_jobid - ) - try: - submitcmd = job.format_wildcards(self.submitcmd, dependencies=deps) - except AttributeError as e: - raise WorkflowError(str(e), rule=job.rule if not job.is_group() else None) - - try: - env = dict(os.environ) - if self.sidecar_vars: - env["SNAKEMAKE_CLUSTER_SIDECAR_VARS"] = self.sidecar_vars - - # Remove SNAKEMAKE_PROFILE from environment as the snakemake call inside - # of the cluster job must run locally (or complains about missing -j). - env.pop("SNAKEMAKE_PROFILE", None) - - ext_jobid = ( - subprocess.check_output( - '{submitcmd} "{jobscript}"'.format( - submitcmd=submitcmd, jobscript=jobscript - ), - shell=True, - env=env, - ) - .decode() - .split("\n") - ) - except subprocess.CalledProcessError as ex: - logger.error( - "Error submitting jobscript (exit code {}):\n{}".format( - ex.returncode, ex.output.decode() - ) - ) - error_callback(job) - return - if ext_jobid and ext_jobid[0]: - ext_jobid = ext_jobid[0] - self.external_jobid.update((f, ext_jobid) for f in job.output) - logger.info( - "Submitted {} {} with external jobid '{}'.".format( - "group job" if job.is_group() else "job", jobid, ext_jobid - ) - ) - self.workflow.persistence.started(job, external_jobid=ext_jobid) - - submit_callback(job) - - with self.lock: - self.all_ext_jobids.append(ext_jobid) - self.active_jobs.append( - GenericClusterJob( - job, - ext_jobid, - callback, - error_callback, - jobscript, - jobfinished, - jobfailed, - ) - ) - - async def _wait_for_jobs(self): - success = "success" - failed = "failed" - running = "running" - status_cmd_kills = [] - if self.statuscmd is not None: - - def job_status(job, valid_returns=["running", "success", "failed"]): - try: - # this command shall return "success", "failed" or "running" - env = dict(os.environ) - if self.sidecar_vars: - env["SNAKEMAKE_CLUSTER_SIDECAR_VARS"] = self.sidecar_vars - ret = subprocess.check_output( - "{statuscmd} '{jobid}'".format( - jobid=job.jobid, statuscmd=self.statuscmd - ), - shell=True, - env=env, - ).decode() - except subprocess.CalledProcessError as e: - if e.returncode < 0: - # Ignore SIGINT and all other issues due to signals - # because it will be caused by hitting e.g. - # Ctrl-C on the main process or sending killall to - # snakemake. - # Snakemake will handle the signal in - # the main process. - status_cmd_kills.append(-e.returncode) - if len(status_cmd_kills) > 10: - logger.info( - "Cluster status command {} was killed >10 times with signal(s) {} " - "(if this happens unexpectedly during your workflow execution, " - "have a closer look.).".format( - self.statuscmd, ",".join(status_cmd_kills) - ) - ) - status_cmd_kills.clear() - else: - raise WorkflowError( - "Failed to obtain job status. " - "See above for error message." - ) - - ret = ret.strip().split("\n") - if len(ret) != 1 or ret[0] not in valid_returns: - raise WorkflowError( - "Cluster status command {} returned {} but just a single line with one of {} is expected.".format( - self.statuscmd, "\\n".join(ret), ",".join(valid_returns) - ) - ) - return ret[0] - - else: - - def job_status(job): - if os.path.exists(active_job.jobfinished): - os.remove(active_job.jobfinished) - os.remove(active_job.jobscript) - return success - if os.path.exists(active_job.jobfailed): - os.remove(active_job.jobfailed) - os.remove(active_job.jobscript) - return failed - return running - - while True: - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - # logger.debug("Checking status of {} jobs.".format(len(active_jobs))) - for active_job in active_jobs: - async with self.status_rate_limiter: - status = job_status(active_job) - - if status == success: - active_job.callback(active_job.job) - elif status == failed: - self.print_job_error( - active_job.job, - cluster_jobid=active_job.jobid - if active_job.jobid - else "unknown", - ) - self.print_cluster_job_error( - active_job, self.dag.jobid(active_job.job) - ) - active_job.error_callback(active_job.job) - else: - still_running.append(active_job) - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() - - -SynchronousClusterJob = namedtuple( - "SynchronousClusterJob", "job jobid callback error_callback jobscript process" -) - - -class SynchronousClusterExecutor(RemoteExecutor): - """ - invocations like "qsub -sync y" (SGE) or "bsub -K" (LSF) are - synchronous, blocking the foreground thread and returning the - remote exit code at remote exit. - """ - - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - submitcmd="qsub", - jobname="snakejob.{rulename}.{jobid}.sh", - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=10, - ) - self.submitcmd = submitcmd - self.external_jobid = dict() - - def get_job_exec_prefix(self, job): - if self.assume_shared_fs: - return f"cd {shlex.quote(self.workflow.workdir_init)}" - else: - return "" - - def cancel(self): - logger.info("Will exit after finishing currently running jobs.") - self.shutdown() - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - - jobscript = self.get_jobscript(job) - self.write_jobscript(job, jobscript) - - deps = " ".join( - self.external_jobid[f] for f in job.input if f in self.external_jobid - ) - try: - submitcmd = job.format_wildcards(self.submitcmd, dependencies=deps) - except AttributeError as e: - raise WorkflowError(str(e), rule=job.rule if not job.is_group() else None) - - process = subprocess.Popen( - '{submitcmd} "{jobscript}"'.format( - submitcmd=submitcmd, jobscript=jobscript - ), - shell=True, - ) - submit_callback(job) - - with self.lock: - self.active_jobs.append( - SynchronousClusterJob( - job, process.pid, callback, error_callback, jobscript, process - ) - ) - - async def _wait_for_jobs(self): - while True: - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - for active_job in active_jobs: - async with self.status_rate_limiter: - exitcode = active_job.process.poll() - if exitcode is None: - # job not yet finished - still_running.append(active_job) - elif exitcode == 0: - # job finished successfully - os.remove(active_job.jobscript) - active_job.callback(active_job.job) - else: - # job failed - os.remove(active_job.jobscript) - self.print_job_error(active_job.job) - self.print_cluster_job_error( - active_job, self.dag.jobid(active_job.job) - ) - active_job.error_callback(active_job.job) - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() - - -DRMAAClusterJob = namedtuple( - "DRMAAClusterJob", "job jobid callback error_callback jobscript" -) - - -class DRMAAExecutor(RemoteExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - jobname="snakejob.{rulename}.{jobid}.sh", - drmaa_args="", - drmaa_log_dir=None, - max_status_checks_per_second=1, - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, - ) - try: - import drmaa - except ImportError: - raise WorkflowError( - "Python support for DRMAA is not installed. " - "Please install it, e.g. with easy_install3 --user drmaa" - ) - except RuntimeError as e: - raise WorkflowError(f"Error loading drmaa support:\n{e}") - self.session = drmaa.Session() - self.drmaa_args = drmaa_args - self.drmaa_log_dir = drmaa_log_dir - self.session.initialize() - self.submitted = list() - - def get_job_exec_prefix(self, job: ExecutorJobInterface): - if self.assume_shared_fs: - return f"cd {shlex.quote(self.workflow.workdir_init)}" - else: - return "" - - def cancel(self): - from drmaa.const import JobControlAction - from drmaa.errors import InvalidJobException, InternalException - - for jobid in self.submitted: - try: - self.session.control(jobid, JobControlAction.TERMINATE) - except (InvalidJobException, InternalException): - # This is common - logging a warning would probably confuse the user. - pass - self.shutdown() - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - jobscript = self.get_jobscript(job) - self.write_jobscript(job, jobscript) - - try: - drmaa_args = job.format_wildcards(self.drmaa_args) - except AttributeError as e: - raise WorkflowError(str(e), rule=job.rule) - - import drmaa - - if self.drmaa_log_dir: - makedirs(self.drmaa_log_dir) - - try: - jt = self.session.createJobTemplate() - jt.remoteCommand = jobscript - jt.nativeSpecification = drmaa_args - if self.drmaa_log_dir: - jt.outputPath = ":" + self.drmaa_log_dir - jt.errorPath = ":" + self.drmaa_log_dir - jt.jobName = os.path.basename(jobscript) - - jobid = self.session.runJob(jt) - except ( - drmaa.DeniedByDrmException, - drmaa.InternalException, - drmaa.InvalidAttributeValueException, - ) as e: - print_exception(WorkflowError(f"DRMAA Error: {e}"), self.workflow.linemaps) - error_callback(job) - return - logger.info(f"Submitted DRMAA job {job.jobid} with external jobid {jobid}.") - self.submitted.append(jobid) - self.session.deleteJobTemplate(jt) - - submit_callback(job) - - with self.lock: - self.active_jobs.append( - DRMAAClusterJob(job, jobid, callback, error_callback, jobscript) - ) - - def shutdown(self): - super().shutdown() - self.session.exit() - - async def _wait_for_jobs(self): - import drmaa - - suspended_msg = set() - - while True: - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - for active_job in active_jobs: - async with self.status_rate_limiter: - try: - retval = self.session.jobStatus(active_job.jobid) - except drmaa.ExitTimeoutException as e: - # job still active - still_running.append(active_job) - continue - except (drmaa.InternalException, Exception) as e: - print_exception( - WorkflowError(f"DRMAA Error: {e}"), - self.workflow.linemaps, - ) - os.remove(active_job.jobscript) - active_job.error_callback(active_job.job) - continue - if retval == drmaa.JobState.DONE: - os.remove(active_job.jobscript) - active_job.callback(active_job.job) - elif retval == drmaa.JobState.FAILED: - os.remove(active_job.jobscript) - self.print_job_error(active_job.job) - self.print_cluster_job_error( - active_job, self.dag.jobid(active_job.job) - ) - active_job.error_callback(active_job.job) - else: - # still running - still_running.append(active_job) - - def handle_suspended(by): - if active_job.job.jobid not in suspended_msg: - logger.warning( - "Job {} (DRMAA id: {}) was suspended by {}.".format( - active_job.job.jobid, active_job.jobid, by - ) - ) - suspended_msg.add(active_job.job.jobid) - - if retval == drmaa.JobState.USER_SUSPENDED: - handle_suspended("user") - elif retval == drmaa.JobState.SYSTEM_SUSPENDED: - handle_suspended("system") - else: - try: - suspended_msg.remove(active_job.job.jobid) - except KeyError: - # there was nothing to remove - pass - - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() @contextlib.contextmanager @@ -1052,961 +22,3 @@ def change_working_directory(directory=None): os.chdir(saved_directory) else: yield - - -KubernetesJob = namedtuple( - "KubernetesJob", "job jobid callback error_callback kubejob jobscript" -) - - -class KubernetesExecutor(RemoteExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - namespace, - container_image=None, - k8s_cpu_scalar=1.0, - k8s_service_account_name=None, - jobname="{rulename}.{jobid}", - ): - self.workflow = workflow - - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=10, - disable_envvar_declarations=True, - ) - # use relative path to Snakefile - self.snakefile = os.path.relpath(workflow.main_snakefile) - - try: - from kubernetes import config - except ImportError: - raise WorkflowError( - "The Python 3 package 'kubernetes' " - "must be installed to use Kubernetes" - ) - config.load_kube_config() - - import kubernetes.client - - self.k8s_cpu_scalar = k8s_cpu_scalar - self.k8s_service_account_name = k8s_service_account_name - self.kubeapi = kubernetes.client.CoreV1Api() - self.batchapi = kubernetes.client.BatchV1Api() - self.namespace = namespace - self.envvars = workflow.envvars - self.secret_files = {} - self.run_namespace = str(uuid.uuid4()) - self.secret_envvars = {} - self.register_secret() - self.container_image = container_image or get_container_image() - logger.info(f"Using {self.container_image} for Kubernetes jobs.") - - def get_job_exec_prefix(self, job: ExecutorJobInterface): - return "cp -rf /source/. ." - - def register_secret(self): - import kubernetes.client - - secret = kubernetes.client.V1Secret() - secret.metadata = kubernetes.client.V1ObjectMeta() - # create a random uuid - secret.metadata.name = self.run_namespace - secret.type = "Opaque" - secret.data = {} - for i, f in enumerate(self.dag.get_sources()): - if f.startswith(".."): - logger.warning( - "Ignoring source file {}. Only files relative " - "to the working directory are allowed.".format(f) - ) - continue - - # The kubernetes API can't create secret files larger than 1MB. - source_file_size = os.path.getsize(f) - max_file_size = 1048576 - if source_file_size > max_file_size: - logger.warning( - "Skipping the source file {f}. Its size {source_file_size} exceeds " - "the maximum file size (1MB) that can be passed " - "from host to kubernetes.".format( - f=f, source_file_size=source_file_size - ) - ) - continue - - with open(f, "br") as content: - key = f"f{i}" - - # Some files are smaller than 1MB, but grows larger after being base64 encoded - # We should exclude them as well, otherwise Kubernetes APIs will complain - encoded_contents = base64.b64encode(content.read()).decode() - encoded_size = len(encoded_contents) - if encoded_size > 1048576: - logger.warning( - "Skipping the source file {f} for secret key {key}. " - "Its base64 encoded size {encoded_size} exceeds " - "the maximum file size (1MB) that can be passed " - "from host to kubernetes.".format( - f=f, - key=key, - encoded_size=encoded_size, - ) - ) - continue - - self.secret_files[key] = f - secret.data[key] = encoded_contents - - for e in self.envvars: - try: - key = e.lower() - secret.data[key] = base64.b64encode(os.environ[e].encode()).decode() - self.secret_envvars[key] = e - except KeyError: - continue - - # Test if the total size of the configMap exceeds 1MB - config_map_size = sum( - [len(base64.b64decode(v)) for k, v in secret.data.items()] - ) - if config_map_size > 1048576: - logger.warning( - "The total size of the included files and other Kubernetes secrets " - "is {}, exceeding the 1MB limit.\n".format(config_map_size) - ) - logger.warning( - "The following are the largest files. Consider removing some of them " - "(you need remove at least {} bytes):".format(config_map_size - 1048576) - ) - - entry_sizes = { - self.secret_files[k]: len(base64.b64decode(v)) - for k, v in secret.data.items() - if k in self.secret_files - } - for k, v in sorted(entry_sizes.items(), key=lambda item: item[1])[:-6:-1]: - logger.warning(f" * File: {k}, original size: {v}") - - raise WorkflowError("ConfigMap too large") - - self.kubeapi.create_namespaced_secret(self.namespace, secret) - - def unregister_secret(self): - import kubernetes.client - - safe_delete_secret = lambda: self.kubeapi.delete_namespaced_secret( - self.run_namespace, self.namespace, body=kubernetes.client.V1DeleteOptions() - ) - self._kubernetes_retry(safe_delete_secret) - - # In rare cases, deleting a pod may rais 404 NotFound error. - def safe_delete_pod(self, jobid, ignore_not_found=True): - import kubernetes.client - - body = kubernetes.client.V1DeleteOptions() - try: - self.kubeapi.delete_namespaced_pod(jobid, self.namespace, body=body) - except kubernetes.client.rest.ApiException as e: - if e.status == 404 and ignore_not_found: - # Can't find the pod. Maybe it's already been - # destroyed. Proceed with a warning message. - logger.warning( - "[WARNING] 404 not found when trying to delete the pod: {jobid}\n" - "[WARNING] Ignore this error\n".format(jobid=jobid) - ) - else: - raise e - - def shutdown(self): - self.unregister_secret() - super().shutdown() - - def cancel(self): - import kubernetes.client - - body = kubernetes.client.V1DeleteOptions() - with self.lock: - for j in self.active_jobs: - func = lambda: self.safe_delete_pod(j.jobid, ignore_not_found=True) - self._kubernetes_retry(func) - - self.shutdown() - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - import kubernetes.client - - super()._run(job) - exec_job = self.format_job_exec(job) - - # Kubernetes silently does not submit a job if the name is too long - # therefore, we ensure that it is not longer than snakejob+uuid. - jobid = "snakejob-{}".format( - get_uuid(f"{self.run_namespace}-{job.jobid}-{job.attempt}") - ) - - body = kubernetes.client.V1Pod() - body.metadata = kubernetes.client.V1ObjectMeta(labels={"app": "snakemake"}) - - body.metadata.name = jobid - - # container - container = kubernetes.client.V1Container(name=jobid) - container.image = self.container_image - container.command = shlex.split("/bin/sh") - container.args = ["-c", exec_job] - container.working_dir = "/workdir" - container.volume_mounts = [ - kubernetes.client.V1VolumeMount(name="workdir", mount_path="/workdir"), - kubernetes.client.V1VolumeMount(name="source", mount_path="/source"), - ] - - node_selector = {} - if "machine_type" in job.resources.keys(): - # Kubernetes labels a node by its instance type using this node_label. - node_selector["node.kubernetes.io/instance-type"] = job.resources[ - "machine_type" - ] - - body.spec = kubernetes.client.V1PodSpec( - containers=[container], node_selector=node_selector - ) - # Add service account name if provided - if self.k8s_service_account_name: - body.spec.service_account_name = self.k8s_service_account_name - - # fail on first error - body.spec.restart_policy = "Never" - - # source files as a secret volume - # we copy these files to the workdir before executing Snakemake - too_large = [ - path - for path in self.secret_files.values() - if os.path.getsize(path) > 1000000 - ] - if too_large: - raise WorkflowError( - "The following source files exceed the maximum " - "file size (1MB) that can be passed from host to " - "kubernetes. These are likely not source code " - "files. Consider adding them to your " - "remote storage instead or (if software) use " - "Conda packages or container images:\n{}".format("\n".join(too_large)) - ) - secret_volume = kubernetes.client.V1Volume(name="source") - secret_volume.secret = kubernetes.client.V1SecretVolumeSource() - secret_volume.secret.secret_name = self.run_namespace - secret_volume.secret.items = [ - kubernetes.client.V1KeyToPath(key=key, path=path) - for key, path in self.secret_files.items() - ] - # workdir as an emptyDir volume of undefined size - workdir_volume = kubernetes.client.V1Volume(name="workdir") - workdir_volume.empty_dir = kubernetes.client.V1EmptyDirVolumeSource() - body.spec.volumes = [secret_volume, workdir_volume] - - # env vars - container.env = [] - for key, e in self.secret_envvars.items(): - envvar = kubernetes.client.V1EnvVar(name=e) - envvar.value_from = kubernetes.client.V1EnvVarSource() - envvar.value_from.secret_key_ref = kubernetes.client.V1SecretKeySelector( - key=key, name=self.run_namespace - ) - container.env.append(envvar) - - # request resources - logger.debug(f"job resources: {dict(job.resources)}") - container.resources = kubernetes.client.V1ResourceRequirements() - container.resources.requests = {} - container.resources.requests["cpu"] = "{}m".format( - int(job.resources["_cores"] * self.k8s_cpu_scalar * 1000) - ) - if "mem_mb" in job.resources.keys(): - container.resources.requests["memory"] = "{}M".format( - job.resources["mem_mb"] - ) - if "disk_mb" in job.resources.keys(): - disk_mb = int(job.resources.get("disk_mb", 1024)) - container.resources.requests["ephemeral-storage"] = f"{disk_mb}M" - - logger.debug(f"k8s pod resources: {container.resources.requests}") - - # capabilities - if job.needs_singularity and self.workflow.use_singularity: - # TODO this should work, but it doesn't currently because of - # missing loop devices - # singularity inside docker requires SYS_ADMIN capabilities - # see https://groups.google.com/a/lbl.gov/forum/#!topic/singularity/e9mlDuzKowc - # container.capabilities = kubernetes.client.V1Capabilities() - # container.capabilities.add = ["SYS_ADMIN", - # "DAC_OVERRIDE", - # "SETUID", - # "SETGID", - # "SYS_CHROOT"] - - # Running in priviledged mode always works - container.security_context = kubernetes.client.V1SecurityContext( - privileged=True - ) - - pod = self._kubernetes_retry( - lambda: self.kubeapi.create_namespaced_pod(self.namespace, body) - ) - - logger.info( - "Get status with:\n" - "kubectl describe pod {jobid}\n" - "kubectl logs {jobid}".format(jobid=jobid) - ) - self.active_jobs.append( - KubernetesJob(job, jobid, callback, error_callback, pod, None) - ) - - # Sometimes, certain k8s requests throw kubernetes.client.rest.ApiException - # Solving this issue requires reauthentication, as _kubernetes_retry shows - # However, reauthentication itself, under rare conditions, may also throw - # errors such as: - # kubernetes.client.exceptions.ApiException: (409), Reason: Conflict - # - # This error doesn't mean anything wrong with the k8s cluster, and users can safely - # ignore it. - def _reauthenticate_and_retry(self, func=None): - import kubernetes - - # Unauthorized. - # Reload config in order to ensure token is - # refreshed. Then try again. - logger.info("Trying to reauthenticate") - kubernetes.config.load_kube_config() - subprocess.run(["kubectl", "get", "nodes"]) - - self.kubeapi = kubernetes.client.CoreV1Api() - self.batchapi = kubernetes.client.BatchV1Api() - - try: - self.register_secret() - except kubernetes.client.rest.ApiException as e: - if e.status == 409 and e.reason == "Conflict": - logger.warning("409 conflict ApiException when registering secrets") - logger.warning(e) - else: - raise WorkflowError( - e, - "This is likely a bug in " - "https://github.com/kubernetes-client/python.", - ) - - if func: - return func() - - def _kubernetes_retry(self, func): - import kubernetes - import urllib3 - - with self.lock: - try: - return func() - except kubernetes.client.rest.ApiException as e: - if e.status == 401: - # Unauthorized. - # Reload config in order to ensure token is - # refreshed. Then try again. - return self._reauthenticate_and_retry(func) - # Handling timeout that may occur in case of GKE master upgrade - except urllib3.exceptions.MaxRetryError as e: - logger.info( - "Request time out! " - "check your connection to Kubernetes master" - "Workflow will pause for 5 minutes to allow any update operations to complete" - ) - time.sleep(300) - try: - return func() - except: - # Still can't reach the server after 5 minutes - raise WorkflowError( - e, - "Error 111 connection timeout, please check" - " that the k8 cluster master is reachable!", - ) - - async def _wait_for_jobs(self): - import kubernetes - - while True: - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - for j in active_jobs: - async with self.status_rate_limiter: - logger.debug(f"Checking status for pod {j.jobid}") - job_not_found = False - try: - res = self._kubernetes_retry( - lambda: self.kubeapi.read_namespaced_pod_status( - j.jobid, self.namespace - ) - ) - except kubernetes.client.rest.ApiException as e: - if e.status == 404: - # Jobid not found - # The job is likely already done and was deleted on - # the server. - j.callback(j.job) - continue - except WorkflowError as e: - print_exception(e, self.workflow.linemaps) - j.error_callback(j.job) - continue - - if res is None: - msg = ( - "Unknown pod {jobid}. " - "Has the pod been deleted " - "manually?" - ).format(jobid=j.jobid) - self.print_job_error(j.job, msg=msg, jobid=j.jobid) - j.error_callback(j.job) - elif res.status.phase == "Failed": - msg = ( - "For details, please issue:\n" - "kubectl describe pod {jobid}\n" - "kubectl logs {jobid}" - ).format(jobid=j.jobid) - # failed - self.print_job_error(j.job, msg=msg, jobid=j.jobid) - j.error_callback(j.job) - elif res.status.phase == "Succeeded": - # finished - j.callback(j.job) - - func = lambda: self.safe_delete_pod( - j.jobid, ignore_not_found=True - ) - self._kubernetes_retry(func) - else: - # still active - still_running.append(j) - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() - - -TibannaJob = namedtuple( - "TibannaJob", "job jobname jobid exec_arn callback error_callback" -) - - -class TibannaExecutor(RemoteExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - tibanna_sfn, - precommand="", - tibanna_config=False, - container_image=None, - max_status_checks_per_second=1, - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - max_status_checks_per_second=max_status_checks_per_second, - disable_default_remote_provider_args=True, - disable_default_resources_args=True, - disable_envvar_declarations=True, - ) - self.workflow = workflow - self.workflow_sources = [] - for wfs in dag.get_sources(): - if os.path.isdir(wfs): - for dirpath, dirnames, filenames in os.walk(wfs): - self.workflow_sources.extend( - [os.path.join(dirpath, f) for f in filenames] - ) - else: - self.workflow_sources.append(os.path.abspath(wfs)) - - log = "sources=" - for f in self.workflow_sources: - log += f - logger.debug(log) - self.snakefile = workflow.main_snakefile - self.envvars = {e: os.environ[e] for e in workflow.envvars} - if self.envvars: - logger.debug("envvars = %s" % str(self.envvars)) - self.tibanna_sfn = tibanna_sfn - if precommand: - self.precommand = precommand - else: - self.precommand = "" - self.s3_bucket = workflow.default_remote_prefix.split("/")[0] - self.s3_subdir = re.sub( - f"^{self.s3_bucket}/", "", workflow.default_remote_prefix - ) - logger.debug("precommand= " + self.precommand) - logger.debug("bucket=" + self.s3_bucket) - logger.debug("subdir=" + self.s3_subdir) - self.quiet = workflow.quiet - - self.container_image = container_image or get_container_image() - logger.info(f"Using {self.container_image} for Tibanna jobs.") - self.tibanna_config = tibanna_config - - def shutdown(self): - # perform additional steps on shutdown if necessary - logger.debug("shutting down Tibanna executor") - super().shutdown() - - def cancel(self): - from tibanna.core import API - - for j in self.active_jobs: - logger.info(f"killing job {j.jobname}") - while True: - try: - res = API().kill(j.exec_arn) - if not self.quiet: - print(res) - break - except KeyboardInterrupt: - pass - self.shutdown() - - def split_filename(self, filename, checkdir=None): - f = os.path.abspath(filename) - if checkdir: - checkdir = checkdir.rstrip("/") - if f.startswith(checkdir): - fname = re.sub(f"^{checkdir}/", "", f) - fdir = checkdir - else: - direrrmsg = ( - "All source files including Snakefile, " - + "conda env files, and rule script files " - + "must be in the same working directory: {} vs {}" - ) - raise WorkflowError(direrrmsg.format(checkdir, f)) - else: - fdir, fname = os.path.split(f) - return fname, fdir - - def remove_prefix(self, s): - return re.sub(f"^{self.s3_bucket}/{self.s3_subdir}/", "", s) - - def get_snakefile(self): - return os.path.basename(self.snakefile) - - def add_command(self, job: ExecutorJobInterface, tibanna_args, tibanna_config): - # format command - command = self.format_job_exec(job) - - if self.precommand: - command = self.precommand + "; " + command - logger.debug("command = " + str(command)) - tibanna_args.command = command - - def add_workflow_files(self, job: ExecutorJobInterface, tibanna_args): - snakefile_fname, snakemake_dir = self.split_filename(self.snakefile) - snakemake_child_fnames = [] - for src in self.workflow_sources: - src_fname, _ = self.split_filename(src, snakemake_dir) - if src_fname != snakefile_fname: # redundant - snakemake_child_fnames.append(src_fname) - # change path for config files - # TODO - this is a hacky way to do this - self.workflow.overwrite_configfiles = [ - self.split_filename(cf, snakemake_dir)[0] - for cf in self.workflow.overwrite_configfiles - ] - tibanna_args.snakemake_directory_local = snakemake_dir - tibanna_args.snakemake_main_filename = snakefile_fname - tibanna_args.snakemake_child_filenames = list(set(snakemake_child_fnames)) - - def adjust_filepath(self, f): - if not hasattr(f, "remote_object"): - rel = self.remove_prefix(f) # log/benchmark - elif ( - hasattr(f.remote_object, "provider") and f.remote_object.provider.is_default - ): - rel = self.remove_prefix(f) - else: - rel = f - return rel - - def make_tibanna_input(self, job: ExecutorJobInterface): - from tibanna import ec2_utils, core as tibanna_core - - # input & output - # Local snakemake command here must be run with --default-remote-prefix - # and --default-remote-provider (forced) but on VM these options will be removed. - # The snakemake on the VM will consider these input and output as not remote. - # They files are transferred to the container by Tibanna before running snakemake. - # In short, the paths on VM must be consistent with what's in Snakefile. - # but the actual location of the files is on the S3 bucket/prefix. - # This mapping info must be passed to Tibanna. - for i in job.input: - logger.debug("job input " + str(i)) - logger.debug("job input is remote= " + ("true" if i.is_remote else "false")) - if hasattr(i.remote_object, "provider"): - logger.debug( - " is remote default= " - + ("true" if i.remote_object.provider.is_default else "false") - ) - for o in job.expanded_output: - logger.debug("job output " + str(o)) - logger.debug( - "job output is remote= " + ("true" if o.is_remote else "false") - ) - if hasattr(o.remote_object, "provider"): - logger.debug( - " is remote default= " - + ("true" if o.remote_object.provider.is_default else "false") - ) - file_prefix = ( - "file:///data1/snakemake" # working dir inside snakemake container on VM - ) - input_source = dict() - for ip in job.input: - ip_rel = self.adjust_filepath(ip) - input_source[os.path.join(file_prefix, ip_rel)] = "s3://" + ip - output_target = dict() - output_all = [eo for eo in job.expanded_output] - if job.log: - if isinstance(job.log, list): - output_all.extend([str(_) for _ in job.log]) - else: - output_all.append(str(job.log)) - if hasattr(job, "benchmark") and job.benchmark: - if isinstance(job.benchmark, list): - output_all.extend([str(_) for _ in job.benchmark]) - else: - output_all.append(str(job.benchmark)) - for op in output_all: - op_rel = self.adjust_filepath(op) - output_target[os.path.join(file_prefix, op_rel)] = "s3://" + op - - # mem & cpu - mem = job.resources["mem_mb"] / 1024 if "mem_mb" in job.resources.keys() else 1 - cpu = job.threads - - # jobid, grouping, run_name - jobid = tibanna_core.create_jobid() - if job.is_group(): - run_name = f"snakemake-job-{str(jobid)}-group-{str(job.groupid)}" - else: - run_name = f"snakemake-job-{str(jobid)}-rule-{str(job.rule)}" - - # tibanna input - tibanna_config = { - "run_name": run_name, - "mem": mem, - "cpu": cpu, - "ebs_size": math.ceil(job.resources["disk_mb"] / 1024), - "log_bucket": self.s3_bucket, - } - logger.debug("additional tibanna config: " + str(self.tibanna_config)) - if self.tibanna_config: - tibanna_config.update(self.tibanna_config) - tibanna_args = ec2_utils.Args( - output_S3_bucket=self.s3_bucket, - language="snakemake", - container_image=self.container_image, - input_files=input_source, - output_target=output_target, - input_env=self.envvars, - ) - self.add_workflow_files(job, tibanna_args) - self.add_command(job, tibanna_args, tibanna_config) - tibanna_input = { - "jobid": jobid, - "config": tibanna_config, - "args": tibanna_args.as_dict(), - } - logger.debug(json.dumps(tibanna_input, indent=4)) - return tibanna_input - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - logger.info("running job using Tibanna...") - from tibanna.core import API - - super()._run(job) - - # submit job here, and obtain job ids from the backend - tibanna_input = self.make_tibanna_input(job) - jobid = tibanna_input["jobid"] - exec_info = API().run_workflow( - tibanna_input, - sfn=self.tibanna_sfn, - verbose=not self.quiet, - jobid=jobid, - open_browser=False, - sleep=0, - ) - exec_arn = exec_info.get("_tibanna", {}).get("exec_arn", "") - jobname = tibanna_input["config"]["run_name"] - jobid = tibanna_input["jobid"] - - # register job as active, using your own namedtuple. - # The namedtuple must at least contain the attributes - # job, jobid, callback, error_callback. - self.active_jobs.append( - TibannaJob(job, jobname, jobid, exec_arn, callback, error_callback) - ) - - async def _wait_for_jobs(self): - # busy wait on job completion - # This is only needed if your backend does not allow to use callbacks - # for obtaining job status. - from tibanna.core import API - - while True: - # always use self.lock to avoid race conditions - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - for j in active_jobs: - # use self.status_rate_limiter to avoid too many API calls. - async with self.status_rate_limiter: - if j.exec_arn: - status = API().check_status(j.exec_arn) - else: - status = "FAILED_AT_SUBMISSION" - if not self.quiet or status != "RUNNING": - logger.debug(f"job {j.jobname}: {status}") - if status == "RUNNING": - still_running.append(j) - elif status == "SUCCEEDED": - j.callback(j.job) - else: - j.error_callback(j.job) - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() - - -def run_wrapper( - job_rule, - input, - output, - params, - wildcards, - threads, - resources, - log, - benchmark, - benchmark_repeats, - conda_env, - container_img, - singularity_args, - env_modules, - use_singularity, - linemaps, - debug, - cleanup_scripts, - shadow_dir, - jobid, - edit_notebook, - conda_base_path, - basedir, - runtime_sourcecache_path, -): - """ - Wrapper around the run method that handles exceptions and benchmarking. - - Arguments - job_rule -- the ``job.rule`` member - input -- a list of input files - output -- a list of output files - wildcards -- so far processed wildcards - threads -- usable threads - log -- a list of log files - shadow_dir -- optional shadow directory root - """ - # get shortcuts to job_rule members - run = job_rule.run_func - version = job_rule.version - rule = job_rule.name - is_shell = job_rule.shellcmd is not None - - if os.name == "posix" and debug: - sys.stdin = open("/dev/stdin") - - if benchmark is not None: - from snakemake.benchmark import ( - BenchmarkRecord, - benchmarked, - write_benchmark_records, - ) - - # Change workdir if shadow defined and not using singularity. - # Otherwise, we do the change from inside the container. - passed_shadow_dir = None - if use_singularity and container_img: - passed_shadow_dir = shadow_dir - shadow_dir = None - - try: - with change_working_directory(shadow_dir): - if benchmark: - bench_records = [] - for bench_iteration in range(benchmark_repeats): - # Determine whether to benchmark this process or do not - # benchmarking at all. We benchmark this process unless the - # execution is done through the ``shell:``, ``script:``, or - # ``wrapper:`` stanza. - is_sub = ( - job_rule.shellcmd - or job_rule.script - or job_rule.wrapper - or job_rule.cwl - ) - if is_sub: - # The benchmarking through ``benchmarked()`` is started - # in the execution of the shell fragment, script, wrapper - # etc, as the child PID is available there. - bench_record = BenchmarkRecord() - run( - input, - output, - params, - wildcards, - threads, - resources, - log, - version, - rule, - conda_env, - container_img, - singularity_args, - use_singularity, - env_modules, - bench_record, - jobid, - is_shell, - bench_iteration, - cleanup_scripts, - passed_shadow_dir, - edit_notebook, - conda_base_path, - basedir, - runtime_sourcecache_path, - ) - else: - # The benchmarking is started here as we have a run section - # and the generated Python function is executed in this - # process' thread. - with benchmarked() as bench_record: - run( - input, - output, - params, - wildcards, - threads, - resources, - log, - version, - rule, - conda_env, - container_img, - singularity_args, - use_singularity, - env_modules, - bench_record, - jobid, - is_shell, - bench_iteration, - cleanup_scripts, - passed_shadow_dir, - edit_notebook, - conda_base_path, - basedir, - runtime_sourcecache_path, - ) - # Store benchmark record for this iteration - bench_records.append(bench_record) - else: - run( - input, - output, - params, - wildcards, - threads, - resources, - log, - version, - rule, - conda_env, - container_img, - singularity_args, - use_singularity, - env_modules, - None, - jobid, - is_shell, - None, - cleanup_scripts, - passed_shadow_dir, - edit_notebook, - conda_base_path, - basedir, - runtime_sourcecache_path, - ) - except (KeyboardInterrupt, SystemExit) as e: - # Re-raise the keyboard interrupt in order to record an error in the - # scheduler but ignore it - raise e - except BaseException as ex: - # this ensures that exception can be re-raised in the parent thread - origin = get_exception_origin(ex, linemaps) - if origin is not None: - log_verbose_traceback(ex) - lineno, file = origin - raise RuleException( - format_error( - ex, lineno, linemaps=linemaps, snakefile=file, show_traceback=True - ) - ) - else: - # some internal bug, just reraise - raise ex - - if benchmark is not None: - try: - write_benchmark_records(bench_records, benchmark) - except BaseException as ex: - raise WorkflowError(ex) diff --git a/snakemake/executors/azure_batch.py b/snakemake/executors/azure_batch.py deleted file mode 100644 index ec89390f7..000000000 --- a/snakemake/executors/azure_batch.py +++ /dev/null @@ -1,933 +0,0 @@ -__author__ = "Johannes Köster, Andreas Wilm, Jake VanCampen" -__copyright__ = "Copyright 2022, Johannes Köster" -__email__ = "johannes.koester@uni-due.de" -__license__ = "MIT" - -import datetime -import io -import os -import re -import shutil -import shlex -import sys -import tarfile -import tempfile -import uuid -from collections import namedtuple -from pprint import pformat -from typing import Optional -from urllib.parse import urlparse - -from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface -from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface -from snakemake_interface_executor_plugins.dag import DAGExecutorInterface -from snakemake_interface_executor_plugins.jobs import ExecutorJobInterface -from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.utils import sleep - -from snakemake.exceptions import WorkflowError -import msrest.authentication as msa - -from snakemake.common import async_lock, bytesto, get_container_image, get_file_hash -from snakemake.exceptions import WorkflowError -from snakemake.executors import sleep -from snakemake.logging import logger - -AzBatchJob = namedtuple("AzBatchJob", "job jobid task_id callback error_callback") - - -def check_source_size(filename, warning_size_gb=0.2): - """A helper function to check the filesize, and return the file - to the calling function Additionally, given that we encourage these - packages to be small, we set a warning at 200MB (0.2GB). - """ - gb = bytesto(os.stat(filename).st_size, "g") - if gb > warning_size_gb: - logger.warning( - f"File {filename} (size {gb} GB) is greater than the {warning_size_gb} GB " - f"suggested size. Consider uploading larger files to storage first." - ) - return filename - - -class AzBatchConfig: - def __init__(self, batch_account_url: str): - # configure defaults - self.batch_account_url = batch_account_url - - # parse batch account name - result = urlparse(self.batch_account_url) - self.batch_account_name = str.split(result.hostname, ".")[0] - - self.batch_account_key = self.set_or_default("AZ_BATCH_ACCOUNT_KEY", None) - - # optional subnet config - self.batch_pool_subnet_id = self.set_or_default("BATCH_POOL_SUBNET_ID", None) - - # managed identity resource id configuration - self.managed_identity_resource_id = self.set_or_default( - "BATCH_MANAGED_IDENTITY_RESOURCE_ID", None - ) - - # parse subscription and resource id - if self.managed_identity_resource_id is not None: - self.subscription_id = self.managed_identity_resource_id.split("/")[2] - self.resource_group = self.managed_identity_resource_id.split("/")[4] - - self.managed_identity_client_id = self.set_or_default( - "BATCH_MANAGED_IDENTITY_CLIENT_ID", None - ) - - if self.batch_pool_subnet_id is not None: - if ( - self.managed_identity_client_id is None - or self.managed_identity_resource_id is None - ): - sys.exit( - "Error: BATCH_MANAGED_IDENTITY_RESOURCE_ID, BATCH_MANAGED_IDENTITY_CLIENT_ID must be set when deploying batch nodes into a private subnet!" - ) - - # parse account details necessary for batch client authentication steps - if self.batch_pool_subnet_id.split("/")[2] != self.subscription_id: - raise WorkflowError( - "Error: managed identity must be in the same subscription as the batch pool subnet." - ) - - if self.batch_pool_subnet_id.split("/")[4] != self.resource_group: - raise WorkflowError( - "Error: managed identity must be in the same resource group as the batch pool subnet." - ) - - # sas url to a batch node start task bash script - self.batch_node_start_task_sasurl = os.getenv("BATCH_NODE_START_TASK_SAS_URL") - - # options configured with env vars or default - self.batch_pool_image_publisher = self.set_or_default( - "BATCH_POOL_IMAGE_PUBLISHER", "microsoft-azure-batch" - ) - self.batch_pool_image_offer = self.set_or_default( - "BATCH_POOL_IMAGE_OFFER", "ubuntu-server-container" - ) - self.batch_pool_image_sku = self.set_or_default( - "BATCH_POOL_IMAGE_SKU", "20-04-lts" - ) - self.batch_pool_vm_container_image = self.set_or_default( - "BATCH_POOL_VM_CONTAINER_IMAGE", "ubuntu" - ) - self.batch_pool_vm_node_agent_sku_id = self.set_or_default( - "BATCH_POOL_VM_NODE_AGENT_SKU_ID", "batch.node.ubuntu 20.04" - ) - self.batch_pool_vm_size = self.set_or_default( - "BATCH_POOL_VM_SIZE", "Standard_D2_v3" - ) - - # dedicated pool node count - self.batch_pool_node_count = self.set_or_default("BATCH_POOL_NODE_COUNT", 1) - - # default tasks per node - # see https://learn.microsoft.com/en-us/azure/batch/batch-parallel-node-tasks - self.batch_tasks_per_node = self.set_or_default("BATCH_TASKS_PER_NODE", 1) - - # possible values "spread" or "pack" - # see https://learn.microsoft.com/en-us/azure/batch/batch-parallel-node-tasks - self.batch_node_fill_type = self.set_or_default( - "BATCH_NODE_FILL_TYPE", "spread" - ) - - # enables simplified batch node communication if set - # see: https://learn.microsoft.com/en-us/azure/batch/simplified-compute-node-communication - self.batch_node_communication_mode = self.set_or_default( - "BATCH_NODE_COMMUNICATION_SIMPLIFIED", None - ) - - self.resource_file_prefix = self.set_or_default( - "BATCH_POOL_RESOURCE_FILE_PREFIX", "resource-files" - ) - - self.container_registry_url = self.set_or_default( - "BATCH_CONTAINER_REGISTRY_URL", None - ) - - self.container_registry_user = self.set_or_default( - "BATCH_CONTAINER_REGISTRY_USER", None - ) - - self.container_registry_pass = self.set_or_default( - "BATCH_CONTAINER_REGISTRY_PASS", None - ) - - @staticmethod - def set_or_default(evar: str, default: Optional[str]): - gotvar = os.getenv(evar) - if gotvar is not None: - return gotvar - else: - return default - - -# the usage of this credential helper is required to authenitcate batch with managed identity credentials -# because not all Azure SDKs support the azure.identity credentials yet, and batch is one of them. -# ref1: https://gist.github.com/lmazuel/cc683d82ea1d7b40208de7c9fc8de59d -# ref2: https://gist.github.com/lmazuel/cc683d82ea1d7b40208de7c9fc8de59d -class AzureIdentityCredentialAdapter(msa.BasicTokenAuthentication): - def __init__( - self, - credential=None, - resource_id="https://management.azure.com/.default", - **kwargs, - ): - """Adapt any azure-identity credential to work with SDK that needs azure.common.credentials or msrestazure. - Default resource is ARM (syntax of endpoint v2) - :param credential: Any azure-identity credential (DefaultAzureCredential by default) - :param str resource_id: The scope to use to get the token (default ARM) - """ - try: - from azure.core.pipeline.policies import BearerTokenCredentialPolicy - from azure.identity import DefaultAzureCredential - - except ImportError: - raise WorkflowError( - "The Python 3 packages 'azure-core' and 'azure-identity' are required" - ) - - super(AzureIdentityCredentialAdapter, self).__init__(None) - if credential is None: - credential = DefaultAzureCredential() - self._policy = BearerTokenCredentialPolicy(credential, resource_id, **kwargs) - - def _make_request(self): - try: - from azure.core.pipeline import PipelineContext, PipelineRequest - from azure.core.pipeline.transport import HttpRequest - except ImportError: - raise WorkflowError("The Python 3 package azure-core is required") - - return PipelineRequest( - HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl"), - PipelineContext(None), - ) - - def set_token(self): - """Ask the azure-core BearerTokenCredentialPolicy policy to get a token. - Using the policy gives us for free the caching system of azure-core. - We could make this code simpler by using private method, but by definition - I can't assure they will be there forever, so mocking a fake call to the policy - to extract the token, using 100% public API.""" - request = self._make_request() - self._policy.on_request(request) - # Read Authorization, and get the second part after Bearer - token = request.http_request.headers["Authorization"].split(" ", 1)[1] - self.token = {"access_token": token} - - def signed_session(self, session=None): - self.set_token() - return super(AzureIdentityCredentialAdapter, self).signed_session(session) - - -class AzBatchExecutor(RemoteExecutor): - "Azure Batch Executor" - - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - jobname="snakejob.{name}.{jobid}.sh", - container_image=None, - regions=None, - location=None, - cache=False, - max_status_checks_per_second=1, - az_batch_account_url=None, - az_batch_enable_autoscale=False, - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, - ) - - try: - from azure.batch import BatchServiceClient - from azure.batch.batch_auth import SharedKeyCredentials - from azure.identity import DefaultAzureCredential - from azure.mgmt.batch import BatchManagementClient - - from snakemake.remote.AzBlob import AzureStorageHelper - - except ImportError: - raise WorkflowError( - "The Python 3 packages 'azure-batch', 'azure-mgmt-batch', and 'azure-identity'" - " must be installed to use Azure Batch" - ) - - AZURE_BATCH_RESOURCE_ENDPOINT = "https://batch.core.windows.net/" - - # Here we validate that az blob credential is SAS - # token because it is specific to azure batch executor - self.validate_az_blob_credential_is_sas() - self.azblob_helper = AzureStorageHelper() - - # get container from remote prefix - self.prefix_container = str.split(workflow.default_remote_prefix, "/")[0] - - # setup batch configuration sets self.az_batch_config - self.batch_config = AzBatchConfig(az_batch_account_url) - logger.debug(f"AzBatchConfig: {self.mask_batch_config_as_string()}") - - self.workflow = workflow - - # handle case on OSX with /var/ symlinked to /private/var/ causing - # issues with workdir not matching other workflow file dirs - dirname = os.path.dirname(self.workflow.persistence.path) - osxprefix = "/private" - if osxprefix in dirname: - dirname = dirname.removeprefix(osxprefix) - - self.workdir = dirname - - # Prepare workflow sources for build package - self._set_workflow_sources() - - # Pool ids can only contain any combination of alphanumeric characters along with dash and underscore - ts = datetime.datetime.now().strftime("%Y-%m%dT%H-%M-%S") - self.pool_id = f"snakepool-{ts:s}" - self.job_id = f"snakejob-{ts:s}" - - self.envvars = list(self.workflow.envvars) or [] - - self.container_image = container_image or get_container_image() - - # enable autoscale flag - self.az_batch_enable_autoscale = az_batch_enable_autoscale - - # Package workflow sources files and upload to storage - self._build_packages = set() - targz = self._generate_build_source_package() - - # removed after job failure/success - self.resource_file = self._upload_build_source_package( - targz, resource_prefix=self.batch_config.resource_file_prefix - ) - - # authenticate batch client from SharedKeyCredentials - if ( - self.batch_config.batch_account_key is not None - and self.batch_config.managed_identity_client_id is None - ): - logger.debug("Using batch account key for authentication...") - creds = SharedKeyCredentials( - self.batch_config.batch_account_name, - self.batch_config.batch_account_key, - ) - # else authenticate with managed indentity client id - elif self.batch_config.managed_identity_client_id is not None: - logger.debug("Using managed identity batch authentication...") - creds = DefaultAzureCredential( - managed_identity_client_id=self.batch_config.managed_identity_client_id - ) - creds = AzureIdentityCredentialAdapter( - credential=creds, resource_id=AZURE_BATCH_RESOURCE_ENDPOINT - ) - - self.batch_client = BatchServiceClient( - creds, batch_url=self.batch_config.batch_account_url - ) - - if self.batch_config.managed_identity_resource_id is not None: - self.batch_mgmt_client = BatchManagementClient( - credential=DefaultAzureCredential( - managed_identity_client_id=self.batch_config.managed_identity_client_id - ), - subscription_id=self.batch_config.subscription_id, - ) - - try: - self.create_batch_pool() - except WorkflowError: - logger.debug("Error: Failed to create batch pool, shutting down.") - self.shutdown() - - try: - self.create_batch_job() - except WorkflowError: - logger.debug("Error: Failed to create batch job, shutting down.") - self.shutdown() - - def shutdown(self): - # perform additional steps on shutdown - # if necessary (jobs were cancelled already) - - logger.debug("Deleting AzBatch job") - self.batch_client.job.delete(self.job_id) - - logger.debug("Deleting AzBatch pool") - self.batch_client.pool.delete(self.pool_id) - - logger.debug("Deleting workflow sources from blob") - - self.azblob_helper.delete_from_container( - self.prefix_container, self.resource_file.file_path - ) - - super().shutdown() - - def cancel(self): - for task in self.batch_client.task.list(self.job_id): - # strictly not need as job deletion also deletes task - self.batch_client.task.terminate(self.job_id, task.id) - self.shutdown() - - # mask_dict_vals masks sensitive keys from a dictionary of values for - # logging used to mask dicts with sensitive information from logging - @staticmethod - def mask_dict_vals(mdict: dict, keys: list): - ret_dict = mdict.copy() - for k in keys: - if k in ret_dict.keys() and ret_dict[k] is not None: - ret_dict[k] = 10 * "*" - return ret_dict - - # mask blob url is used to mask url values that may contain SAS - # token information from being printed to the logs - def mask_sas_urls(self, attrs: dict): - attrs_new = attrs.copy() - sas_pattern = r"\?[^=]+=([^?'\"]+)" - mask = 10 * "*" - - for k, value in attrs.items(): - if value is not None and re.search(sas_pattern, str(value)): - attrs_new[k] = re.sub(sas_pattern, mask, value) - - return attrs_new - - def mask_batch_config_as_string(self) -> str: - masked_keys = self.mask_dict_vals( - self.batch_config.__dict__, - [ - "batch_account_key", - "managed_identity_client_id", - ], - ) - masked_urls = self.mask_sas_urls(masked_keys) - return pformat(masked_urls, indent=2) - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - import azure.batch._batch_service_client as batch - import azure.batch.models as batchmodels - - super()._run(job) - - envsettings = [] - for key in self.envvars: - try: - envsettings.append( - batchmodels.EnvironmentSetting(name=key, value=os.environ[key]) - ) - except KeyError: - continue - - exec_job = self.format_job_exec(job) - exec_job = f"/bin/bash -c 'tar xzf {self.resource_file.file_path} && {shlex.quote(exec_job)}'" - - # A string that uniquely identifies the Task within the Job. - task_uuid = str(uuid.uuid1()) - task_id = f"{job.rule.name}-{task_uuid}" - - # This is the admin user who runs the command inside the container. - user = batchmodels.AutoUserSpecification( - scope=batchmodels.AutoUserScope.pool, - elevation_level=batchmodels.ElevationLevel.admin, - ) - - # This is the docker image we want to run - task_container_settings = batchmodels.TaskContainerSettings( - image_name=self.container_image, container_run_options="--rm" - ) - - # https://docs.microsoft.com/en-us/python/api/azure-batch/azure.batch.models.taskaddparameter?view=azure-python - # all directories recursively below the AZ_BATCH_NODE_ROOT_DIR (the root of Azure Batch directories on the node) - # are mapped into the container, all Task environment variables are mapped into the container, - # and the Task command line is executed in the container - task = batch.models.TaskAddParameter( - id=task_id, - command_line=exec_job, - container_settings=task_container_settings, - resource_files=[self.resource_file], # Snakefile, envs, yml files etc. - user_identity=batchmodels.UserIdentity(auto_user=user), - environment_settings=envsettings, - ) - - # register job as active, using your own namedtuple. - self.batch_client.task.add(self.job_id, task) - self.active_jobs.append( - AzBatchJob(job, self.job_id, task_id, callback, error_callback) - ) - logger.debug(f"Added AzBatch task {task_id}") - logger.debug( - f"Added AzBatch task {pformat(self.mask_sas_urls(task.__dict__), indent=2)}" - ) - - # from https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py - @staticmethod - def _read_stream_as_string(stream, encoding): - """Read stream as string - :param stream: input stream generator - :param str encoding: The encoding of the file. The default is utf-8. - :return: The file content. - :rtype: str - """ - output = io.BytesIO() - try: - for data in stream: - output.write(data) - if encoding is None: - encoding = "utf-8" - return output.getvalue().decode(encoding) - finally: - output.close() - - # adopted from https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py - def _get_task_output(self, job_id, task_id, stdout_or_stderr, encoding=None): - assert stdout_or_stderr in ["stdout", "stderr"] - fname = stdout_or_stderr + ".txt" - try: - stream = self.batch_client.file.get_from_task(job_id, task_id, fname) - content = self._read_stream_as_string(stream, encoding) - except Exception: - content = "" - - return content - - async def _wait_for_jobs(self): - import azure.batch.models as batchmodels - - while True: - # always use self.lock to avoid race conditions - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - - # Loop through active jobs and act on status - for batch_job in active_jobs: - async with self.status_rate_limiter: - logger.debug(f"Monitoring {len(active_jobs)} active AzBatch tasks") - task = self.batch_client.task.get(self.job_id, batch_job.task_id) - - if task.state == batchmodels.TaskState.completed: - dt = ( - task.execution_info.end_time - - task.execution_info.start_time - ) - rc = task.execution_info.exit_code - rt = task.execution_info.retry_count - stderr = self._get_task_output( - self.job_id, batch_job.task_id, "stderr" - ) - stdout = self._get_task_output( - self.job_id, batch_job.task_id, "stdout" - ) - logger.debug( - "task {} completed: result={} exit_code={}\n".format( - batch_job.task_id, task.execution_info.result, rc - ) - ) - logger.debug( - "task {} completed: run_time={}, retry_count={}\n".format( - batch_job.task_id, str(dt), rt - ) - ) - - def print_output(): - logger.debug( - "task {}: stderr='{}'\n".format( - batch_job.task_id, stderr - ) - ) - logger.debug( - "task {}: stdout='{}'\n".format( - batch_job.task_id, stdout - ) - ) - - if ( - task.execution_info.result - == batchmodels.TaskExecutionResult.failure - ): - logger.error( - f"Azure task failed: code={str(task.execution_info.failure_info.code)}, message={str(task.execution_info.failure_info.message)}" - ) - for d in task.execution_info.failure_info.details: - logger.error(f"Error Details: {str(d)}") - print_output() - batch_job.error_callback(batch_job.job) - elif ( - task.execution_info.result - == batchmodels.TaskExecutionResult.success - ): - batch_job.callback(batch_job.job) - else: - logger.error( - "Unknown Azure task execution result: {}".format( - task.execution_info.result - ) - ) - print_output() - batch_job.error_callback(batch_job.job) - - # The operation is still running - else: - logger.debug( - f"task {batch_job.task_id}: creation_time={task.creation_time} state={task.state} node_info={task.node_info}\n" - ) - still_running.append(batch_job) - - # fail if start task fails on a node or node state becomes unusable - # and stream stderr stdout to stream - node_list = self.batch_client.compute_node.list(self.pool_id) - for n in node_list: - # error on unusable node (this occurs if your container image fails to pull) - if n.state == "unusable": - if n.errors is not None: - for e in n.errors: - logger.error( - f"Azure task error: {e.message}, {e.error_details[0].__dict__}" - ) - logger.error( - "A node entered an unusable state, quitting." - ) - return - - if n.start_task_info is not None and ( - n.start_task_info.result - == batchmodels.TaskExecutionResult.failure - ): - try: - stderr_file = ( - self.batch_client.file.get_from_compute_node( - self.pool_id, n.id, "/startup/stderr.txt" - ) - ) - stderr_stream = self._read_stream_as_string( - stderr_file, "utf-8" - ) - except Exception: - stderr_stream = "" - - try: - stdout_file = ( - self.batch_client.file.get_from_compute_node( - self.pool_id, n.id, "/startup/stdout.txt" - ) - ) - stdout_stream = self._read_stream_as_string( - stdout_file, "utf-8" - ) - except Exception: - stdout_stream = "" - - logger.error( - "Azure start task execution failed on node: {}.\nSTART_TASK_STDERR:{}\nSTART_TASK_STDOUT: {}".format( - n.start_task_info.failure_info.message, - stdout_stream, - stderr_stream, - ) - ) - return - - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() - - def create_batch_pool(self): - """Creates a pool of compute nodes""" - - import azure.batch._batch_service_client as bsc - import azure.batch.models as batchmodels - import azure.mgmt.batch.models as mgmtbatchmodels - - image_ref = bsc.models.ImageReference( - publisher=self.batch_config.batch_pool_image_publisher, - offer=self.batch_config.batch_pool_image_offer, - sku=self.batch_config.batch_pool_image_sku, - version="latest", - ) - - # optional subnet network configuration - # requires AAD batch auth insead of batch key auth - network_config = None - if self.batch_config.batch_pool_subnet_id is not None: - network_config = batchmodels.NetworkConfiguration( - subnet_id=self.batch_config.batch_pool_subnet_id - ) - - # configure a container registry - - # Specify container configuration, fetching an image - # https://docs.microsoft.com/en-us/azure/batch/batch-docker-container-workloads#prefetch-images-for-container-configuration - container_config = batchmodels.ContainerConfiguration( - type="dockerCompatible", container_image_names=[self.container_image] - ) - - user = None - passw = None - identity_ref = None - registry_conf = None - - if self.batch_config.container_registry_url is not None: - if ( - self.batch_config.container_registry_user is not None - and self.batch_config.container_registry_pass is not None - ): - user = self.batch_config.container_registry_user - passw = self.batch_config.container_registry_pass - elif self.batch_config.managed_identity_resource_id is not None: - identity_ref = batchmodels.ComputeNodeIdentityReference( - resource_id=self.batch_config.managed_identity_resource_id - ) - else: - raise WorkflowError( - "No container registry authentication scheme set. Please set the BATCH_CONTAINER_REGISTRY_USER and BATCH_CONTAINER_REGISTRY_PASS or set MANAGED_IDENTITY_CLIENT_ID and MANAGED_IDENTITY_RESOURCE_ID." - ) - - registry_conf = [ - batchmodels.ContainerRegistry( - registry_server=self.batch_config.container_registry_url, - identity_reference=identity_ref, - user_name=str(user), - password=str(passw), - ) - ] - - # Specify container configuration, fetching an image - # https://docs.microsoft.com/en-us/azure/batch/batch-docker-container-workloads#prefetch-images-for-container-configuration - container_config = batchmodels.ContainerConfiguration( - type="dockerCompatible", - container_image_names=[self.container_image], - container_registries=registry_conf, - ) - - # default to no start task - start_task = None - - # if configured us start task bash script from sas url - if self.batch_config.batch_node_start_task_sasurl is not None: - _SIMPLE_TASK_NAME = "start_task.sh" - start_task_admin = batchmodels.UserIdentity( - auto_user=batchmodels.AutoUserSpecification( - elevation_level=batchmodels.ElevationLevel.admin, - scope=batchmodels.AutoUserScope.pool, - ) - ) - start_task = batchmodels.StartTask( - command_line=f"bash {_SIMPLE_TASK_NAME}", - resource_files=[ - batchmodels.ResourceFile( - file_path=_SIMPLE_TASK_NAME, - http_url=self.batch_config.batch_node_start_task_sasurl, - ) - ], - user_identity=start_task_admin, - ) - - # autoscale requires the initial dedicated node count to be zero - if self.az_batch_enable_autoscale: - self.batch_config.batch_pool_node_count = 0 - - node_communication_strategy = None - if self.batch_config.batch_node_communication_mode is not None: - node_communication_strategy = batchmodels.NodeCommunicationMode.simplified - - new_pool = batchmodels.PoolAddParameter( - id=self.pool_id, - virtual_machine_configuration=batchmodels.VirtualMachineConfiguration( - image_reference=image_ref, - container_configuration=container_config, - node_agent_sku_id=self.batch_config.batch_pool_vm_node_agent_sku_id, - ), - network_configuration=network_config, - vm_size=self.batch_config.batch_pool_vm_size, - target_dedicated_nodes=self.batch_config.batch_pool_node_count, - target_node_communication_mode=node_communication_strategy, - target_low_priority_nodes=0, - start_task=start_task, - task_slots_per_node=self.batch_config.batch_tasks_per_node, - task_scheduling_policy=batchmodels.TaskSchedulingPolicy( - node_fill_type=self.batch_config.batch_node_fill_type - ), - ) - - # create pool if not exists - try: - logger.debug(f"Creating pool: {self.pool_id}") - self.batch_client.pool.add(new_pool) - - if self.az_batch_enable_autoscale: - # define the autoscale formula - formula = """$samples = $PendingTasks.GetSamplePercent(TimeInterval_Minute * 5); - $tasks = $samples < 70 ? max(0,$PendingTasks.GetSample(1)) : max( $PendingTasks.GetSample(1), avg($PendingTasks.GetSample(TimeInterval_Minute * 5))); - $targetVMs = $tasks > 0? $tasks:max(0, $TargetDedicatedNodes/2); - $TargetDedicatedNodes = max(0, min($targetVMs, 10)); - $NodeDeallocationOption = taskcompletion;""" - - # Enable autoscale; specify the formula - self.batch_client.pool.enable_auto_scale( - self.pool_id, - auto_scale_formula=formula, - # the minimum allowed autoscale interval is 5 minutes - auto_scale_evaluation_interval=datetime.timedelta(minutes=5), - pool_enable_auto_scale_options=None, - custom_headers=None, - raw=False, - ) - - # update pool with managed identity, enables batch nodes to act as managed identity - if self.batch_config.managed_identity_resource_id is not None: - mid = mgmtbatchmodels.BatchPoolIdentity( - type=mgmtbatchmodels.PoolIdentityType.user_assigned, - user_assigned_identities={ - self.batch_config.managed_identity_resource_id: mgmtbatchmodels.UserAssignedIdentities() - }, - ) - params = mgmtbatchmodels.Pool(identity=mid) - self.batch_mgmt_client.pool.update( - resource_group_name=self.batch_config.resource_group, - account_name=self.batch_config.batch_account_name, - pool_name=self.pool_id, - parameters=params, - ) - - except batchmodels.BatchErrorException as err: - if err.error.code != "PoolExists": - raise WorkflowError( - f"Error: Failed to create pool: {err.error.message}" - ) - else: - logger.debug(f"Pool {self.pool_id} exists.") - - def create_batch_job(self): - """Creates a job with the specified ID, associated with the specified pool""" - import azure.batch._batch_service_client as bsc - - logger.debug(f"Creating job {self.job_id}") - - self.batch_client.job.add( - bsc.models.JobAddParameter( - id=self.job_id, - constraints=bsc.models.JobConstraints(max_task_retry_count=0), - pool_info=bsc.models.PoolInformation(pool_id=self.pool_id), - ) - ) - - @staticmethod - def validate_az_blob_credential_is_sas(): - """ - Validates that the AZ_BLOB_CREDENTIAL is a valid storage account SAS - token, required when using --az-batch with AzBlob remote. - """ - cred = os.environ.get("AZ_BLOB_CREDENTIAL") - if cred is not None: - # regex pattern for Storage Account SAS Token - rgx = r"\?sv=.*&ss=.*&srt=.*&sp=.*&se=.*&st=.*&spr=.*&sig=.*" - if re.compile(rgx).match(cred) is None: - raise WorkflowError( - "AZ_BLOB_CREDENTIAL is not a valid storage account SAS token." - ) - - # from google_lifesciences.py - def _set_workflow_sources(self): - """We only add files from the working directory that are config related - (e.g., the Snakefile or a config.yml equivalent), or checked into git. - """ - self.workflow_sources = [] - - for wfs in self.dag.get_sources(): - if os.path.isdir(wfs): - for dirpath, dirnames, filenames in os.walk(wfs): - self.workflow_sources.extend( - [check_source_size(os.path.join(dirpath, f)) for f in filenames] - ) - else: - self.workflow_sources.append(check_source_size(os.path.abspath(wfs))) - - # from google_lifesciences.py - def _generate_build_source_package(self): - """in order for the instance to access the working directory in storage, - we need to upload it. This file is cleaned up at the end of the run. - We do this, and then obtain from the instance and extract. - """ - # Workflow sources for cloud executor must all be under same workdir root - for filename in self.workflow_sources: - if self.workdir not in filename: - raise WorkflowError( - "All source files must be present in the working directory, " - "{workdir} to be uploaded to a build package that respects " - "relative paths, but {filename} was found outside of this " - "directory. Please set your working directory accordingly, " - "and the path of your Snakefile to be relative to it.".format( - workdir=self.workdir, filename=filename - ) - ) - - # We will generate a tar.gz package, renamed by hash - tmpname = next(tempfile._get_candidate_names()) - targz = os.path.join(tempfile.gettempdir(), f"snakemake-{tmpname}.tar.gz") - tar = tarfile.open(targz, "w:gz") - - # Add all workflow_sources files - for filename in self.workflow_sources: - arcname = filename.replace(self.workdir + os.path.sep, "") - tar.add(filename, arcname=arcname) - logger.debug( - f"Created {targz} with the following contents: {self.workflow_sources}" - ) - tar.close() - - # Rename based on hash, in case user wants to save cache - sha256 = get_file_hash(targz) - hash_tar = os.path.join( - self.workflow.persistence.aux_path, f"workdir-{sha256}.tar.gz" - ) - - # Only copy if we don't have it yet, clean up if we do - if not os.path.exists(hash_tar): - shutil.move(targz, hash_tar) - else: - os.remove(targz) - - # We will clean these all up at shutdown - self._build_packages.add(hash_tar) - - return hash_tar - - def _upload_build_source_package(self, targz, resource_prefix=""): - """given a .tar.gz created for a workflow, upload it to the blob - storage account, only if the blob doesn't already exist. - """ - - import azure.batch.models as batchmodels - - blob_name = os.path.join(resource_prefix, os.path.basename(targz)) - - # upload blob to storage using storage helper - bc = self.azblob_helper.upload_to_azure_storage( - self.prefix_container, targz, blob_name=blob_name - ) - - # return resource file - return batchmodels.ResourceFile(http_url=bc.url, file_path=blob_name) diff --git a/snakemake/executors/dryrun.py b/snakemake/executors/dryrun.py new file mode 100644 index 000000000..973b6a008 --- /dev/null +++ b/snakemake/executors/dryrun.py @@ -0,0 +1,77 @@ +__author__ = "Johannes Köster" +__copyright__ = "Copyright 2023, Johannes Köster" +__email__ = "johannes.koester@uni-due.de" +__license__ = "MIT" + +from snakemake_interface_executor_plugins.executors.base import AbstractExecutor +from snakemake_interface_executor_plugins.jobs import ( + ExecutorJobInterface, +) +from snakemake_interface_executor_plugins import CommonSettings +from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo + +from snakemake.logging import logger + + +common_settings = CommonSettings( + non_local_exec=False, + dryrun_exec=True, + implies_no_shared_fs=False, +) + + +class Executor(AbstractExecutor): + def run_job( + self, + job: ExecutorJobInterface, + ): + job_info = SubmittedJobInfo(job=job) + self.report_job_submission(job_info) + self.report_job_success(job_info) + + def get_exec_mode(self): + raise NotImplementedError() + + def printjob(self, job: ExecutorJobInterface): + super().printjob(job) + if job.is_group(): + for j in job.jobs: + self.printcache(j) + else: + self.printcache(job) + + def printcache(self, job: ExecutorJobInterface): + cache_mode = self.workflow.get_cache_mode(job.rule) + if cache_mode: + if self.workflow.output_file_cache.exists(job, cache_mode): + logger.info( + "Output file {} will be obtained from global between-workflow cache.".format( + job.output[0] + ) + ) + else: + logger.info( + "Output file {} will be written to global between-workflow cache.".format( + job.output[0] + ) + ) + + def cancel(self): + # nothing to do + pass + + def shutdown(self): + # nothing to do + pass + + def handle_job_success(self, job: ExecutorJobInterface): + # nothing to do + pass + + def handle_job_error(self, job: ExecutorJobInterface): + # nothing to do + pass + + @property + def cores(self): + return self.workflow.resource_settings.cores diff --git a/snakemake/executors/flux.py b/snakemake/executors/flux.py index 6b0ec1575..d5f011bf9 100644 --- a/snakemake/executors/flux.py +++ b/snakemake/executors/flux.py @@ -12,7 +12,6 @@ from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface from snakemake_interface_executor_plugins.utils import sleep from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface from snakemake.exceptions import WorkflowError @@ -42,22 +41,22 @@ def __init__( self, workflow: WorkflowExecutorInterface, dag: DAGExecutorInterface, - stats: StatsExecutorInterface, logger: LoggerExecutorInterface, jobname="snakejob.{name}.{jobid}.sh", ): super().__init__( workflow, dag, - stats, logger, None, jobname=jobname, max_status_checks_per_second=10, + pass_envvar_declarations_to_cmd=True, ) # Attach variables for easy access self.workdir = os.path.realpath(os.path.dirname(self.workflow.persistence.path)) + # TODO unused self.envvars = list(self.workflow.envvars) or [] # Quit early if we can't access the flux api @@ -83,7 +82,7 @@ def _set_job_resources(self, job: ExecutorJobInterface): including default regions and the virtual machine configuration """ self.default_resources = DefaultResources( - from_other=self.workflow.default_resources + from_other=self.workflow.resource_settings.default_resources ) def get_snakefile(self): diff --git a/snakemake/executors/ga4gh_tes.py b/snakemake/executors/ga4gh_tes.py deleted file mode 100644 index 5588f303f..000000000 --- a/snakemake/executors/ga4gh_tes.py +++ /dev/null @@ -1,323 +0,0 @@ -__author__ = "Sven Twardziok, Alex Kanitz, Valentin Schneider-Lunitz, Johannes Köster" -__copyright__ = "Copyright 2022, Johannes Köster" -__email__ = "johannes.koester@uni-due.de" -__license__ = "MIT" - -import asyncio -import math -import os -from collections import namedtuple - -from snakemake_interface_executor_plugins.dag import DAGExecutorInterface -from snakemake_interface_executor_plugins.jobs import ExecutorJobInterface -from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.utils import sleep -from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface -from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface - -from snakemake.logging import logger -from snakemake.exceptions import WorkflowError -from snakemake.common import get_container_image, async_lock - -TaskExecutionServiceJob = namedtuple( - "TaskExecutionServiceJob", "job jobid callback error_callback" -) - - -class TaskExecutionServiceExecutor(RemoteExecutor): - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - jobname="snakejob.{name}.{jobid}.sh", - max_status_checks_per_second=0.5, - tes_url=None, - container_image=None, - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, - ) - try: - import tes - except ImportError: - raise WorkflowError( - "Unable to import Python package tes. TES backend requires py-tes to be installed. Please install py-tes, e.g. via Conda or Pip." - ) - - self.container_image = container_image or get_container_image() - logger.info(f"Using {self.container_image} for TES jobs.") - self.container_workdir = "/tmp" - self.max_status_checks_per_second = max_status_checks_per_second - self.tes_url = tes_url - self.tes_client = tes.HTTPClient( - url=self.tes_url, - token=os.environ.get("TES_TOKEN"), - user=os.environ.get("FUNNEL_SERVER_USER"), - password=os.environ.get("FUNNEL_SERVER_PASSWORD"), - ) - logger.info(f"[TES] Job execution on TES: {self.tes_url}") - - def get_job_exec_prefix(self, job: ExecutorJobInterface): - return "mkdir /tmp/conda && cd /tmp" - - def shutdown(self): - # perform additional steps on shutdown if necessary - super().shutdown() - - def cancel(self): - for job in self.active_jobs: - try: - self.tes_client.cancel_task(job.jobid) - logger.info(f"[TES] Task canceled: {job.jobid}") - except Exception: - logger.info( - "[TES] Canceling task failed. This may be because the job is " - "already in a terminal state." - ) - self.shutdown() - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - - jobscript = self.get_jobscript(job) - self.write_jobscript(job, jobscript) - - # submit job here, and obtain job ids from the backend - try: - task = self._get_task(job, jobscript) - tes_id = self.tes_client.create_task(task) - logger.info(f"[TES] Task submitted: {tes_id}") - except Exception as e: - raise WorkflowError(str(e)) - - self.active_jobs.append( - TaskExecutionServiceJob(job, tes_id, callback, error_callback) - ) - - async def _wait_for_jobs(self): - UNFINISHED_STATES = ["UNKNOWN", "INITIALIZING", "QUEUED", "RUNNING", "PAUSED"] - ERROR_STATES = [ - "EXECUTOR_ERROR", - "SYSTEM_ERROR", - "CANCELED", # TODO: really call `error_callback` on this? - ] - - while True: - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - - for j in active_jobs: - async with self.status_rate_limiter: # TODO: this doesn't seem to do anything? - res = self.tes_client.get_task(j.jobid, view="MINIMAL") - logger.debug( - "[TES] State of task '{id}': {state}".format( - id=j.jobid, state=res.state - ) - ) - if res.state in UNFINISHED_STATES: - still_running.append(j) - elif res.state in ERROR_STATES: - logger.info(f"[TES] Task errored: {j.jobid}") - j.error_callback(j.job) - elif res.state == "COMPLETE": - logger.info(f"[TES] Task completed: {j.jobid}") - j.callback(j.job) - - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await asyncio.sleep(1 / self.max_status_checks_per_second) - - def _check_file_in_dir(self, checkdir, f): - if checkdir: - checkdir = checkdir.rstrip("/") - if not f.startswith(checkdir): - direrrmsg = ( - "All files including Snakefile, " - + "conda env files, rule script files, output files " - + "must be in the same working directory: {} vs {}" - ) - raise WorkflowError(direrrmsg.format(checkdir, f)) - - def _get_members_path(self, overwrite_path, f): - if overwrite_path: - members_path = overwrite_path - else: - members_path = os.path.join(self.container_workdir, str(os.path.relpath(f))) - return members_path - - def _prepare_file( - self, - filename, - overwrite_path=None, - checkdir=None, - pass_content=False, - type="Input", - ): - import tes - - # TODO: handle FTP files - max_file_size = 131072 - if type not in ["Input", "Output"]: - raise ValueError("Value for 'model' has to be either 'Input' or 'Output'.") - - members = {} - - # Handle remote files - if hasattr(filename, "is_remote") and filename.is_remote: - return None - - # Handle local files - else: - f = os.path.abspath(filename) - - self._check_file_in_dir(checkdir, f) - - members["path"] = self._get_members_path(overwrite_path, f) - - members["url"] = "file://" + f - if pass_content: - source_file_size = os.path.getsize(f) - if source_file_size > max_file_size: - logger.warning( - "Will not pass file '{f}' by content, as it exceeds the " - "minimum supported file size of {max_file_size} bytes " - "defined in the TES specification. Will try to upload " - "file instead.".format(f=f, max_file_size=max_file_size) - ) - else: - with open(f) as stream: - members["content"] = stream.read() - members["url"] = None - - model = getattr(tes.models, type) - logger.warning(members) - return model(**members) - - def _get_task_description(self, job: ExecutorJobInterface): - description = "" - if job.is_group(): - msgs = [i.message for i in job.jobs if i.message] - if msgs: - description = " & ".join(msgs) - else: - if job.message: - description = job.message - - return description - - def _get_task_inputs(self, job: ExecutorJobInterface, jobscript, checkdir): - inputs = [] - - # add workflow sources to inputs - for src in self.dag.get_sources(): - # exclude missing, hidden, empty and build files - if ( - not os.path.exists(src) - or os.path.basename(src).startswith(".") - or os.path.getsize(src) == 0 - or src.endswith(".pyc") - ): - continue - inputs.append( - self._prepare_file(filename=src, checkdir=checkdir, pass_content=True) - ) - - # add input files to inputs - for i in job.input: - obj = self._prepare_file(filename=i, checkdir=checkdir) - if obj: - inputs.append(obj) - - # add jobscript to inputs - inputs.append( - self._prepare_file( - filename=jobscript, - overwrite_path=os.path.join(self.container_workdir, "run_snakemake.sh"), - checkdir=checkdir, - pass_content=True, - ) - ) - - return inputs - - def _append_task_outputs(self, outputs, files, checkdir): - for file in files: - obj = self._prepare_file(filename=file, checkdir=checkdir, type="Output") - if obj: - outputs.append(obj) - return outputs - - def _get_task_outputs(self, job: ExecutorJobInterface, checkdir): - outputs = [] - # add output files to outputs - outputs = self._append_task_outputs(outputs, job.output, checkdir) - - # add log files to outputs - if job.log: - outputs = self._append_task_outputs(outputs, job.log, checkdir) - - # add benchmark files to outputs - if hasattr(job, "benchmark") and job.benchmark: - outputs = self._append_task_outputs(outputs, job.benchmark, checkdir) - - return outputs - - def _get_task_executors(self): - import tes - - executors = [] - executors.append( - tes.models.Executor( - image=self.container_image, - command=[ # TODO: info about what is executed is opaque - "/bin/bash", - os.path.join(self.container_workdir, "run_snakemake.sh"), - ], - workdir=self.container_workdir, - ) - ) - return executors - - def _get_task(self, job: ExecutorJobInterface, jobscript): - import tes - - checkdir, _ = os.path.split(self.snakefile) - - task = {} - task["name"] = job.format_wildcards(self.jobname) - task["description"] = self._get_task_description(job) - task["inputs"] = self._get_task_inputs(job, jobscript, checkdir) - task["outputs"] = self._get_task_outputs(job, checkdir) - task["executors"] = self._get_task_executors() - task["resources"] = tes.models.Resources() - - # define resources - if job.resources.get("_cores") is not None: - task["resources"].cpu_cores = job.resources["_cores"] - if job.resources.get("mem_mb") is not None: - task["resources"].ram_gb = math.ceil(job.resources["mem_mb"] / 1000) - if job.resources.get("disk_mb") is not None: - task["resources"].disk_gb = math.ceil(job.resources["disk_mb"] / 1000) - - tes_task = tes.Task(**task) - logger.debug(f"[TES] Built task: {tes_task}") - return tes_task diff --git a/snakemake/executors/google_lifesciences.py b/snakemake/executors/google_lifesciences.py deleted file mode 100644 index 9b3bf4185..000000000 --- a/snakemake/executors/google_lifesciences.py +++ /dev/null @@ -1,1054 +0,0 @@ -__author__ = "Johannes Köster" -__copyright__ = "Copyright 2022, Johannes Köster" -__email__ = "johannes.koester@uni-due.de" -__license__ = "MIT" - -import logging -import os -import time -import shutil -import tarfile - -import tempfile -from collections import namedtuple -import uuid -import re -import math - -from snakemake_interface_executor_plugins.dag import DAGExecutorInterface -from snakemake_interface_executor_plugins.jobs import ExecutorJobInterface -from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface -from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface - -from snakemake.logging import logger -from snakemake.exceptions import print_exception -from snakemake.exceptions import log_verbose_traceback -from snakemake.exceptions import WorkflowError -from snakemake.common import bytesto, get_container_image, get_file_hash, async_lock -from snakemake.resources import DefaultResources - - -# https://github.com/googleapis/google-api-python-client/issues/299#issuecomment-343255309 -logging.getLogger("googleapiclient.discovery_cache").setLevel(logging.ERROR) - -GoogleLifeSciencesJob = namedtuple( - "GoogleLifeSciencesJob", "job jobname jobid callback error_callback" -) - - -def check_source_size(filename, warning_size_gb=0.2): - """A helper function to check the filesize, and return the file - to the calling function Additionally, given that we encourage these - packages to be small, we set a warning at 200MB (0.2GB). - """ - gb = bytesto(os.stat(filename).st_size, "g") - if gb > warning_size_gb: - logger.warning( - f"File {filename} (size {gb} GB) is greater than the {warning_size_gb} GB " - f"suggested size. Consider uploading larger files to storage first." - ) - return filename - - -class GoogleLifeSciencesExecutor(RemoteExecutor): - """ - The GoogleLifeSciences executor uses Google Cloud Storage, and - Compute Engine paired with the Google Life Sciences API. - https://cloud.google.com/life-sciences/docs/quickstart - """ - - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - jobname="snakejob.{name}.{jobid}.sh", - container_image=None, - regions=None, - location=None, - cache=False, - service_account_email=None, - network=None, - subnetwork=None, - max_status_checks_per_second=10, - preemption_default=None, - preemptible_rules=None, - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, - ) - # Prepare workflow sources for build package - self._set_workflow_sources() - - # Attach variables for easy access - self.quiet = workflow.quiet - self.workdir = os.path.realpath(os.path.dirname(self.workflow.persistence.path)) - self._save_storage_cache = cache - - # Set preemptible instances - self._set_preemptible_rules(preemption_default, preemptible_rules) - - # IMPORTANT: using Compute Engine API and not k8s == no support for secrets - self.envvars = list(self.workflow.envvars) or [] - - # Quit early if we can't authenticate - self._get_services() - self._get_bucket() - - # Akin to Kubernetes, create a run namespace, default container image - self.run_namespace = str(uuid.uuid4()) - self.container_image = container_image or get_container_image() - logger.info(f"Using {self.container_image} for Google Life Science jobs.") - self.regions = regions or ["us-east1", "us-west1", "us-central1"] - - # The project name is required, either from client or environment - self.project = ( - os.environ.get("GOOGLE_CLOUD_PROJECT") or self._bucket_service.project - ) - # Determine API location based on user preference, and then regions - self._set_location(location) - # Tell the user right away the regions, location, and container - logger.debug("regions=%s" % self.regions) - logger.debug("location=%s" % self.location) - logger.debug("container=%s" % self.container_image) - - # If specified, capture service account and GCE VM network configuration - self.service_account_email = service_account_email - self.network = network - self.subnetwork = subnetwork - - # Log service account and VM network configuration - logger.debug("service_account_email=%s" % self.service_account_email) - logger.debug("network=%s" % self.network) - logger.debug("subnetwork=%s" % self.subnetwork) - - # Keep track of build packages to clean up shutdown, and generate - self._build_packages = set() - targz = self._generate_build_source_package() - self._upload_build_source_package(targz) - - # we need to add custom - # default resources depending on the instance requested - self.default_resources = None - - def get_default_resources_args(self, default_resources=None): - assert default_resources is None - return super().get_default_resources_args( - default_resources=self.default_resources - ) - - def _get_services(self): - """ - Use the Google Discovery Build to generate API clients - for Life Sciences, and use the google storage python client - for storage. - """ - from googleapiclient.discovery import build as discovery_build - from google.cloud import storage - import google.auth - import google_auth_httplib2 - import httplib2 - import googleapiclient - - # Credentials may be exported to the environment or from a service account on a GCE VM instance. - try: - # oauth2client is deprecated, see: https://google-auth.readthedocs.io/en/master/oauth2client-deprecation.html - # google.auth is replacement - # not sure about scopes here. this cover all cloud services - creds, _ = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"] - ) - except google.auth.DefaultCredentialsError as ex: - log_verbose_traceback(ex) - raise ex - - def build_request(http, *args, **kwargs): - """ - See https://googleapis.github.io/google-api-python-client/docs/thread_safety.html - """ - new_http = google_auth_httplib2.AuthorizedHttp(creds, http=httplib2.Http()) - return googleapiclient.http.HttpRequest(new_http, *args, **kwargs) - - # Discovery clients for Google Cloud Storage and Life Sciences API - # create authorized http for building services - authorized_http = google_auth_httplib2.AuthorizedHttp( - creds, http=httplib2.Http() - ) - self._storage_cli = discovery_build( - "storage", - "v1", - cache_discovery=False, - requestBuilder=build_request, - http=authorized_http, - ) - self._compute_cli = discovery_build( - "compute", - "v1", - cache_discovery=False, - requestBuilder=build_request, - http=authorized_http, - ) - self._api = discovery_build( - "lifesciences", - "v2beta", - cache_discovery=False, - requestBuilder=build_request, - http=authorized_http, - ) - self._bucket_service = storage.Client() - - def _get_bucket(self): - """ - Get a connection to the storage bucket (self.bucket) and exit - if the name is taken or otherwise invalid. - - Parameters - ========== - workflow: the workflow object to derive the prefix from - """ - import google - - # Hold path to requested subdirectory and main bucket - bucket_name = self.workflow.default_remote_prefix.split("/")[0] - self.gs_subdir = re.sub( - f"^{bucket_name}/", "", self.workflow.default_remote_prefix - ) - self.gs_logs = os.path.join(self.gs_subdir, "google-lifesciences-logs") - - # Case 1: The bucket already exists - try: - self.bucket = self._bucket_service.get_bucket(bucket_name) - - # Case 2: The bucket needs to be created - except google.cloud.exceptions.NotFound: - self.bucket = self._bucket_service.create_bucket(bucket_name) - - # Case 2: The bucket name is already taken - except Exception as ex: - logger.error( - "Cannot get or create {} (exit code {}):\n{}".format( - bucket_name, ex.returncode, ex.output.decode() - ) - ) - log_verbose_traceback(ex) - raise ex - - logger.debug("bucket=%s" % self.bucket.name) - logger.debug("subdir=%s" % self.gs_subdir) - logger.debug("logs=%s" % self.gs_logs) - - def _set_location(self, location=None): - """ - The location is where the Google Life Sciences API is located. - This can be meaningful if the requester has data residency - requirements or multi-zone needs. To determine this value, - we first use the locations API to determine locations available, - and then compare them against: - - 1. user specified location or prefix - 2. regions having the same prefix - 3. if cannot be satisifed, we throw an error. - """ - # Derive available locations - # See https://cloud.google.com/life-sciences/docs/concepts/locations - locations = ( - self._api.projects() - .locations() - .list(name=f"projects/{self.project}") - .execute() - ) - - locations = {x["locationId"]: x["name"] for x in locations.get("locations", [])} - - # Alert the user about locations available - logger.debug("locations-available:\n%s" % "\n".join(locations)) - - # If no locations, there is something wrong - if not locations: - raise WorkflowError("No locations found for Google Life Sciences API.") - - # First pass, attempt to match the user-specified location (or prefix) - if location: - if location in locations: - self.location = locations[location] - return - - # It could be that a prefix was provided - for contender in locations: - if contender.startswith(location): - self.location = locations[contender] - return - - # If we get here and no match, alert user. - raise WorkflowError( - "Location or prefix requested %s is not available." % location - ) - - # If we get here, we need to select location from regions - for region in self.regions: - if region in locations: - self.location = locations[region] - return - - # If we get here, choose based on prefix - prefixes = set([r.split("-")[0] for r in self.regions]) - regexp = "^(%s)" % "|".join(prefixes) - for location in locations: - if re.search(regexp, location): - self.location = locations[location] - return - - # If we get here, total failure of finding location - raise WorkflowError( - " No locations available for regions!" - " Please specify a location with --google-lifesciences-location " - " or extend --google-lifesciences-regions to find a Life Sciences location." - ) - - def shutdown(self): - """ - Shutdown deletes build packages if the user didn't request to clean - up the cache. At this point we've already cancelled running jobs. - """ - from google.api_core import retry - from snakemake.remote.GS import google_cloud_retry_predicate - - @retry.Retry(predicate=google_cloud_retry_predicate) - def _shutdown(): - # Delete build source packages only if user regooglquested no cache - if self._save_storage_cache: - logger.debug("Requested to save workflow sources, skipping cleanup.") - else: - for package in self._build_packages: - blob = self.bucket.blob(package) - if blob.exists(): - logger.debug("Deleting blob %s" % package) - blob.delete() - - # perform additional steps on shutdown if necessary - - _shutdown() - - super().shutdown() - - def cancel(self): - """cancel execution, usually by way of control+c. Cleanup is done in - shutdown (deleting cached workdirs in Google Cloud Storage - """ - import googleapiclient - - # projects.locations.operations/cancel - operations = self._api.projects().locations().operations() - - for job in self.active_jobs: - request = operations.cancel(name=job.jobname) - logger.debug(f"Cancelling operation {job.jobid}") - try: - self._retry_request(request) - except (Exception, BaseException, googleapiclient.errors.HttpError): - continue - - self.shutdown() - - def get_available_machine_types(self): - """ - Using the regions available at self.regions, use the GCP API - to retrieve a lookup dictionary of all available machine types. - """ - # Regular expression to determine if zone in region - regexp = "^(%s)" % "|".join(self.regions) - - # Retrieve zones, filter down to selected regions - zones = self._retry_request( - self._compute_cli.zones().list(project=self.project) - ) - zones = [z for z in zones["items"] if re.search(regexp, z["name"])] - - # Retrieve machine types available across zones - # https://cloud.google.com/compute/docs/regions-zones/ - lookup = {} - for zone in zones: - request = self._compute_cli.machineTypes().list( - project=self.project, zone=zone["name"] - ) - lookup[zone["name"]] = self._retry_request(request)["items"] - - # Only keep those that are shared, use last zone as a base - machine_types = {mt["name"]: mt for mt in lookup[zone["name"]]} - del lookup[zone["name"]] - - # Update final list based on the remaining - to_remove = set() - for zone, types in lookup.items(): - names = [x["name"] for x in types] - names = [name for name in names if "micro" not in name] - names = [name for name in names if not re.search("^(e2|m1)", name)] - for machine_type in list(machine_types.keys()): - if machine_type not in names: - to_remove.add(machine_type) - - for machine_type in to_remove: - del machine_types[machine_type] - return machine_types - - def _add_gpu(self, gpu_count): - """ - Add a number of NVIDIA gpus to the current executor. This works - by way of adding nvidia_gpu to the job default resources, and also - changing the default machine type prefix to be n1, which is - the currently only supported instance type for using GPUs for LHS. - """ - if not gpu_count or gpu_count == 0: - return - - logger.debug( - "found resource request for {} GPUs. This will limit to n1 " - "instance types.".format(gpu_count) - ) - self.default_resources.set_resource("nvidia_gpu", gpu_count) - - self._machine_type_prefix = self._machine_type_prefix or "" - if not self._machine_type_prefix.startswith("n1"): - self._machine_type_prefix = "n1" - - # TODO: move this to workflow itself, as it can be used by other executors as well. - # Just provide a mechanism to set the restart times at the command line (and a default) - # Namely: --set-restart-times = --default-restart-times - # --preemptible-rules --default-preemptible-retries - def _set_preemptible_rules(self, preemption_default=None, preemptible_rules=None): - """ - Define a lookup dictionary for preemptible instance retries, which - is supported by the Google Life Science API. The user can set a default - for all steps, specify per step, or define a default for all steps - that aren't individually customized. - """ - self.preemptible_rules = {} - - # If a default is defined, we apply it to all the rules - if preemption_default is not None: - self.preemptible_rules = { - rule.name: preemption_default for rule in self.workflow.rules - } - - # Now update custom set rules - if preemptible_rules is not None: - for rule in preemptible_rules: - rule_name, restart_times = rule.strip().split("=") - self.preemptible_rules[rule_name] = int(restart_times) - - # Ensure we set the number of restart times for each rule - for rule_name, restart_times in self.preemptible_rules.items(): - rule = self.workflow.get_rule(rule_name) - rule.restart_times = restart_times - - def _generate_job_resources(self, job: ExecutorJobInterface): - """ - Given a particular job, generate the resources that it needs, - including default regions and the virtual machine configuration - """ - # Right now, do a best effort mapping of resources to instance types - cores = job.resources.get("_cores", 1) - mem_mb = job.resources.get("mem_mb", 15360) - - # IOPS performance proportional to disk size - disk_mb = job.resources.get("disk_mb", 512000) - - # Convert mb to gb - disk_gb = math.ceil(disk_mb / 1024) - - # Look for if the user wants an nvidia gpu - gpu_count = job.resources.get("nvidia_gpu") or job.resources.get("gpu") - gpu_model = job.resources.get("gpu_model") - - # If a gpu model is specified without a count, we assume 1 - if gpu_model and not gpu_count: - gpu_count = 1 - - # Update default resources using decided memory and disk - # TODO why is this needed?? - self.default_resources = DefaultResources( - from_other=self.workflow.default_resources - ) - self.default_resources.set_resource("mem_mb", mem_mb) - self.default_resources.set_resource("disk_mb", disk_mb) - - # Job resource specification can be overridden by gpu preferences - self.machine_type_prefix = job.resources.get("machine_type") - - # If gpu wanted, limit to N1 general family, and update arguments - if gpu_count: - self._add_gpu(gpu_count) - - machine_types = self.get_available_machine_types() - - # Alert the user of machine_types available before filtering - # https://cloud.google.com/compute/docs/machine-types - logger.debug( - "found {} machine types across regions {} before filtering " - "to increase selection, define fewer regions".format( - len(machine_types), self.regions - ) - ) - - # First pass - eliminate anything that too low in cpu/memory - keepers = dict() - - # Also keep track of max cpus and memory, in case none available - max_cpu = 1 - max_mem = 15360 - - for name, machine_type in machine_types.items(): - max_cpu = max(max_cpu, machine_type["guestCpus"]) - max_mem = max(max_mem, machine_type["memoryMb"]) - if machine_type["guestCpus"] < cores or machine_type["memoryMb"] < mem_mb: - continue - keepers[name] = machine_type - - # If a prefix is set, filter down to it - if self.machine_type_prefix: - machine_types = keepers - keepers = dict() - for name, machine_type in machine_types.items(): - if name.startswith(self.machine_type_prefix): - keepers[name] = machine_type - - # If we don't have any contenders, workflow error - if not keepers: - if self.machine_type_prefix: - raise WorkflowError( - "Machine prefix {prefix} is too strict, or the resources cannot " - " be satisfied, so there are no options " - "available.".format(prefix=self.machine_type_prefix) - ) - else: - raise WorkflowError( - "You requested {requestMemory} MB memory, {requestCpu} cores. " - "The maximum available are {availableMemory} MB memory and " - "{availableCpu} cores. These resources cannot be satisfied. " - "Please consider reducing the resource requirements of the " - "corresponding rule.".format( - requestMemory=mem_mb, - requestCpu=cores, - availableCpu=max_cpu, - availableMemory=max_mem, - ) - ) - - # Now find (quasi) minimal to satisfy constraints - machine_types = keepers - - # Select the first as the "smallest" - smallest = list(machine_types.keys())[0] - min_cores = machine_types[smallest]["guestCpus"] - min_mem = machine_types[smallest]["memoryMb"] - - for name, machine_type in machine_types.items(): - if ( - machine_type["guestCpus"] < min_cores - and machine_type["memoryMb"] < min_mem - ): - smallest = name - min_cores = machine_type["guestCpus"] - min_mem = machine_type["memoryMb"] - - selected = machine_types[smallest] - logger.debug( - "Selected machine type {}:{}".format(smallest, selected["description"]) - ) - - if job.is_group(): - preemptible = all(rule in self.preemptible_rules for rule in job.rules) - if not preemptible and any( - rule in self.preemptible_rules for rule in job.rules - ): - raise WorkflowError( - "All grouped rules should be homogenously set as preemptible rules" - "(see Defining groups for execution in snakemake documentation)" - ) - else: - preemptible = job.rule.name in self.preemptible_rules - - # We add the size for the image itself (10 GB) to bootDiskSizeGb - virtual_machine = { - "machineType": smallest, - "labels": {"app": "snakemake"}, - "bootDiskSizeGb": disk_gb + 10, - "preemptible": preemptible, - } - - # Add custom GCE VM configuration - if self.network and self.subnetwork: - virtual_machine["network"] = { - "network": self.network, - "usePrivateAddress": False, - "subnetwork": self.subnetwork, - } - - if self.service_account_email: - virtual_machine["service_account"] = { - "email": self.service_account_email, - "scopes": ["https://www.googleapis.com/auth/cloud-platform"], - } - - # If the user wants gpus, add accelerators here - if gpu_count: - accelerator = self._get_accelerator( - gpu_count, zone=selected["zone"], gpu_model=gpu_model - ) - virtual_machine["accelerators"] = [ - {"type": accelerator["name"], "count": gpu_count} - ] - - resources = {"regions": self.regions, "virtualMachine": virtual_machine} - return resources - - def _get_accelerator(self, gpu_count, zone, gpu_model=None): - """ - Get an appropriate accelerator for a GPU given a zone selection. - Currently Google offers NVIDIA Tesla T4 (likely the best), - NVIDIA P100, and the same T4 for a graphical workstation. Since - this isn't a graphical workstation use case, we choose the - accelerator that has >= to the maximumCardsPerInstace - """ - if not gpu_count or gpu_count == 0: - return - - accelerators = self._retry_request( - self._compute_cli.acceleratorTypes().list(project=self.project, zone=zone) - ) - - # Filter down to those with greater than or equal to needed gpus - keepers = {} - for accelerator in accelerators.get("items", []): - # Eliminate virtual workstations (vws) and models that don't match user preference - if (gpu_model and accelerator["name"] != gpu_model) or accelerator[ - "name" - ].endswith("vws"): - continue - - if accelerator["maximumCardsPerInstance"] >= gpu_count: - keepers[accelerator["name"]] = accelerator - - # If no matches available, exit early - if not keepers: - if gpu_model: - raise WorkflowError( - "An accelerator in zone {zone} with model {model} cannot " - " be satisfied, so there are no options " - "available.".format(zone=zone, model=gpu_model) - ) - else: - raise WorkflowError( - "An accelerator in zone {zone} cannot be satisifed, so " - "there are no options available.".format(zone=zone) - ) - - # Find smallest (in future the user might have preference for the type) - smallest = list(keepers.keys())[0] - max_gpu = keepers[smallest]["maximumCardsPerInstance"] - - # This should usually return P-100, which would be preference (cheapest) - for name, accelerator in keepers.items(): - if accelerator["maximumCardsPerInstance"] < max_gpu: - smallest = name - max_gpu = accelerator["maximumCardsPerInstance"] - - return keepers[smallest] - - def get_snakefile(self): - assert os.path.exists(self.workflow.main_snakefile) - return self.workflow.main_snakefile.removeprefix(self.workdir).strip(os.sep) - - def _set_workflow_sources(self): - """ - We only add files from the working directory that are config related - (e.g., the Snakefile or a config.yml equivalent), or checked into git. - """ - self.workflow_sources = [] - - for wfs in self.dag.get_sources(): - if os.path.isdir(wfs): - for dirpath, dirnames, filenames in os.walk(wfs): - self.workflow_sources.extend( - [check_source_size(os.path.join(dirpath, f)) for f in filenames] - ) - else: - self.workflow_sources.append(check_source_size(os.path.abspath(wfs))) - - def _generate_build_source_package(self): - """ - In order for the instance to access the working directory in storage, - we need to upload it. This file is cleaned up at the end of the run. - We do this, and then obtain from the instance and extract. - """ - # Workflow sources for cloud executor must all be under same workdir root - for filename in self.workflow_sources: - if self.workdir not in os.path.realpath(filename): - raise WorkflowError( - "All source files must be present in the working directory, " - "{workdir} to be uploaded to a build package that respects " - "relative paths, but {filename} was found outside of this " - "directory. Please set your working directory accordingly, " - "and the path of your Snakefile to be relative to it.".format( - workdir=self.workdir, filename=filename - ) - ) - - # We will generate a tar.gz package, renamed by hash - tmpname = next(tempfile._get_candidate_names()) - targz = os.path.join(tempfile.gettempdir(), "snakemake-%s.tar.gz" % tmpname) - tar = tarfile.open(targz, "w:gz") - - # Add all workflow_sources files - for filename in self.workflow_sources: - arcname = filename.replace(self.workdir + os.path.sep, "") - tar.add(filename, arcname=arcname) - - tar.close() - - # Rename based on hash, in case user wants to save cache - sha256 = get_file_hash(targz) - hash_tar = os.path.join( - self.workflow.persistence.aux_path, f"workdir-{sha256}.tar.gz" - ) - - # Only copy if we don't have it yet, clean up if we do - if not os.path.exists(hash_tar): - shutil.move(targz, hash_tar) - else: - os.remove(targz) - - # We will clean these all up at shutdown - self._build_packages.add(hash_tar) - - return hash_tar - - def _upload_build_source_package(self, targz): - """ - Given a .tar.gz created for a workflow, upload it to source/cache - of Google storage, only if the blob doesn't already exist. - """ - from google.api_core import retry - from snakemake.remote.GS import google_cloud_retry_predicate - - @retry.Retry(predicate=google_cloud_retry_predicate) - def _upload(): - # Upload to temporary storage, only if doesn't exist - self.pipeline_package = "source/cache/%s" % os.path.basename(targz) - blob = self.bucket.blob(self.pipeline_package) - logger.debug("build-package=%s" % self.pipeline_package) - if not blob.exists(): - blob.upload_from_filename(targz, content_type="application/gzip") - - _upload() - - def _generate_log_action(self, job: ExecutorJobInterface): - """generate an action to save the pipeline logs to storage.""" - # script should be changed to this when added to version control! - # https://raw.githubusercontent.com/snakemake/snakemake/main/snakemake/executors/google_lifesciences_helper.py - # Save logs from /google/logs/output to source/logs in bucket - commands = [ - "/bin/bash", - "-c", - f"wget -O /gls.py https://raw.githubusercontent.com/snakemake/snakemake/main/snakemake/executors/google_lifesciences_helper.py && chmod +x /gls.py && source activate snakemake || true && python /gls.py save {self.bucket.name} /google/logs {self.gs_logs}/{job.name}/jobid_{job.jobid}", - ] - - # Always run the action to generate log output - action = { - "containerName": f"snakelog-{job.name}-{job.jobid}", - "imageUri": self.container_image, - "commands": commands, - "labels": self._generate_pipeline_labels(job), - "alwaysRun": True, - } - - return action - - def _generate_job_action(self, job: ExecutorJobInterface): - """ - Generate a single action to execute the job. - """ - exec_job = self.format_job_exec(job) - - # The full command to download the archive, extract, and run - # For snakemake bases, we must activate the conda environment, but - # for custom images we must allow this to fail (hence || true) - commands = [ - "/bin/bash", - "-c", - "mkdir -p /workdir && " - "cd /workdir && " - "wget -O /download.py " - "https://raw.githubusercontent.com/snakemake/snakemake/main/snakemake/executors/google_lifesciences_helper.py && " - "chmod +x /download.py && " - "source activate snakemake || true && " - f"python /download.py download {self.bucket.name} {self.pipeline_package} " - "/tmp/workdir.tar.gz && " - f"tar -xzvf /tmp/workdir.tar.gz && {exec_job}", - ] - - # We are only generating one action, one job per run - # https://cloud.google.com/life-sciences/docs/reference/rest/v2beta/projects.locations.pipelines/run#Action - action = { - "containerName": f"snakejob-{job.name}-{job.jobid}", - "imageUri": self.container_image, - "commands": commands, - "environment": self._generate_environment(), - "labels": self._generate_pipeline_labels(job), - } - return action - - def _get_jobname(self, job: ExecutorJobInterface): - # Use a dummy job name (human readable and also namespaced) - return f"snakejob-{self.run_namespace}-{job.name}-{job.jobid}" - - def _generate_pipeline_labels(self, job: ExecutorJobInterface): - """ - Generate basic labels to identify the job, namespace, and that - snakemake is running the show! - """ - jobname = self._get_jobname(job) - labels = {"name": jobname, "app": "snakemake"} - return labels - - def _generate_environment(self): - """loop through envvars (keys to host environment) and add - any that are requested for the container environment. - """ - envvars = {} - for key in self.envvars: - try: - envvars[key] = os.environ[key] - except KeyError: - continue - - # Warn the user that we cannot support secrets - if envvars: - logger.warning("This API does not support environment secrets.") - return envvars - - def _generate_pipeline(self, job: ExecutorJobInterface): - """ - Based on the job details, generate a google Pipeline object - to pass to pipelines.run. This includes actions, resources, - environment, and timeout. - """ - # Generate actions (one per job step) and log saving action (runs no matter what) and resources - resources = self._generate_job_resources(job) - action = self._generate_job_action(job) - log_action = self._generate_log_action(job) - - pipeline = { - # Ordered list of actions to execute - "actions": [action, log_action], - # resources required for execution - "resources": resources, - # Technical question - difference between resource and action environment - # For now we will set them to be the same. - "environment": self._generate_environment(), - } - - # "timeout": string in seconds (3.5s) is not included (defaults to 7 days) - return pipeline - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - - # https://cloud.google.com/life-sciences/docs/reference/rest/v2beta/projects.locations.pipelines - pipelines = self._api.projects().locations().pipelines() - - # pipelines.run - # https://cloud.google.com/life-sciences/docs/reference/rest/v2beta/projects.locations.pipelines/run - - labels = self._generate_pipeline_labels(job) - pipeline = self._generate_pipeline(job) - - # The body of the request is a Pipeline and labels - body = {"pipeline": pipeline, "labels": labels} - - # capabilities - this won't currently work (Singularity in Docker) - # We either need to add CAPS or run in privileged mode (ehh) - if job.needs_singularity and self.workflow.use_singularity: - raise WorkflowError( - "Singularity requires additional capabilities that " - "aren't yet supported for standard Docker runs, and " - "is not supported for the Google Life Sciences executor." - ) - - # location looks like: "projects//locations/" - operation = pipelines.run(parent=self.location, body=body) - - # 403 will result if no permission to use pipelines or project - result = self._retry_request(operation) - - # The jobid is the last number of the full name - jobid = result["name"].split("/")[-1] - - # Give some logging for how to get status - logger.info( - "Get status with:\n" - "gcloud config set project {project}\n" - "gcloud beta lifesciences operations describe {location}/operations/{jobid}\n" - "gcloud beta lifesciences operations list\n" - "Logs will be saved to: {bucket}/{logdir}\n".format( - project=self.project, - jobid=jobid, - location=self.location, - bucket=self.bucket.name, - logdir=self.gs_logs, - ) - ) - - self.active_jobs.append( - GoogleLifeSciencesJob(job, result["name"], jobid, callback, error_callback) - ) - - def _job_was_successful(self, status): - """ - Based on a status response (a [pipeline].projects.locations.operations.get - debug print the list of events, return True if all return codes 0 - and False otherwise (indication of failure). In that a nonzero exit - status is found, we also debug print it for the user. - """ - success = True - - # https://cloud.google.com/life-sciences/docs/reference/rest/v2beta/Event - for event in status["metadata"]["events"]: - logger.debug(event["description"]) - - # Does it always result in fail for other failure reasons? - if "failed" in event: - success = False - action = event.get("failed") - logger.debug("{}: {}".format(action["code"], action["cause"])) - - elif "unexpectedExitStatus" in event: - action = event.get("unexpectedExitStatus") - - if action["exitStatus"] != 0: - success = False - - # Provide reason for the failure (desc includes exit code) - msg = "%s" % event["description"] - if "stderr" in action: - msg += ": %s" % action["stderr"] - logger.debug(msg) - - return success - - def _retry_request(self, request, timeout=2, attempts=3): - """ - The Google Python API client frequently has BrokenPipe errors. This - function takes a request, and executes it up to number of retry, - each time with a 2* increase in timeout. - - Parameters - ========== - request: the Google Cloud request that needs to be executed - timeout: time to sleep (in seconds) before trying again - attempts: remaining attempts, throw error when hit 0 - """ - import googleapiclient - - try: - return request.execute() - except BrokenPipeError as ex: - if attempts > 0: - time.sleep(timeout) - return self._retry_request( - request, timeout=timeout * 2, attempts=attempts - 1 - ) - raise ex - except googleapiclient.errors.HttpError as ex: - if attempts > 0: - time.sleep(timeout) - return self._retry_request( - request, timeout=timeout * 2, attempts=attempts - 1 - ) - log_verbose_traceback(ex) - raise ex - except Exception as ex: - if attempts > 0: - time.sleep(timeout) - return self._retry_request( - request, timeout=timeout * 2, attempts=attempts - 1 - ) - log_verbose_traceback(ex) - raise ex - - async def _wait_for_jobs(self): - """ - Wait for jobs to complete. This means requesting their status, - and then marking them as finished when a "done" parameter - shows up. Even for finished jobs, the status should still return - """ - import googleapiclient - - while True: - # always use self.lock to avoid race conditions - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - self.active_jobs = list() - still_running = list() - - # Loop through active jobs and act on status - for j in active_jobs: - # use self.status_rate_limiter to avoid too many API calls. - async with self.status_rate_limiter: - # https://cloud.google.com/life-sciences/docs/reference/rest/v2beta/projects.locations.operations/get - # Get status from projects.locations.operations/get - operations = self._api.projects().locations().operations() - request = operations.get(name=j.jobname) - logger.debug(f"Checking status for operation {j.jobid}") - - try: - status = self._retry_request(request) - except googleapiclient.errors.HttpError as ex: - # Operation name not found, even finished should be found - if ex.status == 404: - j.error_callback(j.job) - continue - - # Unpredictable server (500) error - elif ex.status == 500: - logger.error(ex["content"].decode("utf-8")) - j.error_callback(j.job) - - except WorkflowError as ex: - print_exception(ex, self.workflow.linemaps) - j.error_callback(j.job) - continue - - # The operation is done - if status.get("done", False) == True: - # Derive success/failure from status codes (prints too) - if self._job_was_successful(status): - j.callback(j.job) - else: - self.print_job_error(j.job, jobid=j.jobid) - j.error_callback(j.job) - - # The operation is still running - else: - still_running.append(j) - - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - await sleep() diff --git a/snakemake/executors/google_lifesciences_helper.py b/snakemake/executors/google_lifesciences_helper.py deleted file mode 100755 index 6be63a150..000000000 --- a/snakemake/executors/google_lifesciences_helper.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python - -# This is a helper script for the Google Life Sciences instance to be able to: -# 1. download a blob from storage, which is required at the onset of the Snakemake -# gls.py download -# workflow step to obtain the working directory. -# 2. Upload logs back to storage (or some specified directory of files) -# gls.py save -# gls.py save /google/logs/output source/logs - -import argparse - -from google.cloud import storage -from glob import glob -import sys -import os - - -def download_blob(bucket_name, source_blob_name, destination_file_name): - """Downloads a blob from the bucket.""" - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) - blob = bucket.blob(source_blob_name) - - blob.download_to_filename(destination_file_name) - - print(f"Blob {source_blob_name} downloaded to {destination_file_name}.") - - -def save_files(bucket_name, source_path, destination_path): - """given a directory path, save all files recursively to storage""" - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) - - # destination path should be stripped of path indicators too - bucket_name = bucket_name.strip("/") - destination_path = destination_path.strip("/") - - # These are fullpaths - filenames = get_source_files(source_path) - print("\nThe following files will be uploaded: %s" % "\n".join(filenames)) - - if not filenames: - print("Did not find any filenames under %s" % source_path) - - # Do the upload! - for filename in filenames: - # The relative path of the filename from the source path - relative_path = filename.replace(source_path, "", 1).strip("/") - # The path in storage includes relative path from destination_path - storage_path = os.path.join(destination_path, relative_path) - full_path = os.path.join(bucket_name, storage_path) - print(f"{filename} -> {full_path}") - blob = bucket.blob(storage_path) - print(f"Uploading {filename} to {full_path}") - blob.upload_from_filename(filename, content_type=".txt") - - -def get_source_files(source_path): - """Given a directory, return a listing of files to upload""" - filenames = [] - if not os.path.exists(source_path): - print("%s does not exist!" % source_path) - sys.exit(0) - - for x in os.walk(source_path): - for name in glob(os.path.join(x[0], "*")): - if not os.path.isdir(name): - filenames.append(name) - return filenames - - -def add_ending_slash(filename): - """Since we want to replace based on having an ending slash, ensure it's there""" - if not filename.endswith("/"): - filename = "%s/" % filename - return filename - - -def blob_commands(args): - if args.command == "download": - download_blob( - args.bucket_name, args.source_blob_name, args.destination_file_name - ) - elif args.command == "save": - save_files(args.bucket_name, args.source_path, args.destination_path) - - -def main(): - parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter - ) - - subparsers = parser.add_subparsers(dest="command") - - # Download file from storage - download_parser = subparsers.add_parser("download", help=download_blob.__doc__) - download_parser.add_argument("bucket_name", help="Your cloud storage bucket.") - download_parser.add_argument("source_blob_name") - download_parser.add_argument("destination_file_name") - - # Save logs to storage - save_parser = subparsers.add_parser("save", help=save_files.__doc__) - save_parser.add_argument("bucket_name", help="Your cloud storage bucket.") - save_parser.add_argument("source_path") - save_parser.add_argument("destination_path") - - args = parser.parse_args() - blob_commands(args) - - -if __name__ == "__main__": - main() diff --git a/snakemake/executors/local.py b/snakemake/executors/local.py new file mode 100644 index 000000000..a47c14062 --- /dev/null +++ b/snakemake/executors/local.py @@ -0,0 +1,475 @@ +__author__ = "Johannes Köster" +__copyright__ = "Copyright 2023, Johannes Köster" +__email__ = "johannes.koester@uni-due.de" +__license__ = "MIT" + + +import os +from pathlib import Path +import sys +import time +import shlex +import concurrent.futures +import subprocess +from functools import partial +from snakemake.executors import change_working_directory +from snakemake.settings import DeploymentMethod + +from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo +from snakemake_interface_executor_plugins.executors.real import RealExecutor +from snakemake_interface_executor_plugins.dag import DAGExecutorInterface +from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface +from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface +from snakemake_interface_executor_plugins.jobs import ( + ExecutorJobInterface, + SingleJobExecutorInterface, + GroupJobExecutorInterface, +) +from snakemake_interface_executor_plugins.settings import ExecMode +from snakemake_interface_executor_plugins import CommonSettings + +from snakemake.shell import shell +from snakemake.logging import logger +from snakemake.exceptions import print_exception, get_exception_origin +from snakemake.exceptions import format_error, RuleException, log_verbose_traceback +from snakemake.exceptions import ( + WorkflowError, + SpawnedJobError, + CacheMissException, +) + + +common_settings = CommonSettings( + non_local_exec=False, + implies_no_shared_fs=False, +) + + +_ProcessPoolExceptions = (KeyboardInterrupt,) +try: + from concurrent.futures.process import BrokenProcessPool + + _ProcessPoolExceptions = (KeyboardInterrupt, BrokenProcessPool) +except ImportError: + pass + + +class Executor(RealExecutor): + def __init__( + self, + workflow: WorkflowExecutorInterface, + logger: LoggerExecutorInterface, + ): + super().__init__( + workflow, + logger, + pass_envvar_declarations_to_cmd=False, + ) + + self.use_threads = self.workflow.execution_settings.use_threads + self.keepincomplete = self.workflow.execution_settings.keep_incomplete + cores = self.workflow.resource_settings.cores + + # Zero thread jobs do not need a thread, but they occupy additional workers. + # Hence we need to reserve additional workers for them. + workers = cores + 5 if cores is not None else 5 + self.workers = workers + self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.workers) + + def get_exec_mode(self): + return ExecMode.SUBPROCESS + + @property + def job_specific_local_groupid(self): + return False + + def get_job_exec_prefix(self, job: ExecutorJobInterface): + return f"cd {shlex.quote(self.workflow.workdir_init)}" + + def get_python_executable(self): + return sys.executable + + def get_envvar_declarations(self): + return "" + + def get_job_args(self, job: ExecutorJobInterface, **kwargs): + return f"{super().get_job_args(job, **kwargs)} --quiet" + + def run_job( + self, + job: ExecutorJobInterface, + ): + if job.is_group(): + # if we still don't have enough workers for this group, create a new pool here + missing_workers = max(len(job) - self.workers, 0) + if missing_workers: + self.workers += missing_workers + self.pool = concurrent.futures.ThreadPoolExecutor( + max_workers=self.workers + ) + + # the future waits for the entire group job + future = self.pool.submit(self.run_group_job, job) + else: + future = self.run_single_job(job) + + job_info = SubmittedJobInfo(job=job) + + future.add_done_callback(partial(self._callback, job_info)) + self.report_job_submission(job_info) + + def job_args_and_prepare(self, job: ExecutorJobInterface): + job.prepare() + + conda_env = ( + job.conda_env.address + if DeploymentMethod.CONDA + in self.workflow.deployment_settings.deployment_method + and job.conda_env + else None + ) + container_img = ( + job.container_img_path + if DeploymentMethod.APPTAINER + in self.workflow.deployment_settings.deployment_method + else None + ) + env_modules = ( + job.env_modules + if DeploymentMethod.ENV_MODULES + in self.workflow.deployment_settings.deployment_method + else None + ) + + benchmark = None + benchmark_repeats = job.benchmark_repeats or 1 + if job.benchmark is not None: + benchmark = str(job.benchmark) + return ( + job.rule, + job.input._plainstrings(), + job.output._plainstrings(), + job.params, + job.wildcards, + job.threads, + job.resources, + job.log._plainstrings(), + benchmark, + benchmark_repeats, + conda_env, + container_img, + self.workflow.deployment_settings.apptainer_args, + env_modules, + DeploymentMethod.APPTAINER + in self.workflow.deployment_settings.deployment_method, + self.workflow.linemaps, + self.workflow.execution_settings.debug, + self.workflow.execution_settings.cleanup_scripts, + job.shadow_dir, + job.jobid, + self.workflow.execution_settings.edit_notebook + if self.dag.is_edit_notebook_job(job) + else None, + self.workflow.conda_base_path, + job.rule.basedir, + self.workflow.sourcecache.runtime_cache_path, + ) + + def run_single_job(self, job: SingleJobExecutorInterface): + if ( + self.use_threads + or (not job.is_shadow and not job.is_run) + or job.is_template_engine + ): + future = self.pool.submit( + self.cached_or_run, job, run_wrapper, *self.job_args_and_prepare(job) + ) + else: + # run directive jobs are spawned into subprocesses + future = self.pool.submit(self.cached_or_run, job, self.spawn_job, job) + return future + + def run_group_job(self, job: GroupJobExecutorInterface): + """Run a pipe or service group job. + + This lets all items run simultaneously.""" + # we only have to consider pipe or service groups because in local running mode, + # these are the only groups that will occur + + service_futures = [self.run_single_job(j) for j in job if j.is_service] + normal_futures = [self.run_single_job(j) for j in job if not j.is_service] + + while normal_futures: + for f in list(normal_futures): + if f.done(): + logger.debug("Job inside group is finished.") + ex = f.exception() + if ex is not None: + logger.debug(f"Job inside group failed with exception {ex}.") + # kill all shell commands of the other group jobs + # there can be only shell commands because the + # run directive is not allowed for pipe jobs + for j in job: + shell.kill(j.jobid) + raise ex + normal_futures.remove(f) + time.sleep(1) + + if service_futures: + # terminate all service jobs since all consumers are done + for j in job: + if j.is_service: + logger.info( + f"Terminating service job {j.jobid} since all consuming jobs are finished." + ) + shell.terminate(j.jobid) + logger.info( + f"Service job {j.jobid} has been successfully terminated." + ) + + def spawn_job(self, job: SingleJobExecutorInterface): + cmd = self.format_job_exec(job) + + try: + subprocess.check_call(cmd, shell=True) + except subprocess.CalledProcessError: + raise SpawnedJobError() + + def cached_or_run(self, job: SingleJobExecutorInterface, run_func, *args): + """ + Either retrieve result from cache, or run job with given function. + """ + cache_mode = self.workflow.get_cache_mode(job.rule) + try: + if cache_mode: + self.workflow.output_file_cache.fetch(job, cache_mode) + return + except CacheMissException: + pass + run_func(*args) + if cache_mode: + self.workflow.output_file_cache.store(job, cache_mode) + + def shutdown(self): + self.pool.shutdown() + + def cancel(self): + self.pool.shutdown() + + def _callback(self, job_info: SubmittedJobInfo, future): + try: + ex = future.exception() + if ex is not None: + raise ex + self.report_job_success(job_info) + except _ProcessPoolExceptions: + self.handle_job_error(job_info.job) + # no error callback, just silently ignore the interrupt as the main scheduler is also killed + except SpawnedJobError: + # don't print error message, this is done by the spawned subprocess + self.report_job_error(job_info) + except BaseException as ex: + if self.workflow.output_settings.verbose or ( + not job_info.job.is_group() and not job_info.job.is_shell + ): + print_exception(ex, self.workflow.linemaps) + self.report_job_error(job_info) + + def handle_job_error(self, job: ExecutorJobInterface): + super().handle_job_error(job) + if not self.keepincomplete: + job.cleanup() + self.workflow.persistence.cleanup(job) + + @property + def cores(self): + return self.workflow.resource_settings.cores + + +def run_wrapper( + job_rule, + input, + output, + params, + wildcards, + threads, + resources, + log, + benchmark, + benchmark_repeats, + conda_env, + container_img, + singularity_args, + env_modules, + use_singularity, + linemaps, + debug, + cleanup_scripts, + shadow_dir, + jobid, + edit_notebook, + conda_base_path, + basedir, + runtime_sourcecache_path, +): + """ + Wrapper around the run method that handles exceptions and benchmarking. + + Arguments + job_rule -- the ``job.rule`` member + input -- a list of input files + output -- a list of output files + wildcards -- so far processed wildcards + threads -- usable threads + log -- a list of log files + shadow_dir -- optional shadow directory root + """ + # get shortcuts to job_rule members + run = job_rule.run_func + rule = job_rule.name + is_shell = job_rule.shellcmd is not None + + if os.name == "posix" and debug: + sys.stdin = open("/dev/stdin") + + if benchmark is not None: + from snakemake.benchmark import ( + BenchmarkRecord, + benchmarked, + write_benchmark_records, + ) + + # Change workdir if shadow defined and not using singularity. + # Otherwise, we do the change from inside the container. + passed_shadow_dir = None + if use_singularity and container_img: + passed_shadow_dir = shadow_dir + shadow_dir = None + + try: + with change_working_directory(shadow_dir): + if benchmark: + bench_records = [] + for bench_iteration in range(benchmark_repeats): + # Determine whether to benchmark this process or do not + # benchmarking at all. We benchmark this process unless the + # execution is done through the ``shell:``, ``script:``, or + # ``wrapper:`` stanza. + is_sub = ( + job_rule.shellcmd + or job_rule.script + or job_rule.wrapper + or job_rule.cwl + ) + if is_sub: + # The benchmarking through ``benchmarked()`` is started + # in the execution of the shell fragment, script, wrapper + # etc, as the child PID is available there. + bench_record = BenchmarkRecord() + run( + input, + output, + params, + wildcards, + threads, + resources, + log, + rule, + conda_env, + container_img, + singularity_args, + use_singularity, + env_modules, + bench_record, + jobid, + is_shell, + bench_iteration, + cleanup_scripts, + passed_shadow_dir, + edit_notebook, + conda_base_path, + basedir, + runtime_sourcecache_path, + ) + else: + # The benchmarking is started here as we have a run section + # and the generated Python function is executed in this + # process' thread. + with benchmarked() as bench_record: + run( + input, + output, + params, + wildcards, + threads, + resources, + log, + rule, + conda_env, + container_img, + singularity_args, + use_singularity, + env_modules, + bench_record, + jobid, + is_shell, + bench_iteration, + cleanup_scripts, + passed_shadow_dir, + edit_notebook, + conda_base_path, + basedir, + runtime_sourcecache_path, + ) + # Store benchmark record for this iteration + bench_records.append(bench_record) + else: + run( + input, + output, + params, + wildcards, + threads, + resources, + log, + rule, + conda_env, + container_img, + singularity_args, + use_singularity, + env_modules, + None, + jobid, + is_shell, + None, + cleanup_scripts, + passed_shadow_dir, + edit_notebook, + conda_base_path, + basedir, + runtime_sourcecache_path, + ) + except (KeyboardInterrupt, SystemExit) as e: + # Re-raise the keyboard interrupt in order to record an error in the + # scheduler but ignore it + raise e + except BaseException as ex: + # this ensures that exception can be re-raised in the parent thread + origin = get_exception_origin(ex, linemaps) + if origin is not None: + log_verbose_traceback(ex) + lineno, file = origin + raise RuleException( + format_error( + ex, lineno, linemaps=linemaps, snakefile=file, show_traceback=True + ) + ) + else: + # some internal bug, just reraise + raise ex + + if benchmark is not None: + try: + write_benchmark_records(bench_records, benchmark) + except BaseException as ex: + raise WorkflowError(ex) diff --git a/snakemake/executors/slurm/slurm_jobstep.py b/snakemake/executors/slurm/slurm_jobstep.py deleted file mode 100644 index 1c19404a8..000000000 --- a/snakemake/executors/slurm/slurm_jobstep.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import subprocess - -from snakemake_interface_executor_plugins.dag import DAGExecutorInterface -from snakemake_interface_executor_plugins.jobs import ExecutorJobInterface -from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface -from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface - - -class SlurmJobstepExecutor(RemoteExecutor): - """ - executes SLURM jobsteps and is *only* instaniated in - a SLURM job context - """ - - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - max_status_checks_per_second=0.5, - ): - # overwrite the command to execute a single snakemake job if necessary - # exec_job = "..." - - super().__init__( - workflow, - dag, - stats, - logger, - None, - max_status_checks_per_second=max_status_checks_per_second, - disable_envvar_declarations=True, - ) - - # These environment variables are set by SLURM. - self.mem_per_node = os.getenv("SLURM_MEM_PER_NODE") - self.cpus_on_node = os.getenv("SLURM_CPUS_ON_NODE") - self.jobid = os.getenv("SLURM_JOB_ID") - - async def _wait_for_jobs(self): - pass - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - jobsteps = dict() - - # TODO revisit special handling for group job levels via srun at a later stage - # if job.is_group(): - - # def get_call(level_job, aux=""): - # # we need this calculation, because of srun's greediness and - # # SLURM's limits: it is not able to limit the memory if we divide the job - # # per CPU by itself. - - # level_mem = level_job.resources.get("mem_mb") - # if isinstance(level_mem, TBDString): - # level_mem = 100 - - # mem_per_cpu = max(level_mem // level_job.threads, 100) - # exec_job = self.format_job_exec(level_job) - - # # Note: The '--exlusive' flag is a prevention for triggered job steps within an allocation - # # to oversubscribe within a given c-group. As we are dealing only with smp software - # # the '--ntasks' is explicitly set to 1 by '-n1' per group job(step). - # return ( - # f"srun -J {job.groupid} --jobid {self.jobid}" - # f" --mem-per-cpu {mem_per_cpu} -c {level_job.threads}" - # f" --exclusive -n 1 {aux} {exec_job}" - # ) - - # for level in list(job.toposorted): - # # we need to ensure order - any: - # level_list = list(level) - # for level_job in level_list[:-1]: - # jobsteps[level_job] = subprocess.Popen( - # get_call(level_job), shell=True - # ) - # # now: the last one - # # this way, we ensure that level jobs depending on the current level get started - # jobsteps[level_list[-1]] = subprocess.Popen( - # get_call(level_list[-1], aux="--dependency=singleton"), shell=True - # ) - - if "mpi" in job.resources.keys(): - # MPI job: - # No need to prepend `srun`, as this will happen inside of the job's shell command or script (!). - # The following call invokes snakemake, which in turn takes care of all auxilliary work around the actual command - # like remote file support, benchmark setup, error handling, etc. - # AND there can be stuff around the srun call within the job, like any commands which should be executed before. - call = self.format_job_exec(job) - else: - # SMP job, execute snakemake with srun, to ensure proper placing of threaded executables within the c-group - # The -n1 is important to avoid that srun executes the given command multiple times, depending on the relation between - # cpus per task and the number of CPU cores. - call = f"srun -n1 --cpu-bind=q {self.format_job_exec(job)}" - - # this dict is to support the to be implemented feature of oversubscription in "ordinary" group jobs. - jobsteps[job] = subprocess.Popen(call, shell=True) - - # wait until all steps are finished - error = False - for job, proc in jobsteps.items(): - if proc.wait() != 0: - self.print_job_error(job) - error = True - if error: - error_callback(job) - else: - callback(job) diff --git a/snakemake/executors/slurm/slurm_submit.py b/snakemake/executors/slurm/slurm_submit.py deleted file mode 100644 index 182846344..000000000 --- a/snakemake/executors/slurm/slurm_submit.py +++ /dev/null @@ -1,457 +0,0 @@ -from collections import namedtuple -from io import StringIO -from fractions import Fraction -import csv -import os -import time -import shlex -import subprocess -import uuid - -from snakemake_interface_executor_plugins.dag import DAGExecutorInterface -from snakemake_interface_executor_plugins.jobs import ExecutorJobInterface -from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor -from snakemake_interface_executor_plugins.persistence import StatsExecutorInterface -from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface - -from snakemake.logging import logger -from snakemake.exceptions import WorkflowError -from snakemake.common import async_lock - -SlurmJob = namedtuple("SlurmJob", "job jobid callback error_callback slurm_logfile") - - -def get_account(): - """ - tries to deduce the acccount from recent jobs, - returns None, if none is found - """ - cmd = f'sacct -nu "{os.environ["USER"]}" -o Account%256 | head -n1' - try: - sacct_out = subprocess.check_output( - cmd, shell=True, text=True, stderr=subprocess.PIPE - ) - return sacct_out.strip() - except subprocess.CalledProcessError as e: - logger.warning( - f"No account was given, not able to get a SLURM account via sacct: {e.stderr}" - ) - return None - - -def test_account(account): - """ - tests whether the given account is registered, raises an error, if not - """ - cmd = f'sacctmgr -n -s list user "{os.environ["USER"]}" format=account%256' - try: - accounts = subprocess.check_output( - cmd, shell=True, text=True, stderr=subprocess.PIPE - ) - except subprocess.CalledProcessError as e: - raise WorkflowError( - f"Unable to test the validity of the given or guessed SLURM account '{account}' with sacctmgr: {e.stderr}" - ) - - accounts = accounts.split() - - if account not in accounts: - raise WorkflowError( - f"The given account {account} appears to be invalid. Available accounts:\n{', '.join(accounts)}" - ) - - -def get_default_partition(job): - """ - if no partition is given, checks whether a fallback onto a default partition is possible - """ - try: - out = subprocess.check_output( - r"sinfo -o %P", shell=True, text=True, stderr=subprocess.PIPE - ) - except subprocess.CalledProcessError as e: - raise WorkflowError( - f"Failed to run sinfo for retrieval of cluster partitions: {e.stderr}" - ) - for partition in out.split(): - # a default partition is marked with an asterisk, but this is not part of the name - if "*" in partition: - # the decode-call is necessary, because the output of sinfo is bytes - return partition.replace("*", "") - logger.warning( - f"No partition was given for rule '{job}', and unable to find a default partition." - " Trying to submit without partition information." - " You may want to invoke snakemake with --default-resources 'slurm_partition='." - ) - return "" - - -class SlurmExecutor(RemoteExecutor): - """ - the SLURM_Executor abstracts execution on SLURM - clusters using snakemake resource string - """ - - def __init__( - self, - workflow: WorkflowExecutorInterface, - dag: DAGExecutorInterface, - stats: StatsExecutorInterface, - logger: LoggerExecutorInterface, - jobname="snakejob_{name}_{jobid}", - max_status_checks_per_second=0.5, - ): - super().__init__( - workflow, - dag, - stats, - logger, - None, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, - ) - self.run_uuid = str(uuid.uuid4()) - self._fallback_account_arg = None - self._fallback_partition = None - - def additional_general_args(self): - # we need to set -j to 1 here, because the behaviour - # of snakemake is to submit all jobs at once, otherwise. - # However, the SLURM Executor is supposed to submit jobs - # one after another, so we need to set -j to 1 for the - # JobStep Executor, which in turn handles the launch of - # SLURM jobsteps. - return [" --slurm-jobstep", "--jobs 1"] - - def cancel(self): - # Jobs are collected to reduce load on slurmctld - jobids = " ".join([job.jobid for job in self.active_jobs]) - if len(jobids) > 0: - try: - # timeout set to 60, because a scheduler cycle usually is - # about 30 sec, but can be longer in extreme cases. - # Under 'normal' circumstances, 'scancel' is executed in - # virtually no time. - subprocess.check_output( - f"scancel {jobids}", - text=True, - shell=True, - timeout=60, - stderr=subprocess.PIPE, - ) - except subprocess.TimeoutExpired: - logger.warning("Unable to cancel jobs within a minute.") - self.shutdown() - - def get_account_arg(self, job: ExecutorJobInterface): - """ - checks whether the desired account is valid, - returns a default account, if applicable - else raises an error - implicetly. - """ - if job.resources.get("slurm_account"): - # here, we check whether the given or guessed account is valid - # if not, a WorkflowError is raised - test_account(job.resources.slurm_account) - return f" -A {job.resources.slurm_account}" - else: - if self._fallback_account_arg is None: - logger.warning("No SLURM account given, trying to guess.") - account = get_account() - if account: - logger.warning(f"Guessed SLURM account: {account}") - self._fallback_account_arg = f" -A {account}" - else: - logger.warning( - "Unable to guess SLURM account. Trying to proceed without." - ) - self._fallback_account_arg = ( - "" # no account specific args for sbatch - ) - return self._fallback_account_arg - - def get_partition_arg(self, job: ExecutorJobInterface): - """ - checks whether the desired partition is valid, - returns a default partition, if applicable - else raises an error - implicetly. - """ - if job.resources.get("slurm_partition"): - partition = job.resources.slurm_partition - else: - if self._fallback_partition is None: - self._fallback_partition = get_default_partition(job) - partition = self._fallback_partition - if partition: - return f" -p {partition}" - else: - return "" - - def run( - self, - job: ExecutorJobInterface, - callback=None, - submit_callback=None, - error_callback=None, - ): - super()._run(job) - jobid = job.jobid - - log_folder = f"group_{job.name}" if job.is_group() else f"rule_{job.name}" - - slurm_logfile = f".snakemake/slurm_logs/{log_folder}/%j.log" - os.makedirs(os.path.dirname(slurm_logfile), exist_ok=True) - - # generic part of a submission string: - # we use a run_uuid as the job-name, to allow `--name`-based - # filtering in the job status checks (`sacct --name` and `squeue --name`) - call = f"sbatch --job-name {self.run_uuid} -o {slurm_logfile} --export=ALL" - - call += self.get_account_arg(job) - call += self.get_partition_arg(job) - - if job.resources.get("runtime"): - call += f" -t {job.resources.runtime}" - else: - logger.warning( - "No wall time information given. This might or might not work on your cluster. " - "If not, specify the resource runtime in your rule or as a reasonable " - "default via --default-resources." - ) - - if job.resources.get("constraint"): - call += f" -C {job.resources.constraint}" - if job.resources.get("mem_mb_per_cpu"): - call += f" --mem-per-cpu {job.resources.mem_mb_per_cpu}" - elif job.resources.get("mem_mb"): - call += f" --mem {job.resources.mem_mb}" - else: - logger.warning( - "No job memory information ('mem_mb' or 'mem_mb_per_cpu') is given " - "- submitting without. This might or might not work on your cluster." - ) - - # MPI job - if job.resources.get("mpi", False): - if job.resources.get("nodes", False): - call += f" --nodes={job.resources.get('nodes', 1)}" - if job.resources.get("tasks", False): - call += f" --ntasks={job.resources.get('tasks', 1)}" - - cpus_per_task = job.threads - if job.resources.get("cpus_per_task"): - if not isinstance(cpus_per_task, int): - raise WorkflowError( - f"cpus_per_task must be an integer, but is {cpus_per_task}" - ) - cpus_per_task = job.resources.cpus_per_task - # ensure that at least 1 cpu is requested - # because 0 is not allowed by slurm - cpus_per_task = max(1, cpus_per_task) - call += f" --cpus-per-task={cpus_per_task}" - - if job.resources.get("slurm_extra"): - call += f" {job.resources.slurm_extra}" - - exec_job = self.format_job_exec(job) - # ensure that workdir is set correctly - # use short argument as this is the same in all slurm versions - # (see https://github.com/snakemake/snakemake/issues/2014) - call += f" -D {self.workflow.workdir_init}" - # and finally the job to execute with all the snakemake parameters - call += f" --wrap={shlex.quote(exec_job)}" - - logger.debug(f"sbatch call: {call}") - try: - out = subprocess.check_output( - call, shell=True, text=True, stderr=subprocess.STDOUT - ).strip() - except subprocess.CalledProcessError as e: - raise WorkflowError( - f"SLURM job submission failed. The error message was {e.output}" - ) - - slurm_jobid = out.split(" ")[-1] - slurm_logfile = slurm_logfile.replace("%j", slurm_jobid) - logger.info( - f"Job {jobid} has been submitted with SLURM jobid {slurm_jobid} (log: {slurm_logfile})." - ) - self.active_jobs.append( - SlurmJob(job, slurm_jobid, callback, error_callback, slurm_logfile) - ) - - async def job_stati(self, command): - """obtain SLURM job status of all submitted jobs with sacct - - Keyword arguments: - command -- a slurm command that returns one line for each job with: - "|" - """ - try: - time_before_query = time.time() - command_res = subprocess.check_output( - command, text=True, shell=True, stderr=subprocess.PIPE - ) - query_duration = time.time() - time_before_query - logger.debug( - f"The job status was queried with command: {command}\n" - f"It took: {query_duration} seconds\n" - f"The output is:\n'{command_res}'\n" - ) - res = { - # We split the second field in the output, as the State field - # could contain info beyond the JOB STATE CODE according to: - # https://slurm.schedmd.com/sacct.html#OPT_State - entry[0]: entry[1].split(sep=None, maxsplit=1)[0] - for entry in csv.reader(StringIO(command_res), delimiter="|") - } - except subprocess.CalledProcessError as e: - - def fmt_err(err_type, err_msg): - if err_msg is not None: - return f" {err_type} error: {err_msg.strip()}" - else: - return "" - - logger.error( - f"The job status query failed with command: {command}\n" - f"Error message: {e.stderr.strip()}\n" - ) - pass - - return (res, query_duration) - - async def _wait_for_jobs(self): - from throttler import Throttler - - # busy wait on job completion - # This is only needed if your backend does not allow to use callbacks - # for obtaining job status. - fail_stati = ( - "BOOT_FAIL", - "CANCELLED", - "DEADLINE", - "FAILED", - "NODE_FAIL", - "OUT_OF_MEMORY", - "PREEMPTED", - "TIMEOUT", - "ERROR", - ) - # intialize time to sleep in seconds - MIN_SLEEP_TIME = 20 - # Cap sleeping time between querying the status of all active jobs: - # If `AccountingStorageType`` for `sacct` is set to `accounting_storage/none`, - # sacct will query `slurmctld` (instead of `slurmdbd`) and this in turn can - # rely on default config, see: https://stackoverflow.com/a/46667605 - # This config defaults to `MinJobAge=300`, which implies that jobs will be - # removed from `slurmctld` within 6 minutes of finishing. So we're conservative - # here, with half that time - MAX_SLEEP_TIME = 180 - sleepy_time = MIN_SLEEP_TIME - # only start checking statuses after bit -- otherwise no jobs are in slurmdbd, yet - time.sleep(2 * sleepy_time) - while True: - # Initialize all query durations to specified - # 5 times the status_rate_limiter, to hit exactly - # the status_rate_limiter for the first async below. - # It is dynamically updated afterwards. - sacct_query_duration = ( - self.status_rate_limiter._period / self.status_rate_limiter._rate_limit - ) * 5 - # keep track of jobs already seen in sacct accounting - active_jobs_seen_by_sacct = set() - # always use self.lock to avoid race conditions - async with async_lock(self.lock): - if not self.wait: - return - active_jobs = self.active_jobs - active_jobs_ids = {j.jobid for j in active_jobs} - self.active_jobs = list() - still_running = list() - STATUS_ATTEMPTS = 5 - # this code is inspired by the snakemake profile: - # https://github.com/Snakemake-Profiles/slurm/blob/a0e559e1eca607d0bd26c15f94d609e6905f8a8e/%7B%7Bcookiecutter.profile_name%7D%7D/slurm-status.py#L27 - for i in range(STATUS_ATTEMPTS): - # use self.status_rate_limiter and adaptive query - # timing to avoid too many API calls in retries. - rate_limit = Fraction( - min( - self.status_rate_limiter._rate_limit - / self.status_rate_limiter._period, - # if slurmdbd (sacct) is strained and slow, reduce the query frequency - (1 / sacct_query_duration) / 5, - ) - ).limit_denominator() - missing_sacct_status = set() - async with Throttler( - rate_limit=rate_limit.numerator, - period=rate_limit.denominator, - ): - (status_of_jobs, sacct_query_duration) = await self.job_stati( - # -X: only show main job, no substeps - f"sacct -X --parsable2 --noheader --format=JobIdRaw,State --name {self.run_uuid}" - ) - logger.debug(f"status_of_jobs after sacct is: {status_of_jobs}") - # only take jobs that are still active - active_jobs_ids_with_current_sacct_status = ( - set(status_of_jobs.keys()) & active_jobs_ids - ) - active_jobs_seen_by_sacct = ( - active_jobs_seen_by_sacct - | active_jobs_ids_with_current_sacct_status - ) - logger.debug( - f"active_jobs_seen_by_sacct are: {active_jobs_seen_by_sacct}" - ) - missing_sacct_status = ( - active_jobs_seen_by_sacct - - active_jobs_ids_with_current_sacct_status - ) - if not missing_sacct_status: - break - if i >= STATUS_ATTEMPTS - 1: - logger.warning( - f"Unable to get the status of all active_jobs that should be in slurmdbd, even after {STATUS_ATTEMPTS} attempts.\n" - f"The jobs with the following slurm job ids were previously seen by sacct, but sacct doesn't report them any more:\n" - f"{missing_sacct_status}\n" - f"Please double-check with your slurm cluster administrator, that slurmdbd job accounting is properly set up.\n" - ) - for j in active_jobs: - # the job probably didn't make it into slurmdbd yet, so - # `sacct` doesn't return it - if not j.jobid in status_of_jobs: - # but the job should still be queueing or running and - # appear in slurmdbd (and thus `sacct` output) later - still_running.append(j) - continue - status = status_of_jobs[j.jobid] - if status == "COMPLETED": - j.callback(j.job) - active_jobs_seen_by_sacct.remove(j.jobid) - elif status == "UNKNOWN": - # the job probably does not exist anymore, but 'sacct' did not work - # so we assume it is finished - j.callback(j.job) - active_jobs_seen_by_sacct.remove(j.jobid) - elif status in fail_stati: - self.print_job_error( - j.job, - msg=f"SLURM-job '{j.jobid}' failed, SLURM status is: '{status}'", - aux_logs=[j.slurm_logfile], - ) - j.error_callback(j.job) - active_jobs_seen_by_sacct.remove(j.jobid) - else: # still running? - still_running.append(j) - - # no jobs finished in the last query period - if not active_jobs_ids - {j.jobid for j in still_running}: - # sleep a little longer, but never too long - sleepy_time = min(sleepy_time + 10, MAX_SLEEP_TIME) - else: - sleepy_time = MIN_SLEEP_TIME - async with async_lock(self.lock): - self.active_jobs.extend(still_running) - time.sleep(sleepy_time) diff --git a/snakemake/executors/touch.py b/snakemake/executors/touch.py new file mode 100644 index 000000000..b5afb0736 --- /dev/null +++ b/snakemake/executors/touch.py @@ -0,0 +1,73 @@ +__author__ = "Johannes Köster" +__copyright__ = "Copyright 2022, Johannes Köster" +__email__ = "johannes.koester@uni-due.de" +__license__ = "MIT" + +import time + +from snakemake_interface_executor_plugins.executors.real import RealExecutor +from snakemake_interface_executor_plugins.dag import DAGExecutorInterface +from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface +from snakemake_interface_executor_plugins.logging import LoggerExecutorInterface +from snakemake_interface_executor_plugins.jobs import ( + ExecutorJobInterface, +) +from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo +from snakemake_interface_executor_plugins import CommonSettings + +from snakemake.exceptions import print_exception + + +common_settings = CommonSettings( + non_local_exec=False, + implies_no_shared_fs=False, + touch_exec=True, +) + + +class Executor(RealExecutor): + def __init__( + self, + workflow: WorkflowExecutorInterface, + logger: LoggerExecutorInterface, + ): + super().__init__( + workflow, + logger, + pass_envvar_declarations_to_cmd=False, + ) + + def run_job( + self, + job: ExecutorJobInterface, + ): + job_info = SubmittedJobInfo(job=job) + try: + # Touching of output files will be done by handle_job_success + time.sleep(0.1) + self.report_job_submission(job_info) + self.report_job_success(job_info) + except OSError as ex: + print_exception(ex, self.workflow.linemaps) + self.report_job_error(job_info) + + def get_exec_mode(self): + raise NotImplementedError() + + def handle_job_success(self, job: ExecutorJobInterface): + super().handle_job_success(job, ignore_missing_output=True) + + def cancel(self): + # nothing to do + pass + + def shutdown(self): + # nothing to do + pass + + def get_python_executable(self): + raise NotImplementedError() + + @property + def cores(self): + return self.workflow.resource_settings.cores diff --git a/snakemake/io.py b/snakemake/io.py index 638142efa..7d3e2caf0 100755 --- a/snakemake/io.py +++ b/snakemake/io.py @@ -1386,87 +1386,6 @@ def replace_constraint(match: re.Match): return updated -def split_git_path(path): - file_sub = re.sub(r"^git\+file:/+", "/", path) - (file_path, version) = file_sub.split("@") - file_path = os.path.realpath(file_path) - root_path = get_git_root(file_path) - if file_path.startswith(root_path): - file_path = file_path[len(root_path) :].lstrip("/") - return (root_path, file_path, version) - - -def get_git_root(path): - """ - Args: - path: (str) Path a to a directory/file that is located inside the repo - Returns: - path to the root folder for git repo - """ - import git - - try: - git_repo = git.Repo(path, search_parent_directories=True) - return git_repo.git.rev_parse("--show-toplevel") - except git.exc.NoSuchPathError: - tail, head = os.path.split(path) - return get_git_root_parent_directory(tail, path) - - -def get_git_root_parent_directory(path, input_path): - """ - This function will recursively go through parent directories until a git - repository is found or until no parent directories are left, in which case - an error will be raised. This is needed when providing a path to a - file/folder that is located on a branch/tag not currently checked out. - - Args: - path: (str) Path a to a directory that is located inside the repo - input_path: (str) origin path, used when raising WorkflowError - Returns: - path to the root folder for git repo - """ - import git - - try: - git_repo = git.Repo(path, search_parent_directories=True) - return git_repo.git.rev_parse("--show-toplevel") - except git.exc.NoSuchPathError: - tail, head = os.path.split(path) - if tail is None: - raise WorkflowError( - f"Neither provided git path ({input_path}) " - + "or parent directories contain a valid git repo." - ) - else: - return get_git_root_parent_directory(tail, input_path) - - -def git_content(git_file): - """ - This function will extract a file from a git repository, one located on - the filesystem. - The expected format is git+file:///path/to/your/repo/path_to_file@version - - Args: - env_file (str): consist of path to repo, @, version, and file information - Ex: git+file:///home/smeds/snakemake-wrappers/bio/fastqc/wrapper.py@0.19.3 - Returns: - file content or None if the expected format isn't meet - """ - import git - - if git_file.startswith("git+file:"): - (root_path, file_path, version) = split_git_path(git_file) - return git.Repo(root_path).git.show(f"{version}:{file_path}") - else: - raise WorkflowError( - "Provided git path ({}) doesn't meet the " - "expected format:".format(git_file) + ", expected format is " - "git+file://PATH_TO_REPO/PATH_TO_FILE_INSIDE_REPO@VERSION" - ) - - def strip_wildcard_constraints(pattern): """Return a string that does not contain any wildcard constraints.""" if is_callable(pattern): @@ -1682,45 +1601,6 @@ class Log(Namedlist): pass -def _load_configfile(configpath_or_obj, filetype="Config"): - "Tries to load a configfile first as JSON, then as YAML, into a dict." - import yaml - - if isinstance(configpath_or_obj, str) or isinstance(configpath_or_obj, Path): - obj = open(configpath_or_obj, encoding="utf-8") - else: - obj = configpath_or_obj - - try: - with obj as f: - try: - return json.load(f, object_pairs_hook=collections.OrderedDict) - except ValueError: - f.seek(0) # try again - try: - import yte - - return yte.process_yaml(f, require_use_yte=True) - except yaml.YAMLError: - raise WorkflowError( - f"{filetype} file is not valid JSON or YAML. " - "In case of YAML, make sure to not mix " - "whitespace and tab indentation." - ) - except FileNotFoundError: - raise WorkflowError(f"{filetype} file {configpath_or_obj} not found.") - - -def load_configfile(configpath): - "Loads a JSON or YAML configfile as a dict, then checks that it's a dict." - config = _load_configfile(configpath) - if not isinstance(config, dict): - raise WorkflowError( - "Config file must be given as JSON or YAML with keys at top level." - ) - return config - - ##### Wildcard pumping detection ##### diff --git a/snakemake/jobs.py b/snakemake/jobs.py index 2ff8a1f22..a2906db9a 100644 --- a/snakemake/jobs.py +++ b/snakemake/jobs.py @@ -15,6 +15,7 @@ from operator import attrgetter from typing import Optional from abc import ABC, abstractmethod +from snakemake.settings import DeploymentMethod from snakemake_interface_executor_plugins.utils import lazy_property from snakemake_interface_executor_plugins.jobs import ( @@ -88,7 +89,7 @@ def has_products(self, include_logfiles=True): def _get_scheduler_resources(job): if job._scheduler_resources is None: - if job.dag.workflow.run_local or job.is_local: + if job.dag.workflow.local_exec or job.is_local: job._scheduler_resources = job.resources else: job._scheduler_resources = Resources( @@ -166,7 +167,6 @@ class Job(AbstractJob, SingleJobExecutorInterface): "temp_output", "protected_output", "touch_output", - "subworkflow_input", "_hash", "_attempt", "_group", @@ -233,7 +233,6 @@ def __init__( self.dynamic_output, self.dynamic_input = set(), set() self.temp_output, self.protected_output = set(), set() self.touch_output = set() - self.subworkflow_input = dict() for f in self.output: f_ = output_mapping[f] if f_ in self.rule.dynamic_output: @@ -248,20 +247,6 @@ def __init__( f_ = input_mapping[f] if f_ in self.rule.dynamic_input: self.dynamic_input.add(f) - if f_ in self.rule.subworkflow_input: - self.subworkflow_input[f] = self.rule.subworkflow_input[f_] - elif "subworkflow" in f.flags: - sub = f.flags["subworkflow"] - if f in self.subworkflow_input: - other = self.subworkflow_input[f] - if sub != other: - raise WorkflowError( - "The input file {} is ambiguously " - "associated with two subworkflows {} " - "and {}.".format(f, sub, other), - rule=self.rule, - ) - self.subworkflow_input[f] = sub @property def is_updated(self): @@ -322,8 +307,8 @@ def updated(self): group = self.dag.get_job_group(self) groupid = None if group is None: - if self.dag.workflow.run_local or self.is_local: - groupid = self.dag.workflow.local_groupid + if self.dag.workflow.local_exec or self.is_local: + groupid = self.dag.workflow.group_settings.local_groupid else: groupid = group.jobid @@ -415,7 +400,7 @@ def attempt(self, attempt): @property def resources(self): if self._resources is None: - if self.dag.workflow.run_local or self.is_local: + if self.dag.workflow.local_exec or self.is_local: skip_evaluation = None else: # tmpdir should be evaluated in the context of the actual execution @@ -482,7 +467,11 @@ def is_containerized(self): @property def container_img(self): - if self.dag.workflow.use_singularity and self.container_img_url: + if ( + DeploymentMethod.APPTAINER + in self.dag.workflow.deployment_settings.deployment_method + and self.container_img_url + ): return self.dag.container_imgs[self.container_img_url] return None @@ -637,10 +626,7 @@ def dynamic_wildcards(self): @property def missing_input(self): """Return missing input files.""" - # omit file if it comes from a subworkflow - return set( - f for f in self.input if not f.exists and not f in self.subworkflow_input - ) + return set(f for f in self.input if not f.exists) @property def existing_remote_input(self): @@ -867,11 +853,11 @@ def prepare(self): self.benchmark.prepare() # wait for input files, respecting keep_remote_local - force_stay_on_remote = not self.dag.keep_remote_local + force_stay_on_remote = not self.dag.workflow.storage_settings.keep_remote_local wait_for_files( self.input, force_stay_on_remote=force_stay_on_remote, - latency_wait=self.dag.workflow.latency_wait, + latency_wait=self.dag.workflow.execution_settings.latency_wait, ) if not self.is_shadow: @@ -996,7 +982,6 @@ def format_wildcards(self, string, **variables): resources=self.resources, log=self.log, jobid=self.jobid, - version=self.rule.version, name=self.name, rule=self.rule.name, rulename=self.rule.name, @@ -1121,8 +1106,8 @@ def log_error( ): logger.job_error(**self.get_log_error_info(msg, indent, aux_logs, **kwargs)) - def register(self): - self.dag.workflow.persistence.started(self) + def register(self, external_jobid: Optional[str] = None): + self.dag.workflow.persistence.started(self, external_jobid) def get_wait_for_files(self): wait_for_files = [] @@ -1134,7 +1119,8 @@ def get_wait_for_files(self): if self.shadow_dir: wait_for_files.append(self.shadow_dir) if ( - self.dag.workflow.use_conda + DeploymentMethod.CONDA + in self.dag.workflow.deployment_settings.deployment_method and self.conda_env and not self.conda_env.is_named and not self.conda_env.is_containerized @@ -1162,22 +1148,21 @@ def postprocess( handle_touch=True, error=False, ignore_missing_output=False, - assume_shared_fs=True, - latency_wait=None, - keep_metadata=True, ): if self.dag.is_edit_notebook_job(self): # No postprocessing necessary, we have just created the skeleton notebook and # execution will anyway stop afterwards. return - if assume_shared_fs: + if self.dag.workflow.storage_settings.assume_shared_fs: if not error and handle_touch: self.dag.handle_touch(self) if handle_log: self.dag.handle_log(self) if not error: self.dag.check_and_touch_output( - self, wait=latency_wait, ignore_missing_output=ignore_missing_output + self, + wait=self.dag.workflow.execution_settings.latency_wait, + ignore_missing_output=ignore_missing_output, ) self.dag.unshadow_output(self, only_log=error) if not error: @@ -1187,13 +1172,14 @@ def postprocess( else: if not error: self.dag.check_and_touch_output( - self, wait=latency_wait, no_touch=True, force_stay_on_remote=True + self, + wait=self.dag.workflow.execution_settings.latency_wait, + no_touch=True, + force_stay_on_remote=True, ) if not error: try: - self.dag.workflow.persistence.finished( - self, keep_metadata=keep_metadata - ) + self.dag.workflow.persistence.finished(self) except IOError as e: raise WorkflowError( "Error recording metadata for finished job " @@ -1281,7 +1267,7 @@ def __init__(self, id, jobs, global_resources): self._log = None self._inputsize = None self._all_products = None - self._attempt = self.dag.workflow.attempt + self._attempt = self.dag.workflow.execution_settings.attempt self._jobid = None @property @@ -1375,9 +1361,9 @@ def log_error(self, msg=None, aux_logs: Optional[list] = None, **kwargs): **kwargs, ) - def register(self): + def register(self, external_jobid: Optional[str] = None): for job in self.jobs: - job.register() + job.register(external_jobid=external_jobid) def remove_existing_output(self): for job in self.jobs: @@ -1413,7 +1399,8 @@ def get_wait_for_files(self): if job.shadow_dir: wait_for_files.append(job.shadow_dir) if ( - self.dag.workflow.use_conda + DeploymentMethod.CONDA + in self.dag.workflow.deployment_settings.deployment_method and job.conda_env and not job.conda_env.is_named ): @@ -1427,7 +1414,7 @@ def resources(self): self._resources = GroupResources.basic_layered( toposorted_jobs=self.toposorted, constraints=self.global_resources, - run_local=self.dag.workflow.run_local, + run_local=self.dag.workflow.local_exec, additive_resources=["runtime"], sortby=["runtime"], ) diff --git a/snakemake/linting/rules.py b/snakemake/linting/rules.py index 269080ad3..5184a46c0 100644 --- a/snakemake/linting/rules.py +++ b/snakemake/linting/rules.py @@ -80,15 +80,6 @@ def lint_not_used_params( links=[links.params], ) - def lint_version(self, rule): - if rule.version: - yield Lint( - title="The version directive is deprecated", - body="It was meant for documenting tool version, but this has been replaced " - "by using the conda or container directive.", - links=[links.package_management, links.containers], - ) - def lint_dynamic(self, rule): for file in chain(rule.output, rule.input): if is_flagged(file, "dynamic"): diff --git a/snakemake/logging.py b/snakemake/logging.py index 6617a7099..d3ec65ab9 100644 --- a/snakemake/logging.py +++ b/snakemake/logging.py @@ -15,7 +15,7 @@ import inspect import textwrap -from snakemake_interface_executor_plugins.utils import ExecMode +from snakemake_interface_executor_plugins.settings import ExecMode from snakemake.common import DYNAMIC_FILL @@ -34,9 +34,7 @@ class ColorizingStreamHandler(_logging.StreamHandler): "ERROR": RED, } - def __init__( - self, nocolor=False, stream=sys.stderr, use_threads=False, mode=ExecMode.default - ): + def __init__(self, nocolor=False, stream=sys.stderr, mode=ExecMode.DEFAULT): super().__init__(stream=stream) self._output_lock = threading.Lock() @@ -46,7 +44,7 @@ def __init__( def can_color_tty(self, mode): if "TERM" in os.environ and os.environ["TERM"] == "dumb": return False - if mode == ExecMode.subprocess: + if mode == ExecMode.SUBPROCESS: return True return self.is_tty and not platform.system() == "Windows" @@ -293,13 +291,13 @@ def __init__(self): self.quiet = set() self.logfile = None self.last_msg_was_job_info = False - self.mode = ExecMode.default + self.mode = ExecMode.DEFAULT self.show_failed_logs = False self.logfile_handler = None self.dryrun = False def setup_logfile(self): - if self.mode == ExecMode.default and not self.dryrun: + if self.mode == ExecMode.DEFAULT and not self.dryrun: os.makedirs(os.path.join(".snakemake", "log"), exist_ok=True) self.logfile = os.path.abspath( os.path.join( @@ -314,7 +312,7 @@ def setup_logfile(self): self.logger.addHandler(self.logfile_handler) def cleanup(self): - if self.mode == ExecMode.default and self.logfile_handler is not None: + if self.mode == ExecMode.DEFAULT and self.logfile_handler is not None: self.logger.removeHandler(self.logfile_handler) self.logfile_handler.close() self.log_handler = [self.text_handler] @@ -339,7 +337,7 @@ def set_level(self, level): self.logger.setLevel(level) def logfile_hint(self): - if self.mode == ExecMode.default and not self.dryrun: + if self.mode == ExecMode.DEFAULT and not self.dryrun: logfile = self.get_logfile() self.info(f"Complete log: {os.path.relpath(logfile)}") @@ -411,8 +409,13 @@ def d3dag(self, **msg): msg["level"] = "d3dag" self.handler(msg) - def is_quiet_about(self, msg_type): - return msg_type in self.quiet or "all" in self.quiet + def is_quiet_about(self, msg_type: str): + from snakemake.settings import Quietness + + return ( + Quietness.ALL in self.quiet + or Quietness.parse_choice(msg_type) in self.quiet + ) def text_handler(self, msg): """The default snakemake log handler. @@ -691,8 +694,7 @@ def setup_logger( nocolor=False, stdout=False, debug=False, - use_threads=False, - mode=ExecMode.default, + mode=ExecMode.DEFAULT, show_failed_logs=False, dryrun=False, ): @@ -717,7 +719,6 @@ def setup_logger( stream_handler = ColorizingStreamHandler( nocolor=nocolor, stream=sys.stdout if stdout else sys.stderr, - use_threads=use_threads, mode=mode, ) logger.set_stream_handler(stream_handler) diff --git a/snakemake/modules.py b/snakemake/modules.py index 4c400d4bc..2dad56608 100644 --- a/snakemake/modules.py +++ b/snakemake/modules.py @@ -107,7 +107,8 @@ def use_rules( def get_snakefile(self): if self.meta_wrapper: return wrapper.get_path( - self.meta_wrapper + "/test/Snakefile", self.workflow.wrapper_prefix + self.meta_wrapper + "/test/Snakefile", + self.workflow.workflow_settings.wrapper_prefix, ) elif self.snakefile: return self.snakefile diff --git a/snakemake/notebook.py b/snakemake/notebook.py index 02cb1a81a..372b63ed9 100644 --- a/snakemake/notebook.py +++ b/snakemake/notebook.py @@ -16,13 +16,6 @@ KERNEL_SHUTDOWN_RE = re.compile(r"Kernel shutdown: (?P\S+)") -class EditMode: - def __init__(self, server_addr=None, draft_only=False): - if server_addr is not None: - self.ip, self.port = server_addr.split(":") - self.draft_only = draft_only - - def get_cell_sources(source): import nbformat diff --git a/snakemake/parser.py b/snakemake/parser.py index 73346ed34..104d0923f 100644 --- a/snakemake/parser.py +++ b/snakemake/parser.py @@ -70,6 +70,7 @@ def __init__(self, token): class TokenAutomaton: subautomata: Dict[str, Any] = {} + deprecated: Dict[str, str] = {} def __init__(self, snakefile: "Snakefile", base_indent=0, dedent=0, root=True): self.root = root @@ -112,7 +113,15 @@ def consume(self): def error(self, msg, token): raise SyntaxError(msg, (self.snakefile.path, lineno(token), None, None)) - def subautomaton(self, automaton, *args, **kwargs): + def subautomaton(self, automaton, *args, token=None, **kwargs): + if automaton in self.deprecated: + assert ( + token is not None + ), "bug: deprecation encountered but subautomaton not called with a token" + self.error( + f"Keyword {automaton} is deprecated. {self.deprecated[automaton]}", + token, + ) return self.subautomata[automaton]( self.snakefile, *args, @@ -321,91 +330,6 @@ def keyword(self): return "global_containerized" -# subworkflows - - -class SubworkflowKeywordState(SectionKeywordState): - prefix = "Subworkflow" - - -class SubworkflowSnakefile(SubworkflowKeywordState): - pass - - -class SubworkflowWorkdir(SubworkflowKeywordState): - pass - - -class SubworkflowConfigfile(SubworkflowKeywordState): - pass - - -class Subworkflow(GlobalKeywordState): - subautomata = dict( - snakefile=SubworkflowSnakefile, - workdir=SubworkflowWorkdir, - configfile=SubworkflowConfigfile, - ) - - def __init__(self, snakefile, base_indent=0, dedent=0, root=True): - super().__init__(snakefile, base_indent=base_indent, dedent=dedent, root=root) - self.state = self.name - self.has_snakefile = False - self.has_workdir = False - self.has_name = False - self.primary_token = None - - def end(self): - if not (self.has_snakefile or self.has_workdir): - self.error( - "A subworkflow needs either a path to a Snakefile or to a workdir.", - self.primary_token, - ) - yield ")" - - def name(self, token): - if is_name(token): - yield f"workflow.subworkflow({token.string!r}", token - self.has_name = True - elif is_colon(token) and self.has_name: - self.primary_token = token - self.state = self.block - else: - self.error("Expected name after subworkflow keyword.", token) - - def block_content(self, token): - if is_name(token): - try: - if token.string == "snakefile": - self.has_snakefile = True - if token.string == "workdir": - self.has_workdir = True - for t in self.subautomaton(token.string).consume(): - yield t - except KeyError: - self.error( - "Unexpected keyword {} in " - "subworkflow definition".format(token.string), - token, - ) - except StopAutomaton as e: - self.indentation(e.token) - for t in self.block(e.token): - yield t - elif is_comment(token): - yield "\n", token - yield token.string, token - elif is_string(token): - # ignore docstring - pass - else: - self.error( - "Expecting subworkflow keyword, comment or docstrings " - "inside a subworkflow definition.", - token, - ) - - class Localrules(GlobalKeywordState): def block_content(self, token): if is_comma(token): @@ -538,7 +462,7 @@ def start(self): yield "\n" yield ( "def __rule_{rulename}(input, output, params, wildcards, threads, " - "resources, log, version, rule, conda_env, container_img, " + "resources, log, rule, conda_env, container_img, " "singularity_args, use_singularity, env_modules, bench_record, jobid, " "is_shell, bench_iteration, cleanup_scripts, shadow_dir, edit_notebook, " "conda_base_path, basedir, runtime_sourcecache_path, {rule_func_marker}=True):".format( @@ -667,7 +591,7 @@ def args(self): yield ( ", input, output, params, wildcards, threads, resources, log, " "config, rule, conda_env, conda_base_path, container_img, singularity_args, env_modules, " - "bench_record, workflow.wrapper_prefix, jobid, bench_iteration, " + "bench_record, workflow.workflow_settings.wrapper_prefix, jobid, bench_iteration, " "cleanup_scripts, shadow_dir, runtime_sourcecache_path" ) @@ -700,7 +624,6 @@ def args(self): resources=Resources, retries=Retries, priority=Priority, - version=Version, log=Log, message=Message, benchmark=Benchmark, @@ -717,6 +640,9 @@ def args(self): default_target=DefaultTarget, localrule=LocalRule, ) +rule_property_deprecated = dict( + version="Use conda or container directive instead (see docs)." +) class Rule(GlobalKeywordState): @@ -730,6 +656,7 @@ class Rule(GlobalKeywordState): cwl=CWL, **rule_property_subautomata, ) + deprecated = rule_property_deprecated def __init__(self, snakefile, base_indent=0, dedent=0, root=True): super().__init__(snakefile, base_indent=base_indent, dedent=dedent, root=root) @@ -798,7 +725,7 @@ def block_content(self, token): token, ) for t in self.subautomaton( - token.string, rulename=self.rulename + token.string, token=token, rulename=self.rulename ).consume(): yield t except KeyError: @@ -928,7 +855,7 @@ def block_content(self, token): self.has_snakefile = True if token.string == "meta_wrapper": self.has_meta_wrapper = True - for t in self.subautomaton(token.string).consume(): + for t in self.subautomaton(token.string, token=token).consume(): yield t except KeyError: self.error( @@ -956,6 +883,7 @@ def block_content(self, token): class UseRule(GlobalKeywordState): subautomata = rule_property_subautomata + deprecated = rule_property_deprecated def __init__(self, snakefile, base_indent=0, dedent=0, root=True): super().__init__(snakefile, base_indent=base_indent, dedent=dedent, root=root) @@ -1168,7 +1096,9 @@ def block_content(self, token): yield token.string, token elif is_name(token): try: - self._with_block.extend(self.subautomaton(token.string).consume()) + self._with_block.extend( + self.subautomaton(token.string, token=token).consume() + ) yield from () except KeyError: self.error( @@ -1202,7 +1132,6 @@ class Python(TokenAutomaton): ruleorder=Ruleorder, rule=Rule, checkpoint=Checkpoint, - subworkflow=Subworkflow, localrules=Localrules, onsuccess=OnSuccess, onerror=OnError, @@ -1216,6 +1145,7 @@ class Python(TokenAutomaton): module=Module, use=UseRule, ) + deprecated = dict(subworkflow="Use module directive instead (see docs).") def __init__(self, snakefile, base_indent=0, dedent=0, root=True): super().__init__(snakefile, base_indent=base_indent, dedent=dedent, root=root) @@ -1225,7 +1155,7 @@ def python(self, token: tokenize.TokenInfo): if not (is_indent(token) or is_dedent(token)): if self.lasttoken is None or self.lasttoken.isspace(): try: - for t in self.subautomaton(token.string).consume(): + for t in self.subautomaton(token.string, token=token).consume(): yield t except KeyError: yield token.string, token diff --git a/snakemake/path_modifier.py b/snakemake/path_modifier.py index 4fc7781fe..1e15007c3 100644 --- a/snakemake/path_modifier.py +++ b/snakemake/path_modifier.py @@ -97,7 +97,7 @@ def is_annotated_callable(value): return bool(value.callable) if ( - self.workflow.default_remote_provider is None + self.workflow.storage_settings.default_remote_provider is None or is_flagged(path, "remote_object") or is_flagged(path, "local") or is_annotated_callable(path) @@ -106,9 +106,9 @@ def is_annotated_callable(value): return path # This will convert any AnnotatedString to str - fullpath = f"{self.workflow.default_remote_prefix}/{path}" + fullpath = f"{self.workflow.storage_settings.default_remote_prefix}/{path}" fullpath = os.path.normpath(fullpath) - remote = self.workflow.default_remote_provider.remote(fullpath) + remote = self.workflow.storage_settings.default_remote_provider.remote(fullpath) return remote @property diff --git a/snakemake/persistence.py b/snakemake/persistence.py index 90779c68c..41e9f1956 100755 --- a/snakemake/persistence.py +++ b/snakemake/persistence.py @@ -14,6 +14,8 @@ from functools import lru_cache from itertools import count from pathlib import Path +from contextlib import contextmanager +from typing import Optional from snakemake_interface_executor_plugins.persistence import ( PersistenceExecutorInterface, @@ -180,6 +182,7 @@ def locked(self): return True return False + @contextmanager def lock_warn_only(self): if self.locked: logger.info( @@ -187,14 +190,20 @@ def lock_warn_only(self): "means that another Snakemake instance is running on this directory. " "Another possibility is that a previous run exited unexpectedly." ) + yield + @contextmanager def lock(self): if self.locked: raise snakemake.exceptions.LockException() - self._lock(self.all_inputfiles(), "input") - self._lock(self.all_outputfiles(), "output") + try: + self._lock(self.all_inputfiles(), "input") + self._lock(self.all_outputfiles(), "output") + yield + finally: + self.unlock() - def unlock(self, *args): + def unlock(self): logger.debug("unlocking") for lockfile in self._lockfile.values(): try: @@ -256,17 +265,16 @@ def conda_cleanup_envs(self): if d not in in_use: shutil.rmtree(os.path.join(self.conda_env_archive_path, d)) - def started(self, job, external_jobid=None): + def started(self, job, external_jobid: Optional[str] = None): for f in job.output: self._record(self._incomplete_path, {"external_jobid": external_jobid}, f) - def finished(self, job, keep_metadata=True): - if not keep_metadata: + def finished(self, job): + if not self.dag.workflow.execution_settings.keep_metadata: for f in job.expanded_output: self._delete_record(self._incomplete_path, f) return - version = str(job.rule.version) if job.rule.version is not None else None code = self._code(job.rule) input = self._input(job) log = self._log(job) @@ -289,7 +297,6 @@ def finished(self, job, keep_metadata=True): self._record( self._metadata_path, { - "version": version, "code": code, "rule": job.rule.name, "input": input, @@ -352,9 +359,6 @@ def external_jobids(self, job): def metadata(self, path): return self._read_record(self._metadata_path, path) - def version(self, path): - return self.metadata(path).get("version") - def rule(self, path): return self.metadata(path).get("rule") @@ -388,10 +392,6 @@ def input_checksums(self, job, input_path): for output_path in job.output ) - def version_changed(self, job, file=None): - """Yields output files with changed versions or bool if file given.""" - return _bool_or_gen(self._version_changed, job, file=file) - def code_changed(self, job, file=None): """Yields output files with changed code or bool if file given.""" return _bool_or_gen(self._code_changed, job, file=file) @@ -412,11 +412,6 @@ def container_changed(self, job, file=None): """Yields output files with changed container img or bool if file given.""" return _bool_or_gen(self._container_changed, job, file=file) - def _version_changed(self, job, file=None): - assert file is not None - recorded = self.version(file) - return recorded is not None and recorded != job.rule.version - def _code_changed(self, job, file=None): assert file is not None recorded = self.code(file) @@ -442,8 +437,9 @@ def _container_changed(self, job, file=None): recorded = self.container_img_url(file) return recorded is not None and recorded != job.container_img_url + @contextmanager def noop(self, *args): - pass + yield def _b64id(self, s): return urlsafe_b64encode(str(s).encode()).decode() diff --git a/snakemake/remote/FTP.py b/snakemake/remote/FTP.py index 508f508d8..4d3110e95 100644 --- a/snakemake/remote/FTP.py +++ b/snakemake/remote/FTP.py @@ -108,7 +108,7 @@ def __init__( # === Implementations of abstract class members === def get_default_kwargs(self, **defaults): - """define defaults beyond thos set in PooledDomainObject""" + """define defaults beyond those set in PooledDomainObject""" return super().get_default_kwargs( **{"port": 21, "password": None, "username": None} ) diff --git a/snakemake/remote/HTTP.py b/snakemake/remote/HTTP.py index b1ea58280..163db1a87 100644 --- a/snakemake/remote/HTTP.py +++ b/snakemake/remote/HTTP.py @@ -3,12 +3,15 @@ __email__ = "tomkinsc@broadinstitute.org" __license__ = "MIT" +from dataclasses import dataclass import os import re import collections import shutil import email.utils from contextlib import contextmanager +import snakemake +import snakemake.io # module-specific from snakemake.remote import AbstractRemoteProvider, DomainObject @@ -99,6 +102,15 @@ def __init__( ) self.additional_request_string = additional_request_string + async def inventory(self, cache: snakemake.io.IOCache): + """Obtain all info with a single HTTP request.""" + name = self.local_file() + with self.httpr(verb="HEAD") as httpr: + res = ResponseHandler(httpr) + cache.mtime[name] = snakemake.io.Mtime(remote=res.mtime()) + cache.exists_remote[name] = res.exists() + cache.size[name] = res.size() + # === Implementations of abstract class members === @contextmanager # makes this a context manager. after 'yield' is __exit__() @@ -144,24 +156,20 @@ def httpr(self, verb="GET", stream=False): url = self.remote_file() + self.additional_request_string - if verb.upper() == "GET": - r = requests.get(url, *args_to_use, stream=stream, **kwargs_to_use) - if verb.upper() == "HEAD": - r = requests.head(url, *args_to_use, **kwargs_to_use) + try: + if verb.upper() == "GET": + r = requests.get(url, *args_to_use, stream=stream, **kwargs_to_use) + if verb.upper() == "HEAD": + r = requests.head(url, *args_to_use, **kwargs_to_use) - yield r - r.close() + yield r + finally: + r.close() def exists(self): if self._matched_address: with self.httpr(verb="HEAD") as httpr: - # if a file redirect was found - if httpr.status_code in range(300, 308): - raise HTTPFileException( - f"The file specified appears to have been moved (HTTP {httpr.status_code}), check the URL or try adding 'allow_redirects=True' to the remote() file object: {httpr.url}" - ) - return httpr.status_code == requests.codes.ok - return False + return ResponseHandler(httpr).exists() else: raise HTTPFileException( "The file cannot be parsed as an HTTP path in form 'host:port/abs/path/to/file': %s" @@ -171,23 +179,7 @@ def exists(self): def mtime(self): if self.exists(): with self.httpr(verb="HEAD") as httpr: - file_mtime = self.get_header_item(httpr, "last-modified", default=None) - logger.debug(f"HTTP last-modified: {file_mtime}") - - epochTime = 0 - - if file_mtime is not None: - modified_tuple = email.utils.parsedate_tz(file_mtime) - if modified_tuple is None: - logger.debug( - "HTTP last-modified not in RFC2822 format: `{}`".format( - file_mtime - ) - ) - else: - epochTime = email.utils.mktime_tz(modified_tuple) - - return epochTime + return ResponseHandler(httpr).mtime() else: raise HTTPFileException( "The file does not seem to exist remotely: %s" % self.remote_file() @@ -196,11 +188,7 @@ def mtime(self): def size(self): if self.exists(): with self.httpr(verb="HEAD") as httpr: - content_size = int( - self.get_header_item(httpr, "content-size", default=0) - ) - - return content_size + return ResponseHandler(httpr).size() else: return self._iofile.size_local @@ -237,20 +225,55 @@ def _upload(self): "Upload is not permitted for the HTTP remote provider. Is an output set to HTTP.remote()?" ) - def get_header_item(self, httpr, header_name, default): - """ - Since HTTP header capitalization may differ, this returns - a header value regardless of case - """ - - header_value = default - for k, v in httpr.headers.items(): - if k.lower() == header_name: - header_value = v - return header_value - @property def list(self): raise HTTPFileException( "The HTTP Remote Provider does not currently support list-based operations like glob_wildcards()." ) + + +@dataclass +class ResponseHandler: + response: requests.Response + + def exists(self): + if self.response.status_code in range(300, 308): + raise HTTPFileException( + f"The file specified appears to have been moved (HTTP {self.response.status_code}), check the URL or try adding 'allow_redirects=True' to the remote() file object: {self.response.url}" + ) + return self.response.status_code == requests.codes.ok + + def mtime(self): + file_mtime = get_header_item(self.response, "last-modified", default=None) + logger.debug(f"HTTP last-modified: {file_mtime}") + + epochTime = 0 + + if file_mtime is not None: + modified_tuple = email.utils.parsedate_tz(file_mtime) + if modified_tuple is None: + logger.debug( + "HTTP last-modified not in RFC2822 format: `{}`".format(file_mtime) + ) + else: + epochTime = email.utils.mktime_tz(modified_tuple) + + return epochTime + + def size(self): + content_size = int(get_header_item(self.response, "content-size", default=0)) + + return content_size + + +def get_header_item(httpr, header_name, default): + """ + Since HTTP header capitalization may differ, this returns + a header value regardless of case + """ + + header_value = default + for k, v in httpr.headers.items(): + if k.lower() == header_name: + header_value = v + return header_value diff --git a/snakemake/report/__init__.py b/snakemake/report/__init__.py index ea8424df0..c373f403e 100644 --- a/snakemake/report/__init__.py +++ b/snakemake/report/__init__.py @@ -297,7 +297,7 @@ def code(self): wrapper.get_script( self._rule.wrapper, self._rule.workflow.sourcecache, - prefix=self._rule.workflow.wrapper_prefix, + prefix=self._rule.workflow.workflow_settings.wrapper_prefix, ), self._rule.workflow.sourcecache, ) diff --git a/snakemake/resources.py b/snakemake/resources.py index 3dd052e64..0025831b6 100644 --- a/snakemake/resources.py +++ b/snakemake/resources.py @@ -4,11 +4,15 @@ import re import tempfile +from snakemake_interface_executor_plugins.resources import ( + DefaultResourcesExecutorInterface, +) + from snakemake.exceptions import ResourceScopesException, WorkflowError from snakemake.common import TBDString -class DefaultResources: +class DefaultResources(DefaultResourcesExecutorInterface): defaults = { "mem_mb": "max(2*input.size_mb, 1000)", "disk_mb": "max(2*input.size_mb, 1000)", @@ -541,6 +545,7 @@ def _highest_proportion(group): def parse_resources(resources_args, fallback=None): """Parse resources from args.""" resources = dict() + if resources_args is not None: valid = re.compile(r"[a-zA-Z_]\w*$") diff --git a/snakemake/ruleinfo.py b/snakemake/ruleinfo.py index a0ccf49b6..3fa726414 100644 --- a/snakemake/ruleinfo.py +++ b/snakemake/ruleinfo.py @@ -33,7 +33,6 @@ def __init__(self, func=None): self.resources = None self.priority = None self.retries = None - self.version = None self.log = None self.docstring = None self.group = None diff --git a/snakemake/rules.py b/snakemake/rules.py index 3f51e5542..82fb67cd0 100644 --- a/snakemake/rules.py +++ b/snakemake/rules.py @@ -17,7 +17,7 @@ except ImportError: # python < 3.11 import sre_constants -from snakemake_interface_executor_plugins.utils import ExecMode +from snakemake_interface_executor_plugins.settings import ExecMode from snakemake.io import ( IOFile, @@ -61,10 +61,11 @@ ) from snakemake.resources import infer_resources from snakemake_interface_executor_plugins.utils import not_iterable, lazy_property +from snakemake_interface_common.rules import RuleInterface -class Rule: - def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): +class Rule(RuleInterface): + def __init__(self, *args, lineno=None, snakefile=None): """ Create a rule @@ -73,7 +74,7 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): """ if len(args) == 2: name, workflow = args - self.name = name + self._name = name self.workflow = workflow self.docstring = None self.message = None @@ -87,21 +88,19 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): self.temp_output = set() self.protected_output = set() self.touch_output = set() - self.subworkflow_input = dict() self.shadow_depth = None self.resources = None self.priority = 0 - self._version = None self._log = Log() self._benchmark = None self._conda_env = None self._container_img = None self.is_containerized = False self.env_modules = None - self.group = None + self._group = None self._wildcard_names = None - self.lineno = lineno - self.snakefile = snakefile + self._lineno = lineno + self._snakefile = snakefile self.run_func = None self.shellcmd = None self.script = None @@ -113,7 +112,7 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): self.is_handover = False self.is_branched = False self.is_checkpoint = False - self.restart_times = 0 + self._restart_times = 0 self.basedir = None self.input_modifier = None self.output_modifier = None @@ -123,7 +122,7 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): self.module_globals = None elif len(args) == 1: other = args[0] - self.name = other.name + self._name = other.name self.workflow = other.workflow self.docstring = other.docstring self.message = other.message @@ -137,25 +136,23 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): self.temp_output = set(other.temp_output) self.protected_output = set(other.protected_output) self.touch_output = set(other.touch_output) - self.subworkflow_input = dict(other.subworkflow_input) self.shadow_depth = other.shadow_depth self.resources = other.resources self.priority = other.priority - self.version = other.version self._log = other._log self._benchmark = other._benchmark self._conda_env = other._conda_env self._container_img = other._container_img self.is_containerized = other.is_containerized self.env_modules = other.env_modules - self.group = other.group + self._group = other.group self._wildcard_names = ( set(other._wildcard_names) if other._wildcard_names is not None else None ) - self.lineno = other.lineno - self.snakefile = other.snakefile + self._lineno = other.lineno + self._snakefile = other.snakefile self.run_func = other.run_func self.shellcmd = other.shellcmd self.script = other.script @@ -167,7 +164,7 @@ def __init__(self, *args, lineno=None, snakefile=None, restart_times=0): self.is_handover = other.is_handover self.is_branched = True self.is_checkpoint = other.is_checkpoint - self.restart_times = other.restart_times + self._restart_times = other.restart_times self.basedir = other.basedir self.input_modifier = other.input_modifier self.output_modifier = other.output_modifier @@ -270,6 +267,52 @@ def partially_expand(f, wildcards): return branch, non_dynamic_wildcards return branch + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + + @property + def lineno(self): + return self._lineno + + @property + def snakefile(self): + return self._snakefile + + @property + def restart_times(self): + if self.workflow.remote_execution_settings.preemptible_rules.is_preemptible( + self.name + ): + return self.workflow.remote_execution_settings.preemptible_retries + if self._restart_times is None: + return self.workflow.execution_settings.retries + return self._restart_times + + @restart_times.setter + def restart_times(self, restart_times): + self._restart_times = restart_times + + @property + def group(self): + if self.workflow.local_exec: + return None + else: + overwrite_group = self.workflow.group_settings.overwrite_groups.get( + self.name + ) + if overwrite_group is not None: + return overwrite_group + return self._group + + @group.setter + def group(self, group): + self._group = group + @property def is_shell(self): return self.shellcmd is not None @@ -332,18 +375,6 @@ def has_wildcards(self): """ return bool(self.wildcard_names) - @property - def version(self): - return self._version - - @version.setter - def version(self, version): - if isinstance(version, str) and "\n" in version: - raise WorkflowError( - "Version string may not contain line breaks.", rule=self - ) - self._version = version - @property def benchmark(self): return self._benchmark @@ -579,7 +610,7 @@ def _set_inoutput_item(self, item, output=False, name=None): else: if ( contains_wildcard_constraints(item) - and self.workflow.mode != ExecMode.subprocess + and self.workflow.execution_settings.mode != ExecMode.SUBPROCESS ): logger.warning( "Wildcard constraints in inputs are ignored. (rule: {})".format( @@ -587,7 +618,7 @@ def _set_inoutput_item(self, item, output=False, name=None): ) ) - if self.workflow.all_temp and output: + if self.workflow.storage_settings.all_temp and output: # mark as temp if all output files shall be marked as temp item = flag(item, "temp") @@ -620,22 +651,6 @@ def _set_inoutput_item(self, item, output=False, name=None): report_obj.htmlindex, ) item.flags["report"] = r - if is_flagged(item, "subworkflow"): - if output: - raise SyntaxError("Only input files may refer to a subworkflow") - else: - # record the workflow this item comes from - sub = item.flags["subworkflow"] - if _item in self.subworkflow_input: - other = self.subworkflow_input[_item] - if sub != other: - raise WorkflowError( - "The input file {} is ambiguously " - "associated with two subworkflows " - "{} and {}.".format(item, sub, other), - rule=self, - ) - self.subworkflow_input[_item] = sub inoutput.append(_item) if name: inoutput._add_name(name) @@ -1127,8 +1142,8 @@ def apply(name, res, threads=None): threads = apply("_cores", self.resources["_cores"]) if threads is None: raise WorkflowError("Threads must be given as an int", rule=self) - if self.workflow.max_threads is not None: - threads = min(threads, self.workflow.max_threads) + if self.workflow.resource_settings.max_threads is not None: + threads = min(threads, self.workflow.resource_settings.max_threads) resources["_cores"] = threads for name, res in list(self.resources.items()): @@ -1393,14 +1408,14 @@ def log(self): def _to_iofile(self, files): def cleanup(f): - prefix = self.rule.workflow.default_remote_prefix + prefix = self.rule.workflow.storage_settings.default_remote_prefix # remove constraints and turn this into a plain string cleaned = strip_wildcard_constraints(f) modified_by = get_flag_value(f, PATH_MODIFIER_FLAG) if ( - self.rule.workflow.default_remote_provider is not None + self.rule.workflow.storage_settings.default_remote_provider is not None and f.startswith(prefix) and not is_flagged(f, "local") ): diff --git a/snakemake/scheduler.py b/snakemake/scheduler.py index 7526a97ac..427341725 100644 --- a/snakemake/scheduler.py +++ b/snakemake/scheduler.py @@ -12,32 +12,12 @@ from snakemake_interface_executor_plugins.scheduler import JobSchedulerExecutorInterface from snakemake_interface_executor_plugins.registry import ExecutorPluginRegistry +from snakemake_interface_executor_plugins.registry import Plugin as ExecutorPlugin -from snakemake.executors import ( - AbstractExecutor, - DryrunExecutor, - TouchExecutor, - CPUExecutor, -) -from snakemake.executors import ( - GenericClusterExecutor, - SynchronousClusterExecutor, - DRMAAExecutor, - KubernetesExecutor, - TibannaExecutor, -) - -from snakemake.executors.slurm.slurm_submit import SlurmExecutor -from snakemake.executors.slurm.slurm_jobstep import SlurmJobstepExecutor -from snakemake.executors.flux import FluxExecutor -from snakemake.executors.google_lifesciences import GoogleLifeSciencesExecutor -from snakemake.executors.ga4gh_tes import TaskExecutionServiceExecutor from snakemake.exceptions import RuleException, WorkflowError, print_exception -from snakemake.common import ON_WINDOWS from snakemake.logging import logger from fractions import Fraction -from snakemake.stats import Stats registry = ExecutorPluginRegistry() @@ -65,76 +45,19 @@ async def __aexit__(self, *args): class JobScheduler(JobSchedulerExecutorInterface): - def __init__( - self, - workflow, - dag, - local_cores=1, - dryrun=False, - touch=False, - slurm=None, - slurm_jobstep=None, - cluster=None, - cluster_status=None, - cluster_sync=None, - cluster_cancel=None, - cluster_cancel_nargs=None, - cluster_sidecar=None, - drmaa=None, - drmaa_log_dir=None, - env_modules=None, - kubernetes=None, - k8s_cpu_scalar=1.0, - k8s_service_account_name=None, - container_image=None, - flux=None, - tibanna=None, - tibanna_sfn=None, - az_batch=False, - az_batch_enable_autoscale=False, - az_batch_account_url=None, - google_lifesciences=None, - google_lifesciences_regions=None, - google_lifesciences_location=None, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=None, - google_lifesciences_network=None, - google_lifesciences_subnetwork=None, - tes=None, - precommand="", - preemption_default=None, - preemptible_rules=None, - tibanna_config=False, - jobname=None, - keepgoing=False, - max_jobs_per_second=None, - max_status_checks_per_second=100, - # Note this argument doesn't seem to be used (greediness) - greediness=1.0, - force_use_threads=False, - scheduler_type=None, - scheduler_ilp_solver=None, - executor_args=None, - ): + def __init__(self, workflow, executor_plugin: ExecutorPlugin): """Create a new instance of KnapsackJobScheduler.""" - - cores = workflow.global_resources["_cores"] - - self.cluster = cluster - self.cluster_sync = cluster_sync - self.dag = dag self.workflow = workflow - self.dryrun = dryrun - self.touch = touch - self.quiet = workflow.quiet - self.keepgoing = keepgoing + + self.dryrun = self.workflow.dryrun + self.touch = self.workflow.touch + self.quiet = self.workflow.output_settings.quiet + self.keepgoing = self.workflow.execution_settings.keep_going self.running = set() self.failed = set() self.finished_jobs = 0 - self.greediness = 1 - self.max_jobs_per_second = max_jobs_per_second - self.scheduler_type = scheduler_type - self.scheduler_ilp_solver = scheduler_ilp_solver + self.greediness = self.workflow.scheduling_settings.greediness + self.max_jobs_per_second = self.workflow.scheduling_settings.max_jobs_per_second self._tofinish = [] self._toerror = [] self.handle_job_success = True @@ -142,17 +65,19 @@ def __init__( self.print_progress = not self.quiet and not self.dryrun self.update_dynamic = not self.dryrun + nodes_unset = workflow.global_resources["_nodes"] is None + self.global_resources = { name: (sys.maxsize if res is None else res) for name, res in workflow.global_resources.items() } - if workflow.global_resources["_nodes"] is not None: + if not nodes_unset: # Do not restrict cores locally if nodes are used (i.e. in case of cluster/cloud submission). self.global_resources["_cores"] = sys.maxsize + self.resources = dict(self.global_resources) - use_threads = force_use_threads or (os.name != "posix") self._open_jobs = threading.Semaphore(0) self._lock = threading.Lock() @@ -161,290 +86,279 @@ def __init__( self._finished = False self._job_queue = None self._last_job_selection_empty = False - self._submit_callback = self._noop - self._finish_callback = self._proceed + self.submit_callback = self._noop + self.finish_callback = self._proceed - self._stats = Stats() + if workflow.remote_execution_settings.immediate_submit: + self.submit_callback = self._proceed + self.finish_callback = self._noop self._local_executor = None - if dryrun: - self._executor: AbstractExecutor = DryrunExecutor( - workflow, - dag, - ) - elif touch: - self._executor = TouchExecutor( - workflow, - dag, - self.stats, - logger, - ) - # We have chosen an executor custom plugin - elif executor_args is not None: - plugin = registry.plugins[executor_args._executor.name] - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - self._executor = plugin.executor( - workflow, - dag, - self.stats, - logger, - cores, - executor_args=executor_args, - ) - - elif slurm: - if ON_WINDOWS: - raise WorkflowError("SLURM execution is not supported on Windows.") - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - # we need to adjust the maximum status checks per second - # on a SLURM cluster, to not overstrain the scheduler; - # timings for tested SLURM clusters, extracted from --verbose - # output with: - # ``` - # grep "sacct output" .snakemake/log/2023-02-13T210004.601290.snakemake.log | \ - # awk '{ counter += 1; sum += $6; sum_of_squares += ($6)^2 } \ - # END { print "average: ",sum/counter," sd: ",sqrt((sum_of_squares - sum^2/counter) / counter); } - # ```` - # * cluster 1: - # * sacct: average: 0.073896 sd: 0.0640178 - # * scontrol: average: 0.0193017 sd: 0.0358858 - # Thus, 2 status checks per second should leave enough - # capacity for everybody. - # TODO: check timings on other slurm clusters, to: - # * confirm that this cap is reasonable - # * check if scontrol is the quicker option across the board - if max_status_checks_per_second > 2: - max_status_checks_per_second = 2 - - self._executor = SlurmExecutor( - workflow, - dag, - self.stats, + if self.workflow.local_exec: + self._executor = executor_plugin.executor( + self.workflow, logger, - max_status_checks_per_second=max_status_checks_per_second, ) - - elif slurm_jobstep: - self._executor = SlurmJobstepExecutor( - workflow, - dag, - self.stats, + else: + self._executor = executor_plugin.executor( + self.workflow, logger, ) - self._local_executor = self._executor - - elif cluster or cluster_sync or (drmaa is not None): - if not workflow.immediate_submit: - # No local jobs when using immediate submit! - # Otherwise, they will fail due to missing input - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - - if cluster or cluster_sync: - if cluster_sync: - constructor = SynchronousClusterExecutor - else: - constructor = partial( - GenericClusterExecutor, - statuscmd=cluster_status, - cancelcmd=cluster_cancel, - cancelnargs=cluster_cancel_nargs, - sidecarcmd=cluster_sidecar, - max_status_checks_per_second=max_status_checks_per_second, - ) - - self._executor = constructor( - workflow, - dag, - self.stats, - logger, - submitcmd=(cluster or cluster_sync), - jobname=jobname, - ) - if workflow.immediate_submit: - self._submit_callback = self._proceed - self.update_dynamic = False - self.print_progress = False - self.update_resources = False - self.handle_job_success = False - else: - self._executor = DRMAAExecutor( - workflow, - dag, - self.stats, + self._local_executor = ( + ExecutorPluginRegistry() + .get_plugin("local") + .executor( + self.workflow, logger, - drmaa_args=drmaa, - drmaa_log_dir=drmaa_log_dir, - jobname=jobname, - max_status_checks_per_second=max_status_checks_per_second, ) - elif kubernetes: - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, ) - self._executor = KubernetesExecutor( - workflow, - dag, - self.stats, - logger, - kubernetes, - container_image=container_image, - k8s_cpu_scalar=k8s_cpu_scalar, - k8s_service_account_name=k8s_service_account_name, - ) - elif tibanna: - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - use_threads=use_threads, - ) - - self._executor = TibannaExecutor( - workflow, - dag, - self.stats, - logger, - cores, - tibanna_sfn, - precommand=precommand, - tibanna_config=tibanna_config, - container_image=container_image, - ) - - elif flux: - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - - self._executor = FluxExecutor( - workflow, - dag, - self.stats, - logger, - ) - - elif az_batch: - try: - from snakemake.executors.azure_batch import AzBatchExecutor - except ImportError as e: - raise WorkflowError( - "Unable to load Azure Batch executor. You have to install " - "the msrest, azure-core, azure-batch, azure-mgmt-batch, and azure-identity packages.", - e, - ) - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - self._executor = AzBatchExecutor( - workflow, - dag, - self.stats, - logger, - container_image=container_image, - az_batch_account_url=az_batch_account_url, - az_batch_enable_autoscale=az_batch_enable_autoscale, - ) - - elif google_lifesciences: - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - - self._executor = GoogleLifeSciencesExecutor( - workflow, - dag, - self.stats, - logger, - container_image=container_image, - regions=google_lifesciences_regions, - location=google_lifesciences_location, - cache=google_lifesciences_cache, - service_account_email=google_lifesciences_service_account_email, - network=google_lifesciences_network, - subnetwork=google_lifesciences_subnetwork, - preemption_default=preemption_default, - preemptible_rules=preemptible_rules, - ) - elif tes: - self._local_executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - local_cores, - ) - - self._executor = TaskExecutionServiceExecutor( - workflow, - dag, - self.stats, - logger, - tes_url=tes, - container_image=container_image, - ) - - else: - self._executor = CPUExecutor( - workflow, - dag, - self.stats, - logger, - cores, - use_threads=use_threads, - ) + # elif slurm: + # if ON_WINDOWS: + # raise WorkflowError("SLURM execution is not supported on Windows.") + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + # # we need to adjust the maximum status checks per second + # # on a SLURM cluster, to not overstrain the scheduler; + # # timings for tested SLURM clusters, extracted from --verbose + # # output with: + # # ``` + # # grep "sacct output" .snakemake/log/2023-02-13T210004.601290.snakemake.log | \ + # # awk '{ counter += 1; sum += $6; sum_of_squares += ($6)^2 } \ + # # END { print "average: ",sum/counter," sd: ",sqrt((sum_of_squares - sum^2/counter) / counter); } + # # ```` + # # * cluster 1: + # # * sacct: average: 0.073896 sd: 0.0640178 + # # * scontrol: average: 0.0193017 sd: 0.0358858 + # # Thus, 2 status checks per second should leave enough + # # capacity for everybody. + # # TODO: check timings on other slurm clusters, to: + # # * confirm that this cap is reasonable + # # * check if scontrol is the quicker option across the board + # if max_status_checks_per_second > 2: + # max_status_checks_per_second = 2 + + # self._executor = SlurmExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # max_status_checks_per_second=max_status_checks_per_second, + # ) + + # elif slurm_jobstep: + # self._executor = SlurmJobstepExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # ) + # self._local_executor = self._executor + + # elif cluster or cluster_sync or (drmaa is not None): + # if not workflow.remote_execution_settings.immediate_submit: + # # No local jobs when using immediate submit! + # # Otherwise, they will fail due to missing input + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + + # if cluster or cluster_sync: + # if cluster_sync: + # constructor = SynchronousClusterExecutor + # else: + # constructor = partial( + # GenericClusterExecutor, + # statuscmd=cluster_status, + # cancelcmd=cluster_cancel, + # cancelnargs=cluster_cancel_nargs, + # sidecarcmd=cluster_sidecar, + # max_status_checks_per_second=max_status_checks_per_second, + # ) + + # self._executor = constructor( + # workflow, + # dag, + # self.stats, + # logger, + # submitcmd=(cluster or cluster_sync), + # jobname=jobname, + # ) + # if workflow.remote_execution_settings.immediate_submit: + # self._submit_callback = self._proceed + # self.update_dynamic = False + # self.print_progress = False + # self.update_resources = False + # self.handle_job_success = False + # else: + # self._executor = DRMAAExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # drmaa_args=drmaa, + # drmaa_log_dir=drmaa_log_dir, + # jobname=jobname, + # max_status_checks_per_second=max_status_checks_per_second, + # ) + # elif kubernetes: + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + + # self._executor = KubernetesExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # kubernetes, + # container_image=container_image, + # k8s_cpu_scalar=k8s_cpu_scalar, + # k8s_service_account_name=k8s_service_account_name, + # ) + # elif tibanna: + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # use_threads=use_threads, + # ) + + # self._executor = TibannaExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # cores, + # tibanna_sfn, + # precommand=precommand, + # tibanna_config=tibanna_config, + # container_image=container_image, + # ) + + # elif flux: + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + + # self._executor = FluxExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # ) + + # elif az_batch: + # try: + # from snakemake.executors.azure_batch import AzBatchExecutor + # except ImportError as e: + # raise WorkflowError( + # "Unable to load Azure Batch executor. You have to install " + # "the msrest, azure-core, azure-batch, azure-mgmt-batch, and azure-identity packages.", + # e, + # ) + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + # self._executor = AzBatchExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # container_image=container_image, + # az_batch_account_url=az_batch_account_url, + # az_batch_enable_autoscale=az_batch_enable_autoscale, + # ) + + # elif google_lifesciences: + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + + # self._executor = GoogleLifeSciencesExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # container_image=container_image, + # regions=google_lifesciences_regions, + # location=google_lifesciences_location, + # cache=google_lifesciences_cache, + # service_account_email=google_lifesciences_service_account_email, + # network=google_lifesciences_network, + # subnetwork=google_lifesciences_subnetwork, + # preemption_default=preemption_default, + # preemptible_rules=preemptible_rules, + # ) + # elif tes: + # self._local_executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # local_cores, + # ) + + # self._executor = TaskExecutionServiceExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # tes_url=tes, + # container_image=container_image, + # ) + + # else: + # self._executor = CPUExecutor( + # workflow, + # dag, + # self.stats, + # logger, + # cores, + # use_threads=use_threads, + # ) from throttler import Throttler - if self.max_jobs_per_second and not self.dryrun: + if not self.dryrun: max_jobs_frac = Fraction(self.max_jobs_per_second).limit_denominator() self.rate_limiter = Throttler( rate_limit=max_jobs_frac.numerator, period=max_jobs_frac.denominator ) - else: # essentially no rate limit self.rate_limiter = DummyRateLimiter() # Choose job selector (greedy or ILP) self.job_selector = self.job_selector_greedy - if scheduler_type == "ilp": + if self.workflow.scheduling_settings.scheduler == "ilp": import pulp if pulp.apis.LpSolverDefault is None: @@ -478,13 +392,13 @@ def stats(self): @property def open_jobs(self): """Return open jobs.""" - jobs = self.dag.ready_jobs + jobs = self.workflow.dag.ready_jobs if not self.dryrun: jobs = [ job for job in jobs - if not job.dynamic_input and not self.dag.dynamic(job) + if not job.dynamic_input and not self.workflow.dag.dynamic(job) ] return jobs @@ -493,9 +407,9 @@ def remaining_jobs(self): """Return jobs to be scheduled including not yet ready ones.""" return [ job - for job in self.dag.needrun_jobs() + for job in self.workflow.dag.needrun_jobs() if job not in self.running - and not self.dag.finished(job) + and not self.workflow.dag.finished(job) and job not in self.failed ] @@ -537,7 +451,10 @@ def schedule(self): continue # all runnable jobs have finished, normal shutdown - if not needrun and (not running or self.workflow.immediate_submit): + if not needrun and ( + not running + or self.workflow.remote_execution_settings.immediate_submit + ): self._executor.shutdown() if errors: logger.error(_ERROR_MSG_FINAL) @@ -590,13 +507,17 @@ def schedule(self): with self._lock: self.running.update(run) # remove from ready_jobs - self.dag.register_running(run) + self.workflow.dag.register_running(run) # actually run jobs local_runjobs = [job for job in run if job.is_local] runjobs = [job for job in run if not job.is_local] - self.run(local_runjobs, executor=self._local_executor or self._executor) - self.run(runjobs) + if local_runjobs: + self.run( + local_runjobs, executor=self._local_executor or self._executor + ) + if runjobs: + self.run(runjobs) except (KeyboardInterrupt, SystemExit): logger.info( "Terminating processes on user request, this might take some time." @@ -632,7 +553,7 @@ def _finish_jobs(self): logger.job_finished(jobid=job.jobid) self.progress() - self.dag.finish(job, update_dynamic=self.update_dynamic) + self.workflow.dag.finish(job, update_dynamic=self.update_dynamic) self._tofinish.clear() def _error_jobs(self): @@ -644,13 +565,7 @@ def _error_jobs(self): def run(self, jobs, executor=None): if executor is None: executor = self._executor - - executor.run_jobs( - jobs, - callback=self._finish_callback, - submit_callback=self._submit_callback, - error_callback=self._error, - ) + executor.run_jobs(jobs) def get_executor(self, job): if job.is_local and self._local_executor is not None: @@ -669,6 +584,7 @@ def _free_resources(self, job): def _proceed(self, job): """Do stuff after job is finished.""" with self._lock: + logger.debug(f"Completion of job {job.rules} reported to scheduler.") self._tofinish.append(job) if self.dryrun: @@ -681,7 +597,7 @@ def _proceed(self, job): # go on scheduling if there is any free core self._open_jobs.release() - def _error(self, job): + def error_callback(self, job): with self._lock: self._toerror.append(job) self._open_jobs.release() @@ -699,10 +615,10 @@ def _handle_error(self, job): # attempt starts counting from 1, but the first attempt is not # a restart, hence we subtract 1. if job.restart_times > job.attempt - 1: - logger.info(f"Trying to restart job {self.dag.jobid(job)}.") + logger.info(f"Trying to restart job {self.workflow.dag.jobid(job)}.") job.attempt += 1 # add job to those being ready again - self.dag._ready_jobs.add(job) + self.workflow.dag._ready_jobs.add(job) else: self._errors = True self.failed.add(job) @@ -748,7 +664,9 @@ def size_gb(f): return f.size / 1e9 temp_files = { - temp_file for job in jobs for temp_file in self.dag.temp_input(job) + temp_file + for job in jobs + for temp_file in self.workflow.dag.temp_input(job) } temp_job_improvement = { @@ -872,26 +790,27 @@ def _solve_ilp(self, prob): import pulp old_path = os.environ["PATH"] - if self.workflow.scheduler_solver_path is None: + if self.workflow.scheduling_settings.solver_path is None: # Temporarily prepend the given snakemake env to the path, such that the solver can be found in any case. # This is needed for cluster envs, where the cluster job might have a different environment but # still needs access to the solver binary. os.environ["PATH"] = "{}:{}".format( - self.workflow.scheduler_solver_path, os.environ["PATH"] + self.workflow.scheduling_settings.solver_path, + os.environ["PATH"], ) try: solver = ( - pulp.get_solver(self.scheduler_ilp_solver) - if self.scheduler_ilp_solver + pulp.get_solver(self.workflow.scheduling_settings.ilp_solver) + if self.workflow.scheduling_settings.ilp_solver else pulp.apis.LpSolverDefault ) finally: os.environ["PATH"] = old_path - solver.msg = self.workflow.verbose + solver.msg = self.workflow.output_settings.verbose prob.solve(solver) def required_by_job(self, temp_file, job): - return 1 if temp_file in self.dag.temp_input(job) else 0 + return 1 if temp_file in self.workflow.dag.temp_input(job) else 0 def job_selector_greedy(self, jobs): """ @@ -989,12 +908,16 @@ def job_weight(self, job): ] def job_reward(self, job): - if self.touch or self.dryrun or self.workflow.immediate_submit: + if ( + self.touch + or self.dryrun + or self.workflow.remote_execution_settings.immediate_submit + ): temp_size = 0 input_size = 0 else: try: - temp_size = self.dag.temp_size(job) + temp_size = self.workflow.dag.temp_size(job) input_size = job.inputsize except FileNotFoundError: # If the file is not yet present, this shall not affect the @@ -1014,4 +937,4 @@ def job_reward(self, job): def progress(self): """Display the progress.""" - logger.progress(done=self.finished_jobs, total=len(self.dag)) + logger.progress(done=self.finished_jobs, total=len(self.workflow.dag)) diff --git a/snakemake/settings.py b/snakemake/settings.py new file mode 100644 index 000000000..d3d190301 --- /dev/null +++ b/snakemake/settings.py @@ -0,0 +1,381 @@ +from abc import ABC +from dataclasses import dataclass, field +from enum import Enum +import importlib +from pathlib import Path +from typing import Optional +from collections.abc import Mapping, Sequence, Set + +import immutables + +from snakemake_interface_common.exceptions import ApiError +from snakemake_interface_executor_plugins.settings import ( + RemoteExecutionSettingsExecutorInterface, + DeploymentSettingsExecutorInterface, + ExecutionSettingsExecutorInterface, + StorageSettingsExecutorInterface, + DeploymentMethod, + ExecMode, +) +from snakemake_interface_common.settings import SettingsEnumBase + +from snakemake.common import dict_to_key_value_args, get_container_image +from snakemake.common.configfile import load_configfile +from snakemake.resources import DefaultResources +from snakemake.utils import update_config +from snakemake.exceptions import WorkflowError + + +class RerunTrigger(SettingsEnumBase): + MTIME = 0 + PARAMS = 1 + INPUT = 2 + SOFTWARE_ENV = 3 + CODE = 4 + + +class ChangeType(SettingsEnumBase): + CODE = 0 + INPUT = 1 + PARAMS = 2 + + +class SettingsBase(ABC): + def __post_init__(self): + self._check() + + def _check(self): + # by default, nothing to check + # override this method in subclasses if needed + pass + + +class NotebookEditMode: + def __init__(self, server_addr: Optional[str] = None, draft_only: bool = False): + if server_addr is not None: + self.ip, self.port = server_addr.split(":") + self.draft_only = draft_only + + +@dataclass +class ExecutionSettings(SettingsBase, ExecutionSettingsExecutorInterface): + """ + Parameters + ---------- + + batch: + whether to compute only a partial DAG, defined by the given Batch object + cache: + list of rules to cache + cores: + the number of provided cores (ignored when using cluster/cloud support) + nodes: + the number of provided cluster nodes (ignored without cluster/cloud support) + local_cores: + the number of provided local cores if in cluster mode (ignored without cluster/cloud support) + """ + + latency_wait: int = 3 + keep_going: bool = False + debug: bool = False + standalone: bool = False + ignore_ambiguity: bool = False + lock: bool = True + ignore_incomplete: bool = False + wait_for_files: Sequence[str] = tuple() + no_hooks: bool = False + retries: int = 0 + attempt: int = 1 + use_threads: bool = False + shadow_prefix: Optional[Path] = None + mode: ExecMode = ExecMode.DEFAULT + keep_incomplete: bool = False + keep_metadata: bool = True + edit_notebook: Optional[NotebookEditMode] = None + cleanup_scripts: bool = True + + +@dataclass +class WorkflowSettings(SettingsBase): + wrapper_prefix: Optional[str] = None + + +class Batch: + """Definition of a batch for calculating only a partial DAG.""" + + def __init__(self, rulename: str, idx: int, batches: int): + assert idx <= batches + assert idx > 0 + self.rulename = rulename + self.idx = idx + self.batches = batches + + def get_batch(self, items: list): + """Return the defined batch of the given items. + Items are usually input files.""" + # make sure that we always consider items in the same order + if len(items) < self.batches: + raise WorkflowError( + "Batching rule {} has less input files than batches. " + "Please choose a smaller number of batches.".format(self.rulename) + ) + items = sorted(items) + + # we can equally split items using divmod: + # len(items) = (self.batches * quotient) + remainder + # Because remainder always < divisor (self.batches), + # each batch will be equal to quotient + (1 or 0 item) + # from the remainder + k, m = divmod(len(items), self.batches) + + # self.batch is one-based, hence we have to subtract 1 + idx = self.idx - 1 + + # First n batches will have k (quotient) items + + # one item from the remainder (m). Once we consume all items + # from the remainder, last batches only contain k items. + i = idx * k + min(idx, m) + batch_len = (idx + 1) * k + min(idx + 1, m) + + if self.is_final: + # extend the last batch to cover rest of list + return items[i:] + else: + return items[i:batch_len] + + @property + def is_final(self): + return self.idx == self.batches + + def __str__(self): + return f"{self.idx}/{self.batches} (rule {self.rulename})" + + +@dataclass +class DAGSettings(SettingsBase): + targets: Set[str] = frozenset() + target_jobs: Set[str] = frozenset() + target_files_omit_workdir_adjustment: bool = False + batch: Optional[Batch] = None + forcetargets: bool = False + forceall: bool = False + forcerun: Set[str] = frozenset() + until: Set[str] = frozenset() + omit_from: Set[str] = frozenset() + force_incomplete: bool = False + allowed_rules: Set[str] = frozenset() + rerun_triggers: Set[RerunTrigger] = RerunTrigger.all() + max_inventory_wait_time: int = 20 + cache: Optional[Sequence[str]] = None + + def _check(self): + if self.batch is not None and self.forceall: + raise WorkflowError( + "--batch may not be combined with --forceall, because recomputed upstream " + "jobs in subsequent batches may render already obtained results outdated." + ) + + +@dataclass +class StorageSettings(SettingsBase, StorageSettingsExecutorInterface): + default_remote_provider: Optional[str] = None + default_remote_prefix: Optional[str] = None + assume_shared_fs: bool = True + keep_remote_local: bool = False + notemp: bool = False + all_temp: bool = False + + def __post_init__(self): + self.default_remote_provider = self._get_default_remote_provider() + super().__post_init__() + + def _get_default_remote_provider(self): + if self.default_remote_provider is not None: + try: + rmt = importlib.import_module( + "snakemake.remote." + self.default_remote_provider + ) + except ImportError: + raise ApiError( + f"Unknown default remote provider {self.default_remote_provider}." + ) + if rmt.RemoteProvider.supports_default: + return rmt.RemoteProvider( + keep_local=self.keep_remote_local, is_default=True + ) + else: + raise ApiError( + "Remote provider {} does not (yet) support to " + "be used as default provider." + ) + + +class CondaCleanupPkgs(SettingsEnumBase): + TARBALLS = 0 + CACHE = 1 + + +@dataclass +class DeploymentSettings(SettingsBase, DeploymentSettingsExecutorInterface): + """ + Parameters + ---------- + + deployment_method + deployment method to use (CONDA, APPTAINER, ENV_MODULES) + conda_prefix: + the directory in which conda environments will be created (default None) + conda_cleanup_pkgs: + whether to clean up conda tarballs after env creation (default None), valid values: "tarballs", "cache" + conda_create_envs_only: + if specified, only builds the conda environments specified for each job, then exits. + list_conda_envs: + list conda environments and their location on disk. + conda_base_path: + Path to conda base environment (this can be used to overwrite the search path for conda, mamba, and activate). + """ + + deployment_method: Set[DeploymentMethod] = frozenset() + conda_prefix: Optional[Path] = None + conda_cleanup_pkgs: Optional[CondaCleanupPkgs] = None + conda_base_path: Optional[Path] = None + conda_frontend: str = "mamba" + conda_not_block_search_path_envvars: bool = False + apptainer_args: str = "" + apptainer_prefix: Optional[Path] = None + + def imply_deployment_method(self, method: DeploymentMethod): + self.deployment_method = set(self.deployment_method) + self.deployment_method.add(method) + + +@dataclass +class SchedulingSettings(SettingsBase): + """ + Parameters + ---------- + + prioritytargets: + list of targets that shall be run with maximum priority (default []) + scheduler: + Select scheduling algorithm (default ilp, allowed: ilp, greedy) + ilp_solver: + Set solver for ilp scheduler. + greediness: + set the greediness of scheduling. This value between 0 and 1 determines how careful jobs are selected for execution. The default value (0.5 if prioritytargets are used, 1.0 else) provides the best speed and still acceptable scheduling quality. + """ + + prioritytargets: Set[str] = frozenset() + scheduler: str = "ilp" + ilp_solver: Optional[str] = None + solver_path: Optional[Path] = None + greediness: Optional[float] = None + max_jobs_per_second: int = 10 + + def __post_init__(self): + self.greediness = self._get_greediness() + + def _get_greediness(self): + if self.greediness is None: + return 0.5 if self.prioritytargets else 1.0 + else: + return self.greediness + + def _check(self): + if not (0 < self.greedyness <= 1.0): + raise ApiError("greediness must be >0 and <=1") + + +@dataclass +class ResourceSettings(SettingsBase): + cores: Optional[int] = None + nodes: Optional[int] = None + local_cores: Optional[int] = None + max_threads: Optional[int] = None + resources: Mapping[str, int] = immutables.Map() + overwrite_threads: Mapping[str, int] = immutables.Map() + overwrite_scatter: Mapping[str, int] = immutables.Map() + overwrite_resource_scopes: Mapping[str, str] = immutables.Map() + overwrite_resources: Mapping[str, Mapping[str, int]] = immutables.Map() + default_resources: Optional[DefaultResources] = None + + def __post_init__(self): + if self.default_resources is None: + self.default_resources = DefaultResources(mode="bare") + + +@dataclass +class ConfigSettings(SettingsBase): + config: Mapping[str, str] = immutables.Map() + configfiles: Sequence[Path] = tuple() + config_args: Optional[str] = None + + def __post_init__(self): + self.overwrite_config = self._get_overwrite_config() + self.configfiles = self._get_configfiles() + self.config_args = self._get_config_args() + + def _get_overwrite_config(self): + overwrite_config = dict() + if self.configfiles: + for f in self.configfiles: + update_config(overwrite_config, load_configfile(f)) + if self.config: + update_config(overwrite_config, self.config) + return overwrite_config + + def _get_configfiles(self): + return list(map(Path.absolute, self.configfiles)) + + def _get_config_args(self): + if self.config_args is None: + return dict_to_key_value_args(self.config) + else: + return self.config_args + + +class Quietness(SettingsEnumBase): + RULES = 0 + PROGRESS = 1 + ALL = 2 + + +@dataclass +class OutputSettings(SettingsBase): + printshellcmds: bool = False + nocolor: bool = False + quiet: Optional[Set[Quietness]] = None + debug_dag: bool = False + verbose: bool = False + show_failed_logs: bool = False + log_handlers: Sequence[object] = tuple() + keep_logger: bool = False + + +@dataclass +class PreemptibleRules: + rules: Set[str] = frozenset() + all: bool = False + + def is_preemptible(self, rulename: str): + return self.all or rulename in self.rules + + +@dataclass +class RemoteExecutionSettings(SettingsBase, RemoteExecutionSettingsExecutorInterface): + jobname: str = "snakejob.{rulename}.{jobid}.sh" + jobscript: Optional[Path] = None + max_status_checks_per_second: float = 100.0 + seconds_between_status_checks: int = 10 + container_image: str = get_container_image() + preemptible_retries: Optional[int] = None + preemptible_rules: PreemptibleRules = field(default_factory=PreemptibleRules) + envvars: Sequence[str] = tuple() + immediate_submit: bool = False + + +@dataclass +class GroupSettings(SettingsBase): + overwrite_groups: Mapping[str, str] = immutables.Map() + group_components: Mapping[str, int] = immutables.Map() + local_groupid: str = "local" diff --git a/snakemake/shell.py b/snakemake/shell.py index 0715ed1f9..87435fe23 100644 --- a/snakemake/shell.py +++ b/snakemake/shell.py @@ -288,7 +288,10 @@ def __new__( if jobid is not None: with cls._lock: - del cls._processes[jobid] + try: + del cls._processes[jobid] + except KeyError: + pass if retcode: raise sp.CalledProcessError(retcode, cmd) diff --git a/snakemake/snakemake.code-workspace b/snakemake/snakemake.code-workspace new file mode 100644 index 000000000..bab1b7f61 --- /dev/null +++ b/snakemake/snakemake.code-workspace @@ -0,0 +1,8 @@ +{ + "folders": [ + { + "path": ".." + } + ], + "settings": {} +} \ No newline at end of file diff --git a/snakemake/sourcecache.py b/snakemake/sourcecache.py index 84137462b..a479cad06 100644 --- a/snakemake/sourcecache.py +++ b/snakemake/sourcecache.py @@ -23,7 +23,7 @@ smart_join, ) from snakemake.exceptions import WorkflowError, SourceFileError -from snakemake.io import split_git_path +from snakemake.common.git import split_git_path def _check_git_args(tag: str = None, branch: str = None, commit: str = None): diff --git a/snakemake/spawn_jobs.py b/snakemake/spawn_jobs.py new file mode 100644 index 000000000..f0da13ab4 --- /dev/null +++ b/snakemake/spawn_jobs.py @@ -0,0 +1,151 @@ +from dataclasses import dataclass +import os +import sys +from typing import TypeVar, TYPE_CHECKING, Any +from snakemake_interface_executor_plugins.utils import format_cli_arg, join_cli_args + +if TYPE_CHECKING: + from snakemake.workflow import Workflow + + TWorkflow = TypeVar("TWorkflow", bound="Workflow") +else: + TWorkflow = Any + + +@dataclass +class SpawnedJobArgsFactory: + workflow: TWorkflow + + def get_default_remote_provider_args(self): + has_default_remote_provider = ( + self.workflow.storage_settings.default_remote_provider is not None + ) + if has_default_remote_provider: + return join_cli_args( + [ + format_cli_arg( + "--default-remote-prefix", + self.workflow.storage_settings.default_remote_prefix, + ), + format_cli_arg( + "--default-remote-provider", + self.workflow.storage_settings.default_remote_provider.name, + ), + ] + ) + else: + return "" + + def get_set_resources_args(self): + return format_cli_arg( + "--set-resources", + [ + f"{rule}:{name}={value}" + for rule, res in self.workflow.resource_settings.overwrite_resources.items() + for name, value in res.items() + ], + skip=not self.workflow.resource_settings.overwrite_resources, + ) + + def get_resource_scopes_args(self): + return format_cli_arg( + "--set-resource-scopes", + self.workflow.resource_settings.overwrite_resource_scopes, + ) + + def workflow_property_to_arg( + self, property, flag=None, quote=True, skip=False, invert=False, attr=None + ): + if skip: + return "" + + # Get the value of the property. If property is nested, follow the hierarchy until + # reaching the final value. + query = property.split(".") + base = self.workflow + for prop in query[:-1]: + base = getattr(base, prop) + value = getattr(base, query[-1]) + + if value is not None and attr is not None: + value = getattr(value, attr) + + if flag is None: + flag = f"--{query[-1].replace('_', '-')}" + + if invert and isinstance(value, bool): + value = not value + + return format_cli_arg(flag, value, quote=quote) + + def general_args( + self, + pass_default_remote_provider_args: bool = True, + pass_default_resources_args: bool = False, + ): + """Return a string to add to self.exec_job that includes additional + arguments from the command line. This is currently used in the + ClusterExecutor and CPUExecutor, as both were using the same + code. Both have base class of the RealExecutor. + """ + w2a = self.workflow_property_to_arg + + args = [ + "--force", + "--target-files-omit-workdir-adjustment", + "--keep-remote", + "--max-inventory-time 0", + "--nocolor", + "--notemp", + "--no-hooks", + "--nolock", + "--ignore-incomplete", + w2a("execution_settings.keep_incomplete"), + w2a("rerun_triggers"), + w2a( + "execution_settings.cleanup_scripts", + invert=True, + flag="--skip-script-cleanup", + ), + w2a("execution_settings.shadow_prefix"), + w2a("deployment_settings.deployment_method"), + w2a("deployment_settings.conda_frontend"), + w2a("deployment_settings.conda_prefix"), + w2a( + "conda_base_path", + skip=not self.workflow.storage_settings.assume_shared_fs, + ), + w2a("deployment_settings.apptainer_prefix"), + w2a("deployment_settings.apptainer_args"), + w2a("resource_settings.max_threads"), + w2a( + "execution_settings.keep_metadata", flag="--drop-metadata", invert=True + ), + w2a("workflow_settings.wrapper_prefix"), + w2a("resource_settings.overwrite_threads", flag="--set-threads"), + w2a("resource_settings.overwrite_scatter", flag="--set-scatter"), + w2a("deployment_settings.conda_not_block_search_path_envvars"), + w2a("overwrite_configfiles", flag="--configfiles"), + w2a("config_settings.config_args", flag="--config"), + w2a("output_settings.printshellcmds"), + w2a("execution_settings.latency_wait"), + w2a("scheduling_settings.scheduler", flag="--scheduler"), + format_cli_arg( + "--scheduler-solver-path", + os.path.dirname(sys.executable), + skip=not self.workflow.storage_settings.assume_shared_fs, + ), + w2a( + "overwrite_workdir", + flag="--directory", + skip=self.workflow.storage_settings.assume_shared_fs, + ), + self.get_set_resources_args(), + self.get_resource_scopes_args(), + ] + if pass_default_remote_provider_args: + args.append(self.get_default_remote_provider_args()) + if pass_default_resources_args: + args.append(w2a("resource_settings.default_resources", attr="args")) + + return join_cli_args(args) diff --git a/snakemake/stats.py b/snakemake/stats.py deleted file mode 100644 index 05d96075b..000000000 --- a/snakemake/stats.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import defaultdict -import json -import time - -from snakemake_interface_executor_plugins.jobs import ExecutorJobInterface - -fmt_time = time.ctime - - -class Stats: - def __init__(self): - self.starttime = dict() - self.endtime = dict() - - def report_job_start(self, job): - if job.is_group(): - for j in job: - self.starttime[j] = time.time() - else: - self.starttime[job] = time.time() - - def report_job_end(self, job): - if job.is_group(): - for j in job: - self.endtime[j] = time.time() - else: - self.endtime[job] = time.time() - - @property - def rule_stats(self): - runtimes = defaultdict(list) - for job, t in self.starttime.items(): - runtimes[job.rule].append(self.endtime[job] - t) - for rule, runtimes in runtimes.items(): - yield (rule, sum(runtimes) / len(runtimes), min(runtimes), max(runtimes)) - - @property - def file_stats(self): - for job, t in self.starttime.items(): - for f in job.expanded_output: - start, stop = t, self.endtime[job] - yield f, fmt_time(start), fmt_time(stop), stop - start, job - - @property - def overall_runtime(self): - if self.starttime and self.endtime: - return max(self.endtime.values()) - min(self.starttime.values()) - else: - return 0 - - def to_json(self, path): - rule_stats = { - rule.name: { - "mean-runtime": mean_runtime, - "min-runtime": min_runtime, - "max-runtime": max_runtime, - } - for rule, mean_runtime, min_runtime, max_runtime in self.rule_stats - } - file_stats = { - f: { - "start-time": start, - "stop-time": stop, - "duration": duration, - "priority": job.priority - if job.priority != ExecutorJobInterface.HIGHEST_PRIORITY - else "highest", - "resources": dict(job.resources.items()), - } - for f, start, stop, duration, job in self.file_stats - } - - with open(path, "w") as f: - json.dump( - { - "total_runtime": self.overall_runtime, - "rules": rule_stats, - "files": file_stats, - }, - f, - indent=4, - ) diff --git a/snakemake/target_jobs.py b/snakemake/target_jobs.py index d74951750..339e75170 100644 --- a/snakemake/target_jobs.py +++ b/snakemake/target_jobs.py @@ -6,11 +6,11 @@ from snakemake.common import parse_key_value_arg -def parse_target_jobs_cli_args(args): +def parse_target_jobs_cli_args(target_jobs_args): errmsg = "Invalid target wildcards definition: entries have to be defined as WILDCARD=VALUE pairs" - if args.target_jobs is not None: + if target_jobs_args is not None: target_jobs = list() - for entry in args.target_jobs: + for entry in target_jobs_args: rulename, wildcards = entry.split(":", 1) if wildcards: diff --git a/snakemake/unit_tests/__init__.py b/snakemake/unit_tests/__init__.py index b6d3ddce7..0d777c2dd 100644 --- a/snakemake/unit_tests/__init__.py +++ b/snakemake/unit_tests/__init__.py @@ -5,6 +5,7 @@ from snakemake.logging import logger from snakemake import __version__ +from snakemake.exceptions import WorkflowError class RuleTest: @@ -26,7 +27,7 @@ def expected_path(self): return self.path / "expected" -def generate(dag, path, deploy=["conda", "singularity"], configfiles=None): +def generate(dag, path: Path, deploy=["conda", "singularity"], configfiles=None): """Generate unit tests from given dag at a given path.""" logger.info("Generating unit tests for each rule...") @@ -43,7 +44,6 @@ def generate(dag, path, deploy=["conda", "singularity"], configfiles=None): lstrip_blocks=True, ) - path = Path(path) os.makedirs(path, exist_ok=True) with open(path / "common.py", "w") as common: diff --git a/snakemake/unit_tests/templates/ruletest.py.jinja2 b/snakemake/unit_tests/templates/ruletest.py.jinja2 index 73569fefb..b957f1319 100644 --- a/snakemake/unit_tests/templates/ruletest.py.jinja2 +++ b/snakemake/unit_tests/templates/ruletest.py.jinja2 @@ -32,7 +32,7 @@ def test_{{ ruletest.name }}(): "{{ ruletest.target }}", "-f", "-j1", - "--keep-target-files", + "--target-files-omit-workdir-adjustment", {% if configfiles %} "--configfile", {% for configfile in configfiles %} diff --git a/snakemake/utils.py b/snakemake/utils.py index 7f7f17610..2a9aaf479 100644 --- a/snakemake/utils.py +++ b/snakemake/utils.py @@ -16,11 +16,11 @@ import sys from urllib.parse import urljoin -from snakemake.io import regex, Namedlist, Wildcards, _load_configfile +from snakemake.io import regex, Namedlist, Wildcards +from snakemake.common.configfile import _load_configfile from snakemake.logging import logger from snakemake.common import ON_WINDOWS from snakemake.exceptions import WorkflowError -import snakemake def validate(data, schema, set_default=True): @@ -468,13 +468,12 @@ def read_job_properties( def min_version(version): """Require minimum snakemake version, raise workflow error if not met.""" import pkg_resources + from snakemake.common import __version__ - if pkg_resources.parse_version(snakemake.__version__) < pkg_resources.parse_version( - version - ): + if pkg_resources.parse_version(__version__) < pkg_resources.parse_version(version): raise WorkflowError( "Expecting Snakemake version {} or higher (you are currently using {}).".format( - version, snakemake.__version__ + version, __version__ ) ) diff --git a/snakemake/workflow.py b/snakemake/workflow.py index efbd668b6..ac7e84f30 100644 --- a/snakemake/workflow.py +++ b/snakemake/workflow.py @@ -3,17 +3,44 @@ __email__ = "johannes.koester@uni-due.de" __license__ = "MIT" +from dataclasses import dataclass, field import re import os import sys from collections import OrderedDict +from collections.abc import Mapping from itertools import filterfalse, chain from functools import partial import copy from pathlib import Path +from typing import List, Optional, Set +from snakemake.common.workdir_handler import WorkdirHandler +from snakemake.settings import ( + ConfigSettings, + DAGSettings, + DeploymentMethod, + DeploymentSettings, + ExecutionSettings, + GroupSettings, + OutputSettings, + RemoteExecutionSettings, + RerunTrigger, + ResourceSettings, + SchedulingSettings, + StorageSettings, + WorkflowSettings, +) from snakemake_interface_executor_plugins.workflow import WorkflowExecutorInterface -from snakemake_interface_executor_plugins.utils import ExecMode +from snakemake_interface_executor_plugins.cli import ( + SpawnedJobArgsFactoryExecutorInterface, +) +from snakemake_interface_executor_plugins.utils import lazy_property +from snakemake_interface_executor_plugins import ExecutorSettingsBase +from snakemake_interface_executor_plugins.registry.plugin import ( + Plugin as ExecutorPlugin, +) +from snakemake_interface_executor_plugins.settings import ExecMode from snakemake.logging import logger, format_resources from snakemake.rules import Rule, Ruleorder, RuleProxy @@ -25,7 +52,7 @@ NoRulesException, WorkflowError, ) -from snakemake.dag import DAG +from snakemake.dag import DAG, ChangeType from snakemake.scheduler import JobScheduler from snakemake.parser import parse import snakemake.io @@ -84,168 +111,66 @@ infer_source_file, ) from snakemake.deployment.conda import Conda -from snakemake import sourcecache +from snakemake import api, sourcecache +@dataclass class Workflow(WorkflowExecutorInterface): - def __init__( - self, - snakefile=None, - rerun_triggers=None, - jobscript=None, - overwrite_shellcmd=None, - overwrite_config=None, - overwrite_workdir=None, - overwrite_configfiles=None, - overwrite_clusterconfig=None, - overwrite_threads=None, - overwrite_scatter=None, - overwrite_groups=None, - overwrite_resources=None, - overwrite_resource_scopes=None, - group_components=None, - config_args=None, - debug=False, - verbose=False, - use_conda=False, - conda_frontend=None, - conda_prefix=None, - use_singularity=False, - use_env_modules=False, - singularity_prefix=None, - singularity_args="", - shadow_prefix=None, - scheduler_type="ilp", - scheduler_ilp_solver=None, - mode=ExecMode.default, - wrapper_prefix=None, - printshellcmds=False, - restart_times=None, - attempt=1, - default_remote_provider=None, - default_remote_prefix="", - run_local=True, - assume_shared_fs=True, - default_resources=None, - cache=None, - nodes=1, - cores=1, - resources=None, - conda_cleanup_pkgs=None, - edit_notebook=False, - envvars=None, - max_inventory_wait_time=20, - conda_not_block_search_path_envvars=False, - execute_subworkflows=True, - scheduler_solver_path=None, - conda_base_path=None, - check_envvars=True, - max_threads=None, - all_temp=False, - local_groupid="local", - keep_metadata=True, - latency_wait=3, - executor_args=None, - cleanup_scripts=True, - immediate_submit=False, - keep_incomplete=False, - quiet=False, - ): + config_settings: ConfigSettings + resource_settings: ResourceSettings + workflow_settings: WorkflowSettings + storage_settings: Optional[StorageSettings] = None + dag_settings: Optional[DAGSettings] = None + execution_settings: Optional[ExecutionSettings] = None + deployment_settings: Optional[DeploymentSettings] = None + scheduling_settings: Optional[SchedulingSettings] = None + output_settings: Optional[OutputSettings] = None + remote_execution_settings: Optional[RemoteExecutionSettings] = None + group_settings: Optional[GroupSettings] = None + executor_settings: ExecutorSettingsBase = None + check_envvars: bool = True + cache_rules: Mapping[str, str] = field(default_factory=dict) + overwrite_workdir: Optional[str] = None + _workdir_handler: Optional[WorkdirHandler] = field(init=False, default=None) + + def __post_init__(self): """ Create the controller. """ + self.global_resources = dict(self.resource_settings.resources) + self.global_resources["_cores"] = self.resource_settings.cores + self.global_resources["_nodes"] = self.resource_settings.nodes - self.global_resources = dict() if resources is None else resources - self.global_resources["_cores"] = cores - self.global_resources["_nodes"] = nodes - - self._rerun_triggers = ( - frozenset(rerun_triggers) if rerun_triggers is not None else frozenset() - ) self._rules = OrderedDict() self.default_target = None - self._workdir = None - self.overwrite_workdir = overwrite_workdir self._workdir_init = os.path.abspath(os.curdir) - self._cleanup_scripts = cleanup_scripts self._ruleorder = Ruleorder() self._localrules = set() self._linemaps = dict() self.rule_count = 0 - self.basedir = os.path.dirname(snakefile) - self._main_snakefile = os.path.abspath(snakefile) self.included = [] self.included_stack = [] - self._jobscript = jobscript self._persistence: Persistence = None - self._subworkflows = dict() - self.overwrite_shellcmd = overwrite_shellcmd - self.overwrite_config = overwrite_config or dict() - self._overwrite_configfiles = overwrite_configfiles - self.overwrite_clusterconfig = overwrite_clusterconfig or dict() - self._overwrite_threads = overwrite_threads or dict() - self._overwrite_resources = overwrite_resources or dict() - self._config_args = config_args - self._immediate_submit = immediate_submit + self._dag: Optional[DAG] = None self._onsuccess = lambda log: None self._onerror = lambda log: None self._onstart = lambda log: None - self._debug = debug - self._verbose = verbose self._rulecount = 0 - self._use_conda = use_conda - self._conda_frontend = conda_frontend - self._conda_prefix = conda_prefix - self._use_singularity = use_singularity - self._use_env_modules = use_env_modules - self.singularity_prefix = singularity_prefix - self._singularity_args = singularity_args - self._shadow_prefix = shadow_prefix - self._scheduler_type = scheduler_type - self.scheduler_ilp_solver = scheduler_ilp_solver self.global_container_img = None self.global_is_containerized = False - self.mode = mode - self._wrapper_prefix = wrapper_prefix - self._printshellcmds = printshellcmds - self.restart_times = restart_times - self.attempt = attempt - self.default_remote_provider = default_remote_provider - self._default_remote_prefix = default_remote_prefix - self.configfiles = ( - [] if overwrite_configfiles is None else list(overwrite_configfiles) - ) - self.run_local = run_local - self._assume_shared_fs = assume_shared_fs + self.configfiles = list(self.config_settings.configfiles) self.report_text = None - self.conda_cleanup_pkgs = conda_cleanup_pkgs - self._edit_notebook = edit_notebook # environment variables to pass to jobs # These are defined via the "envvars:" syntax in the Snakefile itself self._envvars = set() - self.overwrite_groups = overwrite_groups or dict() - self.group_components = group_components or dict() - self._scatter = dict(overwrite_scatter or dict()) - self._overwrite_scatter = overwrite_scatter or dict() - self._overwrite_resource_scopes = overwrite_resource_scopes or dict() + self._scatter = dict(self.resource_settings.overwrite_scatter) self._resource_scopes = ResourceScopes.defaults() - self._resource_scopes.update(self.overwrite_resource_scopes) - self._conda_not_block_search_path_envvars = conda_not_block_search_path_envvars - self._execute_subworkflows = execute_subworkflows + self._resource_scopes.update(self.resource_settings.overwrite_resource_scopes) self.modules = dict() self._sourcecache = SourceCache() - self.scheduler_solver_path = scheduler_solver_path - self._conda_base_path = conda_base_path - self.check_envvars = check_envvars - self._max_threads = max_threads - self.all_temp = all_temp - self._executor_args = executor_args self._scheduler = None - self._local_groupid = local_groupid - self._keep_metadata = keep_metadata - self._latency_wait = latency_wait - self._keep_incomplete = keep_incomplete - self._quiet = quiet + self._spawned_job_general_args = None + self._executor_plugin = None _globals = globals() from snakemake.shell import shell @@ -261,57 +186,99 @@ def __init__( self.vanilla_globals = dict(_globals) self.modifier_stack = [WorkflowModifier(self, globals=_globals)] + self._output_file_cache = None + self.cache_rules = dict() - self.enable_cache = False - if cache is not None: - self.enable_cache = True - self.cache_rules = {rulename: "all" for rulename in cache} - if self.default_remote_provider is not None: - self._output_file_cache = RemoteOutputFileCache( - self.default_remote_provider - ) - else: - self._output_file_cache = LocalOutputFileCache() - else: - self._output_file_cache = None - self.cache_rules = dict() + self.globals["config"] = copy.deepcopy(self.config_settings.overwrite_config) - if default_resources is not None: - self._default_resources = default_resources - else: - # only _cores, _nodes, and _tmpdir - self._default_resources = DefaultResources(mode="bare") + @property + def enable_cache(self): + return ( + self.execution_settings is not None + and self.execution_settings.cache is not None + ) - self.iocache = snakemake.io.IOCache(max_inventory_wait_time) + def check_cache_rules(self): + for rule in self.rules: + cache_mode = self.cache_rules.get(rule.name) + if cache_mode: + if len(rule.output) > 1: + if not all(out.is_multiext for out in rule.output): + raise WorkflowError( + "Rule is marked for between workflow caching but has multiple output files. " + "This is only allowed if multiext() is used to declare them (see docs on between " + "workflow caching).", + rule=rule, + ) + if not self.enable_cache: + logger.warning( + f"Workflow defines that rule {rule.name} is eligible for caching between workflows " + "(use the --cache argument to enable this)." + ) + if rule.benchmark: + raise WorkflowError( + "Rules with a benchmark directive may not be marked as eligible " + "for between-workflow caching at the same time. The reason is that " + "when the result is taken from cache, there is no way to fill the benchmark file with " + "any reasonable values. Either remove the benchmark directive or disable " + "between-workflow caching for this rule.", + rule=rule, + ) - self.globals["config"] = copy.deepcopy(self.overwrite_config) + @property + def attempt(self): + if self.execution_settings is None: + # if not executing, we can safely set this to 1 + return 1 + return self.execution_settings.attempt - if envvars is not None: - self.register_envvars(*envvars) + @property + def executor_plugin(self): + return self._executor_plugin @property - def quiet(self): - return self._quiet + def dryrun(self): + if self.executor_plugin is None: + return False + else: + return self.executor_plugin.common_settings.dryrun_exec @property - def assume_shared_fs(self): - return self._assume_shared_fs + def touch(self): + import snakemake.executors.touch + + return isinstance( + self.executor_plugin.executor, snakemake.executors.touch.Executor + ) @property - def keep_incomplete(self): - return self._keep_incomplete + def use_threads(self): + return ( + self.workflow.execution_settings.use_threads + or (os.name not in ["posix", "nt"]) + or not self.local_exec + ) @property - def executor_args(self): - return self._executor_args + def local_exec(self): + if self.executor_plugin is not None: + return self.executor_plugin.common_settings.local_exec + else: + return True @property - def default_remote_prefix(self): - return self._default_remote_prefix + def non_local_exec(self): + return not self.local_exec + + @lazy_property + def spawned_job_args_factory(self) -> SpawnedJobArgsFactoryExecutorInterface: + from snakemake.spawn_jobs import SpawnedJobArgsFactory + + return SpawnedJobArgsFactory(self) @property - def immediate_submit(self): - return self._immediate_submit + def basedir(self): + return os.path.dirname(self.main_snakefile) @property def scheduler(self): @@ -325,42 +292,10 @@ def scheduler(self, scheduler): def envvars(self): return self._envvars - @property - def jobscript(self): - return self._jobscript - - @property - def verbose(self): - return self._verbose - @property def sourcecache(self): return self._sourcecache - @property - def edit_notebook(self): - return self._edit_notebook - - @property - def cleanup_scripts(self): - return self._cleanup_scripts - - @property - def debug(self): - return self._debug - - @property - def use_env_modules(self): - return self._use_env_modules - - @property - def use_singularity(self): - return self._use_singularity - - @property - def use_conda(self): - return self._use_conda - @property def workdir_init(self): return self._workdir_init @@ -374,8 +309,12 @@ def persistence(self): return self._persistence @property - def main_snakefile(self): - return self._main_snakefile + def dag(self): + return self._dag + + @property + def main_snakefile(self) -> str: + return self.included[0].get_path_or_uri() @property def output_file_cache(self): @@ -385,95 +324,19 @@ def output_file_cache(self): def resource_scopes(self): return self._resource_scopes - @property - def overwrite_resource_scopes(self): - return self._overwrite_resource_scopes - - @property - def default_resources(self): - return self._default_resources - - @property - def scheduler_type(self): - return self._scheduler_type - - @property - def printshellcmds(self): - return self._printshellcmds - - @property - def config_args(self): - return self._config_args - @property def overwrite_configfiles(self): - return self._overwrite_configfiles - - @property - def conda_not_block_search_path_envvars(self): - return self._conda_not_block_search_path_envvars - - @property - def local_groupid(self): - return self._local_groupid + return self.config_settings.configfiles @property - def overwrite_scatter(self): - return self._overwrite_scatter - - @property - def overwrite_threads(self): - return self._overwrite_threads - - @property - def wrapper_prefix(self): - return self._wrapper_prefix - - @property - def keep_metadata(self): - return self._keep_metadata - - @property - def max_threads(self): - return self._max_threads - - @property - def execute_subworkflows(self): - return self._execute_subworkflows - - @property - def singularity_args(self): - return self._singularity_args - - @property - def conda_prefix(self): - return self._conda_prefix - - @property - def conda_frontend(self): - return self._conda_frontend - - @property - def shadow_prefix(self): - return self._shadow_prefix - - @property - def rerun_triggers(self): - return self._rerun_triggers - - @property - def latency_wait(self): - return self._latency_wait - - @property - def overwrite_resources(self): - return self._overwrite_resources + def rerun_triggers(self) -> Set[RerunTrigger]: + return self.dag_settings.rerun_triggers @property def conda_base_path(self): - if self._conda_base_path: - return self._conda_base_path - if self.use_conda: + if self.deployment_settings.conda_base_path: + return self.deployment_settings.conda_base_path + if DeploymentMethod.CONDA in self.deployment_settings.deployment_method: try: return Conda().prefix_path except CreateCondaEnvironmentException as e: @@ -520,11 +383,10 @@ def lint(self, json=False): return linted def get_cache_mode(self, rule: Rule): - return self.cache_rules.get(rule.name) - - @property - def subworkflows(self): - return self._subworkflows.values() + if self.dag_settings.cache is None: + return None + else: + return self.cache_rules.get(rule.name) @property def rules(self): @@ -567,6 +429,8 @@ def check(self): raise UnknownRuleException( rulename, prefix="Error in ruleorder definition." ) + self.check_cache_rules() + self.check_localrules() def add_rule( self, @@ -622,7 +486,8 @@ def list_rules(self, only_targets=False): if only_targets: rules = filterfalse(Rule.has_wildcards, rules) for rule in sorted(rules, key=lambda r: r.name): - logger.rule_info(name=rule.name, docstring=rule.docstring) + docstring = f" ({rule.docstring})" if rule.docstring else "" + print(rule.name + docstring) def list_resources(self): for resource in set( @@ -655,114 +520,33 @@ def inputfile(self, path): """ if isinstance(path, Path): path = str(path) - if self.default_remote_provider is not None: + if self.storage_settings.default_remote_provider is not None: path = self.modifier.modify_path(path) return IOFile(path) - def execute( + def _prepare_dag( self, - targets=None, - target_jobs=None, - dryrun=False, - generate_unit_tests=None, - touch=False, - scheduler_type=None, - scheduler_ilp_solver=None, - local_cores=1, - forcetargets=False, - forceall=False, - forcerun=None, - until=[], - omit_from=[], - prioritytargets=None, - keepgoing=False, - printdag=False, - slurm=None, - slurm_jobstep=None, - cluster=None, - cluster_sync=None, - jobname=None, - ignore_ambiguity=False, - printrulegraph=False, - printfilegraph=False, - printd3dag=False, - drmaa=None, - drmaa_log_dir=None, - kubernetes=None, - k8s_cpu_scalar=1.0, - k8s_service_account_name=None, - flux=None, - tibanna=None, - tibanna_sfn=None, - az_batch=False, - az_batch_enable_autoscale=False, - az_batch_account_url=None, - google_lifesciences=None, - google_lifesciences_regions=None, - google_lifesciences_location=None, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=None, - google_lifesciences_network=None, - google_lifesciences_subnetwork=None, - tes=None, - precommand="", - preemption_default=None, - preemptible_rules=None, - tibanna_config=False, - container_image=None, - stats=None, - force_incomplete=False, - ignore_incomplete=False, - list_version_changes=False, - list_code_changes=False, - list_input_changes=False, - list_params_changes=False, - list_untracked=False, - list_conda_envs=False, - summary=False, - archive=None, - delete_all_output=False, - delete_temp_output=False, - detailed_summary=False, - wait_for_files=None, - nolock=False, - unlock=False, - notemp=False, - nodeps=False, - cleanup_metadata=None, - conda_cleanup_envs=False, - cleanup_containers=False, - cleanup_shadow=False, - subsnakemake=None, - updated_files=None, - keep_target_files=False, - # Note that keep_shadow doesn't seem to be used? - keep_shadow=False, - keep_remote_local=False, - allowed_rules=None, - max_jobs_per_second=None, - max_status_checks_per_second=None, - greediness=1.0, - no_hooks=False, - force_use_threads=False, - conda_create_envs_only=False, - cluster_status=None, - cluster_cancel=None, - cluster_cancel_nargs=None, - cluster_sidecar=None, - report=None, - report_stylesheet=None, - export_cwl=False, - batch=None, - keepincomplete=False, - containerize=False, - ): - self.check_localrules() + forceall: bool, + ignore_incomplete: bool, + lock_warn_only: bool, + nolock: bool = False, + shadow_prefix: Optional[str] = None, + ) -> DAG: + if self.dag_settings.cache is not None: + self.cache_rules.update( + {rulename: "all" for rulename in self.dag_settings.cache} + ) + if self.storage_settings.default_remote_provider is not None: + self._output_file_cache = RemoteOutputFileCache( + self.storage_settings.default_remote_provider + ) + else: + self._output_file_cache = LocalOutputFileCache() def rules(items): return map(self._rules.__getitem__, filter(self.is_rule, items)) - if keep_target_files: + if self.dag_settings.target_files_omit_workdir_adjustment: def files(items): return filterfalse(self.is_rule, items) @@ -777,28 +561,27 @@ def files(items): ) return map(relpath, filterfalse(self.is_rule, items)) - if not targets and not target_jobs: + self.iocache = snakemake.io.IOCache(self.dag_settings.max_inventory_wait_time) + + if not self.dag_settings.targets and not self.dag_settings.target_jobs: targets = ( [self.default_target] if self.default_target is not None else list() ) + else: + targets = self.dag_settings.targets - if prioritytargets is None: - prioritytargets = list() - if forcerun is None: - forcerun = list() - if until is None: - until = list() - if omit_from is None: - omit_from = list() + prioritytargets = set() + if self.scheduling_settings is not None: + prioritytargets = self.scheduling_settings.prioritytargets priorityrules = set(rules(prioritytargets)) priorityfiles = set(files(prioritytargets)) - forcerules = set(rules(forcerun)) - forcefiles = set(files(forcerun)) - untilrules = set(rules(until)) - untilfiles = set(files(until)) - omitrules = set(rules(omit_from)) - omitfiles = set(files(omit_from)) + forcerules = set(rules(self.dag_settings.forcerun)) + forcefiles = set(files(self.dag_settings.forcerun)) + untilrules = set(rules(self.dag_settings.until)) + untilfiles = set(files(self.dag_settings.until)) + omitrules = set(rules(self.dag_settings.omit_from)) + omitfiles = set(files(self.dag_settings.omit_from)) targetrules = set( chain( rules(targets), @@ -812,34 +595,24 @@ def files(items): if ON_WINDOWS: targetfiles = set(tf.replace(os.sep, os.altsep) for tf in targetfiles) - if forcetargets: + if self.dag_settings.forcetargets: forcefiles.update(targetfiles) forcerules.update(targetrules) rules = self.rules - if allowed_rules: - allowed_rules = set(allowed_rules) - rules = [rule for rule in rules if rule.name in allowed_rules] - - if wait_for_files is not None: - try: - snakemake.io.wait_for_files( - wait_for_files, latency_wait=self.latency_wait - ) - except IOError as e: - logger.error(str(e)) - return False + if self.dag_settings.allowed_rules: + rules = [ + rule for rule in rules if rule.name in self.dag_settings.allowed_rules + ] - dag = DAG( + self._dag = DAG( self, rules, - dryrun=dryrun, targetfiles=targetfiles, targetrules=targetrules, - target_jobs_def=target_jobs, # when cleaning up conda or containers, we should enforce all possible jobs # since their envs shall not be deleted - forceall=forceall or conda_cleanup_envs or cleanup_containers, + forceall=forceall, forcefiles=forcefiles, forcerules=forcerules, priorityfiles=priorityfiles, @@ -848,452 +621,488 @@ def files(items): untilrules=untilrules, omitfiles=omitfiles, omitrules=omitrules, - ignore_ambiguity=ignore_ambiguity, - force_incomplete=force_incomplete, - ignore_incomplete=ignore_incomplete - or printdag - or printrulegraph - or printfilegraph, - notemp=notemp, - keep_remote_local=keep_remote_local, - batch=batch, + ignore_incomplete=ignore_incomplete, ) self._persistence = Persistence( nolock=nolock, - dag=dag, - conda_prefix=self.conda_prefix, - singularity_prefix=self.singularity_prefix, - shadow_prefix=self.shadow_prefix, - warn_only=dryrun - or printrulegraph - or printfilegraph - or printdag - or summary - or detailed_summary - or archive - or list_version_changes - or list_code_changes - or list_input_changes - or list_params_changes - or list_untracked - or delete_all_output - or delete_temp_output, + dag=self._dag, + conda_prefix=self.deployment_settings.conda_prefix, + singularity_prefix=self.deployment_settings.apptainer_prefix, + shadow_prefix=shadow_prefix, + warn_only=lock_warn_only, ) - if self.mode in [ExecMode.subprocess, ExecMode.remote]: - self.persistence.deactivate_cache() + def generate_unit_tests(self, path: Path): + """Generate unit tests for the workflow. - if cleanup_metadata: - failed = [] - for f in cleanup_metadata: - success = self.persistence.cleanup_metadata(f) - if not success: - failed.append(f) - if failed: - logger.warning( - "Failed to clean up metadata for the following files because the metadata was not present.\n" - "If this is expected, there is nothing to do.\nOtherwise, the reason might be file system latency " - "or still running jobs.\nConsider running metadata cleanup again.\nFiles:\n" - + "\n".join(failed) - ) - return True + Arguments + path -- Path to the directory where the unit tests shall be generated. + """ + from snakemake import unit_tests - if unlock: - try: - self.persistence.cleanup_locks() - logger.info("Unlocking working directory.") - return True - except IOError: - logger.error( - "Error: Unlocking the directory {} failed. Maybe " - "you don't have the permissions?" - ) - return False + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=False, + lock_warn_only=False, + ) + self._build_dag() + + deploy = [] + if DeploymentMethod.CONDA in self.deployment_settings.deployment_method: + deploy.append("conda") + if DeploymentMethod.APPTAINER in self.deployment_settings.deployment_method: + deploy.append("singularity") + unit_tests.generate( + self.dag, path, deploy, configfiles=self.overwrite_configfiles + ) - logger.info("Building DAG of jobs...") - dag.init() - dag.update_checkpoint_dependencies() - dag.check_dynamic() + def cleanup_metadata(self, paths: List[Path]): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=True, + lock_warn_only=False, + ) + failed = [] + for path in paths: + success = self.persistence.cleanup_metadata(path) + if not success: + failed.append(path) + if failed: + raise WorkflowError( + "Failed to clean up metadata for the following files because the metadata was not present.\n" + "If this is expected, there is nothing to do.\nOtherwise, the reason might be file system latency " + "or still running jobs.\nConsider running metadata cleanup again.\nFiles:\n" + + "\n".join(failed) + ) - self.persistence.lock() + def unlock(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=False, + ) + self._build_dag() + try: + self.persistence.cleanup_locks() + logger.info("Unlocked working directory.") + except IOError as e: + raise WorkflowError( + f"Error: Unlocking the directory {os.getcwd()} failed. Maybe " + "you don't have the permissions?", + e, + ) - if cleanup_shadow: + def cleanup_shadow(self): + self._prepare_dag(forceall=False, ignore_incomplete=False, lock_warn_only=False) + self._build_dag() + with self.persistence.lock(): self.persistence.cleanup_shadow() - return True - if containerize: - from snakemake.deployment.containerize import containerize + def delete_output(self, only_temp: bool = False, dryrun: bool = False): + self._prepare_dag(forceall=False, ignore_incomplete=False, lock_warn_only=True) + self._build_dag() - containerize(self, dag) - return True + self.dag.clean(only_temp=only_temp, dryrun=dryrun) + + def list_untracked(self): + self._prepare_dag(forceall=False, ignore_incomplete=False, lock_warn_only=True) + self._build_dag() + + self.dag.list_untracked() + + def list_changes(self, change_type: ChangeType): + self._prepare_dag(forceall=False, ignore_incomplete=False, lock_warn_only=True) + self._build_dag() + + items = self.dag.get_outputs_with_changes(change_type) + if items: + print(*items, sep="\n") + + def archive(self, path: Path): + """Archive the workflow. + + Arguments + path -- Path to the archive file. + """ + self._prepare_dag(forceall=False, ignore_incomplete=False, lock_warn_only=True) + self._build_dag() + + self.dag.archive(path) + + def summary(self, detailed: bool = False): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=True, + lock_warn_only=True, + ) + self._build_dag() + + print("\n".join(self.dag.summary(detailed=detailed))) + + def conda_cleanup_envs(self): + self._prepare_dag(forceall=True, ignore_incomplete=True, lock_warn_only=False) + self._build_dag() + self.persistence.conda_cleanup_envs() + + def printdag(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=True, + ) + self._build_dag() + print(self.dag) + + def printrulegraph(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=True, + ) + self._build_dag() + self.dag.rule_dot() + + def printfilegraph(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=True, + lock_warn_only=True, + ) + self._build_dag() + print(self.dag.filegraph_dot()) + + def printd3dag(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=True, + lock_warn_only=True, + ) + self._build_dag() + + self.dag.d3dag() + + def containerize(self): + from snakemake.deployment.containerize import containerize + + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=False, + lock_warn_only=False, + ) + self._build_dag() + with self.persistence.lock(): + containerize(self, self.dag) + + def export_cwl(self, path: Path): + """Export the workflow as CWL document. + + Arguments + path -- the path to the CWL document to be created. + """ + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=True, + lock_warn_only=False, + ) + self._build_dag() + + from snakemake.cwl import dag_to_cwl + import json + + with open(path, "w") as cwl: + json.dump(dag_to_cwl(self.dag), cwl, indent=4) + + def create_report(self, path: Path, stylesheet: Optional[Path] = None): + from snakemake.report import auto_report + + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=False, + lock_warn_only=False, + ) + self._build_dag() + + auto_report(self.dag, path, stylesheet=stylesheet) + + def conda_list_envs(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=False, + lock_warn_only=False, + ) + self._build_dag() if ( - self.subworkflows - and self.execute_subworkflows - and not printdag - and not printrulegraph - and not printfilegraph + DeploymentMethod.APPTAINER in self.deployment_settings.deployment_method + and self.storage_settings.assume_shared_fs ): - # backup globals - globals_backup = dict(self.globals) - # execute subworkflows - for subworkflow in self.subworkflows: - subworkflow_targets = subworkflow.targets(dag) - logger.debug( - "Files requested from subworkflow:\n {}".format( - "\n ".join(subworkflow_targets) - ) + self.dag.pull_container_imgs() + self.dag.create_conda_envs( + dryrun=True, + quiet=True, + ) + print("environment", "container", "location", sep="\t") + for env in set(job.conda_env for job in self.dag.jobs): + if env and not env.is_named: + print( + env.file.simplify_path(), + env.container_img_url or "", + simplify_path(env.address), + sep="\t", ) - updated = list() - if subworkflow_targets: - logger.info(f"Executing subworkflow {subworkflow.name}.") - if not subsnakemake( - subworkflow.snakefile, - workdir=subworkflow.workdir, - targets=subworkflow_targets, - cores=self._cores, - nodes=self.nodes, - resources=self.global_resources, - configfiles=[subworkflow.configfile] - if subworkflow.configfile - else None, - updated_files=updated, - rerun_triggers=self.rerun_triggers, - ): - return False - dag.updated_subworkflow_files.update( - subworkflow.target(f) for f in updated - ) - else: - logger.info( - f"Subworkflow {subworkflow.name}: {NOTHING_TO_BE_DONE_MSG}" - ) - if self.subworkflows: - logger.info("Executing main workflow.") - # rescue globals - self.globals.update(globals_backup) - - dag.postprocess(update_needrun=False) - if not dryrun: - # deactivate IOCache such that from now on we always get updated - # size, existence and mtime information - # ATTENTION: this may never be removed without really good reason. - # Otherwise weird things may happen. - self.iocache.deactivate() - # clear and deactivate persistence cache, from now on we want to see updates - self.persistence.deactivate_cache() + return True - if nodeps: - missing_input = [ - f - for job in dag.targetjobs - for f in job.input - if dag.needrun(job) and not os.path.exists(f) - ] - if missing_input: - logger.error( - "Dependency resolution disabled (--nodeps) " - "but missing input " - "files detected. If this happens on a cluster, please make sure " - "that you handle the dependencies yourself or turn off " - "--immediate-submit. Missing input files:\n{}".format( - "\n".join(missing_input) - ) - ) - return False + def conda_create_envs(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=False, + ) + self._build_dag() - if self.immediate_submit and any(dag.checkpoint_jobs): - logger.error( - "Immediate submit mode (--immediate-submit) may not be used for workflows " - "with checkpoint jobs, as the dependencies cannot be determined before " - "execution in such cases." - ) - return False + if ( + DeploymentMethod.APPTAINER in self.deployment_settings.deployment_method + and self.storage_settings.assume_shared_fs + ): + self.dag.pull_container_imgs() + self.dag.create_conda_envs() + + def conda_cleanup_envs(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=False, + ) + self._build_dag() + self.persistence.conda_cleanup_envs() + + def container_cleanup_images(self): + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=False, + ) + self._build_dag() + self.persistence.cleanup_containers() - updated_files.extend(f for job in dag.needrun_jobs() for f in job.output) + def _build_dag(self): + logger.info("Building DAG of jobs...") + self.dag.init() + self.dag.update_checkpoint_dependencies() + self.dag.check_dynamic() - if generate_unit_tests: - from snakemake import unit_tests + def execute( + self, + executor_plugin: ExecutorPlugin, + executor_settings: ExecutorSettingsBase, + updated_files: Optional[List[str]] = None, + ): + from snakemake.shell import shell - path = generate_unit_tests - deploy = [] - if self.use_conda: - deploy.append("conda") - if self.use_singularity: - deploy.append("singularity") - unit_tests.generate( - dag, path, deploy, configfiles=self.overwrite_configfiles - ) - return True - elif export_cwl: - from snakemake.cwl import dag_to_cwl - import json + shell.conda_block_conflicting_envvars = ( + not self.deployment_settings.conda_not_block_search_path_envvars + ) - with open(export_cwl, "w") as cwl: - json.dump(dag_to_cwl(dag), cwl, indent=4) - return True - elif report: - from snakemake.report import auto_report + if self.remote_execution_settings.envvars: + self.register_envvars(*self.remote_execution_settings.envvars) - auto_report(dag, report, stylesheet=report_stylesheet) - return True - elif printd3dag: - dag.d3dag() - return True - elif printdag: - print(dag) - return True - elif printrulegraph: - print(dag.rule_dot()) - return True - elif printfilegraph: - print(dag.filegraph_dot()) - return True - elif summary: - print("\n".join(dag.summary(detailed=False))) - return True - elif detailed_summary: - print("\n".join(dag.summary(detailed=True))) - return True - elif archive: - dag.archive(archive) - return True - elif delete_all_output: - dag.clean(only_temp=False, dryrun=dryrun) - return True - elif delete_temp_output: - dag.clean(only_temp=True, dryrun=dryrun) - return True - elif list_version_changes: - items = dag.get_outputs_with_changes("version") - if items: - print(*items, sep="\n") - return True - elif list_code_changes: - items = dag.get_outputs_with_changes("code") - if items: - print(*items, sep="\n") - return True - elif list_input_changes: - items = dag.get_outputs_with_changes("input") - if items: - print(*items, sep="\n") - return True - elif list_params_changes: - items = dag.get_outputs_with_changes("params") - if items: - print(*items, sep="\n") - return True - elif list_untracked: - dag.list_untracked() - return True + self._executor_plugin = executor_plugin + self.executor_settings = executor_settings - if self.use_singularity and self.assume_shared_fs: - dag.pull_container_imgs( - dryrun=dryrun or list_conda_envs or cleanup_containers, - quiet=list_conda_envs, - ) - if self.use_conda: - dag.create_conda_envs( - dryrun=dryrun or list_conda_envs or conda_cleanup_envs, - quiet=list_conda_envs, - ) - if conda_create_envs_only: - return True - - if list_conda_envs: - print("environment", "container", "location", sep="\t") - for env in set(job.conda_env for job in dag.jobs): - if env and not env.is_named: - print( - env.file.simplify_path(), - env.container_img_url or "", - simplify_path(env.address), - sep="\t", - ) - return True + if self.execution_settings.wait_for_files: + try: + snakemake.io.wait_for_files( + self.execution_settings.wait_for_files, + latency_wait=self.execution_settings.latency_wait, + ) + except IOError as e: + logger.error(str(e)) + return False - if conda_cleanup_envs: - self.persistence.conda_cleanup_envs() - return True + self._prepare_dag( + forceall=self.dag_settings.forceall, + ignore_incomplete=self.execution_settings.ignore_incomplete, + lock_warn_only=self.dryrun, + nolock=not self.execution_settings.lock, + shadow_prefix=self.execution_settings.shadow_prefix, + ) - if cleanup_containers: - self.persistence.cleanup_containers() - return True + if self.execution_settings.mode in [ExecMode.SUBPROCESS, ExecMode.REMOTE]: + self.persistence.deactivate_cache() - self.scheduler = JobScheduler( - self, - dag, - local_cores=local_cores, - dryrun=dryrun, - touch=touch, - slurm=slurm, - slurm_jobstep=slurm_jobstep, - cluster=cluster, - cluster_status=cluster_status, - cluster_cancel=cluster_cancel, - cluster_cancel_nargs=cluster_cancel_nargs, - cluster_sidecar=cluster_sidecar, - cluster_sync=cluster_sync, - jobname=jobname, - max_jobs_per_second=max_jobs_per_second, - max_status_checks_per_second=max_status_checks_per_second, - keepgoing=keepgoing, - drmaa=drmaa, - drmaa_log_dir=drmaa_log_dir, - kubernetes=kubernetes, - k8s_cpu_scalar=k8s_cpu_scalar, - k8s_service_account_name=k8s_service_account_name, - flux=flux, - tibanna=tibanna, - tibanna_sfn=tibanna_sfn, - az_batch=az_batch, - az_batch_enable_autoscale=az_batch_enable_autoscale, - az_batch_account_url=az_batch_account_url, - google_lifesciences=google_lifesciences, - google_lifesciences_regions=google_lifesciences_regions, - google_lifesciences_location=google_lifesciences_location, - google_lifesciences_cache=google_lifesciences_cache, - google_lifesciences_service_account_email=google_lifesciences_service_account_email, - google_lifesciences_network=google_lifesciences_network, - google_lifesciences_subnetwork=google_lifesciences_subnetwork, - tes=tes, - preemption_default=preemption_default, - preemptible_rules=preemptible_rules, - precommand=precommand, - tibanna_config=tibanna_config, - container_image=container_image, - greediness=greediness, - force_use_threads=force_use_threads, - scheduler_type=scheduler_type, - scheduler_ilp_solver=scheduler_ilp_solver, - executor_args=self.executor_args, - ) + self._build_dag() + + with self.persistence.lock(): + self.dag.postprocess(update_needrun=False) + if not self.dryrun: + # deactivate IOCache such that from now on we always get updated + # size, existence and mtime information + # ATTENTION: this may never be removed without really good reason. + # Otherwise weird things may happen. + self.iocache.deactivate() + # clear and deactivate persistence cache, from now on we want to see updates + self.persistence.deactivate_cache() + + if self.remote_execution_settings.immediate_submit and any( + self.dag.checkpoint_jobs + ): + raise WorkflowError( + "Immediate submit mode (--immediate-submit) may not be used for workflows " + "with checkpoint jobs, as the dependencies cannot be determined before " + "execution in such cases." + ) - if not dryrun: - if len(dag): - from snakemake.shell import shell - - shell_exec = shell.get_executable() - if shell_exec is not None: - logger.info(f"Using shell: {shell_exec}") - if cluster or cluster_sync or drmaa: - logger.resources_info(f"Provided cluster nodes: {self.nodes}") - elif kubernetes or tibanna or google_lifesciences: - logger.resources_info(f"Provided cloud nodes: {self.nodes}") - else: - if self._cores is not None: - warning = ( - "" - if self._cores > 1 - else " (use --cores to define parallelism)" - ) - logger.resources_info(f"Provided cores: {self._cores}{warning}") + if updated_files is not None: + updated_files.extend( + f for job in self.dag.needrun_jobs() for f in job.output + ) + + if ( + DeploymentMethod.APPTAINER in self.deployment_settings.deployment_method + and self.storage_settings.assume_shared_fs + ): + self.dag.pull_container_imgs() + if DeploymentMethod.CONDA in self.deployment_settings.deployment_method: + self.dag.create_conda_envs() + + self.scheduler = JobScheduler(self, executor_plugin) + + if not self.dryrun: + if len(self.dag): + from snakemake.shell import shell + + shell_exec = shell.get_executable() + if shell_exec is not None: + logger.info(f"Using shell: {shell_exec}") + if not self.local_exec: + logger.resources_info(f"Provided remote nodes: {self.nodes}") + else: + if self._cores is not None: + warning = ( + "" + if self._cores > 1 + else " (use --cores to define parallelism)" + ) + logger.resources_info( + f"Provided cores: {self._cores}{warning}" + ) + logger.resources_info( + "Rules claiming more threads will be scaled down." + ) + + provided_resources = format_resources(self.global_resources) + if provided_resources: logger.resources_info( - "Rules claiming more threads will be scaled down." + f"Provided resources: {provided_resources}" ) - provided_resources = format_resources(self.global_resources) - if provided_resources: - logger.resources_info(f"Provided resources: {provided_resources}") - - if self.run_local and any(rule.group for rule in self.rules): - logger.info("Group jobs: inactive (local execution)") + if self.local_exec and any(rule.group for rule in self.rules): + logger.info("Group jobs: inactive (local execution)") - if not self.use_conda and any(rule.conda_env for rule in self.rules): - logger.info("Conda environments: ignored") + if ( + DeploymentMethod.CONDA + not in self.deployment_settings.deployment_method + and any(rule.conda_env for rule in self.rules) + ): + logger.info("Conda environments: ignored") - if not self.use_singularity and any( - rule.container_img for rule in self.rules - ): - logger.info("Singularity containers: ignored") + if ( + DeploymentMethod.APPTAINER + not in self.deployment_settings.deployment_method + and any(rule.container_img for rule in self.rules) + ): + logger.info("Singularity containers: ignored") - if self.mode == ExecMode.default: - logger.run_info("\n".join(dag.stats())) - else: - logger.info(NOTHING_TO_BE_DONE_MSG) - else: - # the dryrun case - if len(dag): - logger.run_info("\n".join(dag.stats())) + if self.execution_settings.mode == ExecMode.DEFAULT: + logger.run_info("\n".join(self.dag.stats())) + else: + logger.info(NOTHING_TO_BE_DONE_MSG) + return else: - logger.info(NOTHING_TO_BE_DONE_MSG) - return True - if self.quiet: - # in case of dryrun and quiet, just print above info and exit - return True - - if not dryrun and not no_hooks: - self._onstart(logger.get_logfile()) - - def log_provenance_info(): - provenance_triggered_jobs = [ - job - for job in dag.needrun_jobs(exclude_finished=False) - if dag.reason(job).is_provenance_triggered() - ] - if provenance_triggered_jobs: - logger.info( - "Some jobs were triggered by provenance information, " - "see 'reason' section in the rule displays above.\n" - "If you prefer that only modification time is used to " - "determine whether a job shall be executed, use the command " - "line option '--rerun-triggers mtime' (also see --help).\n" - "If you are sure that a change for a certain output file (say, ) won't " - "change the result (e.g. because you just changed the formatting of a script " - "or environment definition), you can also wipe its metadata to skip such a trigger via " - "'snakemake --cleanup-metadata '. " - ) - logger.info( - "Rules with provenance triggered jobs: " - + ",".join( - sorted(set(job.rule.name for job in provenance_triggered_jobs)) + # the dryrun case + if len(self.dag): + logger.run_info("\n".join(self.dag.stats())) + else: + logger.info(NOTHING_TO_BE_DONE_MSG) + return + if self.output_settings.quiet: + # in case of dryrun and quiet, just print above info and exit + return + + if not self.dryrun and not self.execution_settings.no_hooks: + self._onstart(logger.get_logfile()) + + def log_provenance_info(): + provenance_triggered_jobs = [ + job + for job in self.dag.needrun_jobs(exclude_finished=False) + if self.dag.reason(job).is_provenance_triggered() + ] + if provenance_triggered_jobs: + logger.info( + "Some jobs were triggered by provenance information, " + "see 'reason' section in the rule displays above.\n" + "If you prefer that only modification time is used to " + "determine whether a job shall be executed, use the command " + "line option '--rerun-triggers mtime' (also see --help).\n" + "If you are sure that a change for a certain output file (say, ) won't " + "change the result (e.g. because you just changed the formatting of a script " + "or environment definition), you can also wipe its metadata to skip such a trigger via " + "'snakemake --cleanup-metadata '. " ) - ) - logger.info("") + logger.info( + "Rules with provenance triggered jobs: " + + ",".join( + sorted( + set(job.rule.name for job in provenance_triggered_jobs) + ) + ) + ) + logger.info("") - has_checkpoint_jobs = any(dag.checkpoint_jobs) + has_checkpoint_jobs = any(self.dag.checkpoint_jobs) - try: - success = self.scheduler.schedule() - except Exception as e: - if dryrun: - log_provenance_info() - raise e - - if not self.immediate_submit and not dryrun and self.mode == ExecMode.default: - dag.cleanup_workdir() - - if success: - if dryrun: - if len(dag): - logger.run_info("\n".join(dag.stats())) - dag.print_reasons() + try: + success = self.scheduler.schedule() + except Exception as e: + if self.dryrun: log_provenance_info() - logger.info("") - logger.info( - "This was a dry-run (flag -n). The order of jobs " - "does not reflect the order of execution." - ) - if has_checkpoint_jobs: + raise e + + if ( + not self.remote_execution_settings.immediate_submit + and not self.dryrun + and self.execution_settings.mode == ExecMode.DEFAULT + ): + self.dag.cleanup_workdir() + + if success: + if self.dryrun: + if len(self.dag): + logger.run_info("\n".join(self.dag.stats())) + self.dag.print_reasons() + log_provenance_info() + logger.info("") logger.info( - "The run involves checkpoint jobs, " - "which will result in alteration of the DAG of " - "jobs (e.g. adding more jobs) after their completion." + "This was a dry-run (flag -n). The order of jobs " + "does not reflect the order of execution." ) + if has_checkpoint_jobs: + logger.info( + "The run involves checkpoint jobs, " + "which will result in alteration of the DAG of " + "jobs (e.g. adding more jobs) after their completion." + ) + else: + logger.logfile_hint() + if not self.dryrun and not self.execution_settings.no_hooks: + self._onsuccess(logger.get_logfile()) else: - if stats: - self.scheduler.stats.to_json(stats) + if not self.dryrun and not self.execution_settings.no_hooks: + self._onerror(logger.get_logfile()) logger.logfile_hint() - if not dryrun and not no_hooks: - self._onsuccess(logger.get_logfile()) - return True - else: - if not dryrun and not no_hooks: - self._onerror(logger.get_logfile()) - logger.logfile_hint() - return False + raise WorkflowError("At least one job did not complete successfully.") @property def current_basedir(self): @@ -1370,7 +1179,6 @@ def include( snakefile, overwrite_default_target=False, print_compilation=False, - overwrite_shellcmd=None, ): """ Include a snakefile. @@ -1388,7 +1196,6 @@ def include( code, linemap, rulecount = parse( snakefile, self, - overwrite_shellcmd=self.overwrite_shellcmd, rulecount=self._rulecount, ) self._rulecount = rulecount @@ -1431,7 +1238,7 @@ def global_wildcard_constraints(self, **content): def scattergather(self, **content): """Register scattergather defaults.""" self._scatter.update(content) - self._scatter.update(self.overwrite_scatter) + self._scatter.update(self.resource_settings.overwrite_scatter) # add corresponding wildcard constraint self.global_wildcard_constraints(scatteritem=r"\d+-of-\d+") @@ -1451,29 +1258,30 @@ def func(key, *args, **wildcards): def resourcescope(self, **content): """Register resource scope defaults""" self.resource_scopes.update(content) - self.resource_scopes.update(self.overwrite_resource_scopes) + self.resource_scopes.update(self.resource_settings.overwrite_resource_scopes) def workdir(self, workdir): """Register workdir.""" if self.overwrite_workdir is None: - os.makedirs(workdir, exist_ok=True) - self._workdir = workdir - os.chdir(workdir) + self._workdir_handler = WorkdirHandler(Path(workdir)) + self._workdir_handler.change_to() def configfile(self, fp): """Update the global config with data from the given file.""" + from snakemake.common.configfile import load_configfile + if not self.modifier.skip_configfile: if os.path.exists(fp): self.configfiles.append(fp) - c = snakemake.io.load_configfile(fp) + c = load_configfile(fp) update_config(self.config, c) - if self.overwrite_config: + if self.config_settings.overwrite_config: logger.info( "Config file {} is extended by additional config specified via the command line.".format( fp ) ) - update_config(self.config, self.overwrite_config) + update_config(self.config, self.config_settings.overwrite_config) elif not self.overwrite_configfiles: fp_full = os.path.abspath(fp) raise WorkflowError( @@ -1481,7 +1289,7 @@ def configfile(self, fp): ) else: # CLI configfiles have been specified, do not throw an error but update with their values - update_config(self.config, self.overwrite_config) + update_config(self.config, self.config_settings.overwrite_config) def set_pepfile(self, path): try: @@ -1517,15 +1325,6 @@ def config(self): def ruleorder(self, *rulenames): self._ruleorder.add(*map(self.modifier.modify_rulename, rulenames)) - def subworkflow(self, name, snakefile=None, workdir=None, configfile=None): - # Take absolute path of config file, because it is relative to current - # workdir, which could be changed for the subworkflow. - if configfile: - configfile = os.path.abspath(configfile) - sw = Subworkflow(self, name, snakefile, workdir, configfile) - self._subworkflows[name] = sw - self.globals[name] = sw.target - def localrules(self, *rulenames): self._localrules.update(rulenames) @@ -1582,8 +1381,15 @@ def decorate(ruleinfo): if ruleinfo.params: rule.set_params(*ruleinfo.params[0], **ruleinfo.params[1]) # handle default resources - if self.default_resources is not None: - rule.resources = copy.deepcopy(self.default_resources.parsed) + if self.resource_settings.default_resources is not None: + rule.resources = copy.deepcopy( + self.resource_settings.default_resources.parsed + ) + else: + rule.resources = dict() + # Always require one node + rule.resources["_nodes"] = 1 + if ruleinfo.threads is not None: if ( not isinstance(ruleinfo.threads, int) @@ -1594,12 +1400,17 @@ def decorate(ruleinfo): "Threads value has to be an integer, float, or a callable.", rule=rule, ) - if name in self.overwrite_threads: - rule.resources["_cores"] = self.overwrite_threads[name] + if name in self.resource_settings.overwrite_threads: + rule.resources["_cores"] = self.resource_settings.overwrite_threads[ + name + ] else: if isinstance(ruleinfo.threads, float): ruleinfo.threads = int(ruleinfo.threads) rule.resources["_cores"] = ruleinfo.threads + else: + rule.resources["_cores"] = 1 + if ruleinfo.shadow_depth: if ruleinfo.shadow_depth not in ( True, @@ -1639,8 +1450,8 @@ def decorate(ruleinfo): rule=rule, ) rule.resources.update(resources) - if name in self.overwrite_resources: - rule.resources.update(self.overwrite_resources[name]) + if name in self.resource_settings.overwrite_resources: + rule.resources.update(self.resource_settings.overwrite_resources[name]) if ruleinfo.priority: if not isinstance(ruleinfo.priority, int) and not isinstance( @@ -1656,12 +1467,9 @@ def decorate(ruleinfo): raise RuleException( "Retries values have to be integers >= 0", rule=rule ) - rule.restart_times = ( - self.restart_times if ruleinfo.retries is None else ruleinfo.retries - ) - if ruleinfo.version: - rule.version = ruleinfo.version + rule.restart_times = ruleinfo.retries + if ruleinfo.log: rule.log_modifier = ruleinfo.log.modifier rule.set_log(*ruleinfo.log.paths, **ruleinfo.log.kwpaths) @@ -1670,13 +1478,14 @@ def decorate(ruleinfo): if ruleinfo.benchmark: rule.benchmark_modifier = ruleinfo.benchmark.modifier rule.benchmark = ruleinfo.benchmark.paths - if not self.run_local: - group = self.overwrite_groups.get(name) or ruleinfo.group - if group is not None: - rule.group = group + + group = ruleinfo.group + if group is not None: + rule.group = group + if ruleinfo.wrapper: rule.conda_env = snakemake.wrapper.get_conda_env( - ruleinfo.wrapper, prefix=self.wrapper_prefix + ruleinfo.wrapper, prefix=self.workflow_settings.wrapper_prefix ) # TODO retrieve suitable singularity image @@ -1771,40 +1580,20 @@ def decorate(ruleinfo): self._localrules.add(rule.name) rule.is_handover = True - if ruleinfo.cache: - if len(rule.output) > 1: - if not rule.output[0].is_multiext: - raise WorkflowError( - "Rule is marked for between workflow caching but has multiple output files. " - "This is only allowed if multiext() is used to declare them (see docs on between " - "workflow caching).", - rule=rule, - ) - if not self.enable_cache: - logger.warning( - "Workflow defines that rule {} is eligible for caching between workflows " - "(use the --cache argument to enable this).".format(rule.name) - ) - else: - if ruleinfo.cache is True or "omit-software" or "all": - self.cache_rules[rule.name] = ( - "all" if ruleinfo.cache is True else ruleinfo.cache - ) - else: - raise WorkflowError( - "Invalid value for cache directive. Use True or 'omit-software'.", - rule=rule, - ) - if ruleinfo.benchmark and self.get_cache_mode(rule): + if ruleinfo.cache and not ( + ruleinfo.cache is True + or ruleinfo.cache == "omit-software" + or ruleinfo.cache == "all" + ): raise WorkflowError( - "Rules with a benchmark directive may not be marked as eligible " - "for between-workflow caching at the same time. The reason is that " - "when the result is taken from cache, there is no way to fill the benchmark file with " - "any reasonable values. Either remove the benchmark directive or disable " - "between-workflow caching for this rule.", + "Invalid value for cache directive. Use 'all' or 'omit-software'.", rule=rule, ) + self.cache_rules[rule.name] = ( + "all" if ruleinfo.cache is True else ruleinfo.cache + ) + if ruleinfo.default_target is True: self.default_target = rule.name elif not (ruleinfo.default_target is False): @@ -1985,13 +1774,6 @@ def decorate(ruleinfo): return decorate - def version(self, version): - def decorate(ruleinfo): - ruleinfo.version = version - return ruleinfo - - return decorate - def group(self, group): def decorate(ruleinfo): ruleinfo.group = group @@ -2155,59 +1937,3 @@ def decorate(maybe_ruleinfo): @staticmethod def _empty_decorator(f): return f - - -class Subworkflow: - def __init__(self, workflow, name, snakefile, workdir, configfile): - self.workflow = workflow - self.name = name - self._snakefile = snakefile - self._workdir = workdir - self.configfile = configfile - - @property - def snakefile(self): - if self._snakefile is None: - return os.path.abspath(os.path.join(self.workdir, "Snakefile")) - if not os.path.isabs(self._snakefile): - return os.path.abspath(os.path.join(self.workflow.basedir, self._snakefile)) - return self._snakefile - - @property - def workdir(self): - workdir = "." if self._workdir is None else self._workdir - if not os.path.isabs(workdir): - return os.path.abspath(os.path.join(self.workflow.basedir, workdir)) - return workdir - - def target(self, paths): - if not_iterable(paths): - path = paths - path = ( - path - if os.path.isabs(path) or path.startswith("root://") - else os.path.join(self.workdir, path) - ) - return flag(path, "subworkflow", self) - return [self.target(path) for path in paths] - - def targets(self, dag): - def relpath(f): - if f.startswith(self.workdir): - return os.path.relpath(f, start=self.workdir) - # do not adjust absolute targets outside of workdir - return f - - return [ - relpath(f) - for job in dag.jobs - for f in job.subworkflow_input - if job.subworkflow_input[f] is self - ] - - -def srcdir(path): - """Return the absolute path, relative to the source directory of the current Snakefile.""" - if not workflow.included_stack: - return None - return workflow.current_basedir.join(path).get_path_or_uri() diff --git a/test-environment.yml b/test-environment.yml index cfc1729c7..a11471aef 100644 --- a/test-environment.yml +++ b/test-environment.yml @@ -8,7 +8,7 @@ dependencies: - stopit - datrie - boto3 - - moto + - moto =3 # breaking changes in moto 4 that would need an update of S3Mocked - junit-xml # needed for S3Mocked - httpretty - wrapt @@ -71,3 +71,7 @@ dependencies: - azure-batch - azure-mgmt-batch - azure-identity + - nodejs # for cwltool + - apptainer + - squashfuse # for apptainer + - immutables diff --git a/test.py b/test.py new file mode 100644 index 000000000..676e3b227 --- /dev/null +++ b/test.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from abc import ABC, abstractmethod +import typing + + +class A(ABC): + @property + @abstractmethod + def test(self) -> typing.Optional[int]: + ... + + +@dataclass +class B(A): + test: int = frozenset() + + +print(B(test=2)) +print(isinstance(B(test=2), A)) diff --git a/tests/common.py b/tests/common.py index 7f5f4666f..50cece258 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,6 +4,7 @@ __license__ = "MIT" import os +from pathlib import Path import signal import sys import shlex @@ -18,10 +19,11 @@ import subprocess import tarfile -from snakemake.api import snakemake -from snakemake.shell import shell +from snakemake_interface_executor_plugins.registry import ExecutorPluginRegistry + +from snakemake import api, settings from snakemake.common import ON_WINDOWS -from snakemake.resources import DefaultResources, GroupResources, ResourceScopes +from snakemake.resources import ResourceScopes def dpath(path): @@ -136,12 +138,51 @@ def run( cleanup=True, conda_frontend="mamba", config=dict(), - targets=None, + targets=set(), container_image=os.environ.get("CONTAINER_IMAGE", "snakemake/snakemake:latest"), shellcmd=None, sigint_after=None, overwrite_resource_scopes=None, - **params, + executor="local", + executor_settings=None, + cleanup_scripts=True, + scheduler_ilp_solver=None, + report=None, + report_stylesheet=None, + deployment_method=frozenset(), + shadow_prefix=None, + until=frozenset(), + omit_from=frozenset(), + forcerun=frozenset(), + conda_list_envs=False, + conda_prefix=None, + wrapper_prefix=None, + printshellcmds=False, + default_remote_provider=None, + default_remote_prefix=None, + archive=None, + cluster=None, + cluster_status=None, + retries=0, + resources=dict(), + default_resources=None, + group_components=dict(), + max_threads=None, + overwrite_groups=dict(), + configfiles=list(), + overwrite_resources=dict(), + batch=None, + envvars=list(), + cache=None, + edit_notebook=None, + overwrite_scatter=dict(), + generate_unit_tests=None, + force_incomplete=False, + containerize=False, + forceall=False, + all_temp=False, + cleanup_metadata=None, + rerun_triggers=settings.RerunTrigger.all(), ): """ Test the Snakefile in the path. @@ -177,21 +218,6 @@ def run( config = dict(config) - # handle subworkflow - if subpath is not None: - # set up a working directory for the subworkflow and pass it in `config` - # for now, only one subworkflow is supported - assert os.path.exists(subpath) and os.path.isdir( - subpath - ), "{} does not exist".format(subpath) - subworkdir = os.path.join(tmpdir, "subworkdir") - os.mkdir(subworkdir) - - # copy files - for f in os.listdir(subpath): - copy(os.path.join(subpath, f), subworkdir) - config["subworkdir"] = subworkdir - # copy files for f in os.listdir(path): copy(os.path.join(path, f), tmpdir) @@ -199,6 +225,9 @@ def run( # Snakefile is now in temporary directory snakefile = join(tmpdir, snakefile) + snakemake_api = None + exception = None + # run snakemake if shellcmd: if not shellcmd.startswith("snakemake"): @@ -231,29 +260,120 @@ def run( print(e.stdout.decode(), file=sys.stderr) else: assert sigint_after is None, "Cannot sent SIGINT when calling directly" - success = snakemake( - snakefile=original_snakefile if no_tmpdir else snakefile, - cores=cores, - nodes=nodes, - workdir=path if no_tmpdir else tmpdir, - stats="stats.txt", - config=config, - verbose=True, - targets=targets, - conda_frontend=conda_frontend, - container_image=container_image, - overwrite_resource_scopes=( - ResourceScopes(overwrite_resource_scopes) - if overwrite_resource_scopes is not None - else overwrite_resource_scopes + + if cluster is not None: + executor = "cluster-generic" + plugin = ExecutorPluginRegistry().get_plugin(executor) + executor_settings = plugin.executor_settings_class( + submit_cmd=cluster, status_cmd=cluster_status + ) + nodes = 3 + + success = True + + with api.SnakemakeApi( + settings.OutputSettings( + verbose=True, + printshellcmds=printshellcmds, + show_failed_logs=True, ), - **params, - ) + ) as snakemake_api: + try: + workflow_api = snakemake_api.workflow( + resource_settings=settings.ResourceSettings( + cores=cores, + nodes=nodes, + overwrite_resource_scopes=( + ResourceScopes(overwrite_resource_scopes) + if overwrite_resource_scopes is not None + else dict() + ), + overwrite_resources=overwrite_resources, + resources=resources, + default_resources=default_resources, + max_threads=max_threads, + overwrite_scatter=overwrite_scatter, + ), + config_settings=settings.ConfigSettings( + config=config, + configfiles=configfiles, + ), + storage_settings=settings.StorageSettings( + default_remote_provider=default_remote_provider, + default_remote_prefix=default_remote_prefix, + all_temp=all_temp, + ), + workflow_settings=settings.WorkflowSettings( + wrapper_prefix=wrapper_prefix, + ), + snakefile=Path(original_snakefile if no_tmpdir else snakefile), + workdir=Path(path if no_tmpdir else tmpdir), + ) + + dag_api = workflow_api.dag( + dag_settings=settings.DAGSettings( + targets=targets, + until=until, + omit_from=omit_from, + forcerun=forcerun, + batch=batch, + force_incomplete=force_incomplete, + cache=cache, + forceall=forceall, + rerun_triggers=rerun_triggers, + ), + deployment_settings=settings.DeploymentSettings( + conda_frontend=conda_frontend, + conda_prefix=conda_prefix, + deployment_method=deployment_method, + ), + ) + + if report is not None: + dag_api.create_report(path=report, stylesheet=report_stylesheet) + elif conda_list_envs: + dag_api.conda_list_envs() + elif archive is not None: + dag_api.archive(Path(archive)) + elif generate_unit_tests is not None: + dag_api.generate_unit_tests(Path(generate_unit_tests)) + elif containerize: + dag_api.containerize() + elif cleanup_metadata: + dag_api.cleanup_metadata(cleanup_metadata) + else: + dag_api.execute_workflow( + executor=executor, + execution_settings=settings.ExecutionSettings( + cleanup_scripts=cleanup_scripts, + shadow_prefix=shadow_prefix, + retries=retries, + edit_notebook=edit_notebook, + ), + remote_execution_settings=settings.RemoteExecutionSettings( + container_image=container_image, + seconds_between_status_checks=0, + envvars=envvars, + ), + scheduling_settings=settings.SchedulingSettings( + ilp_solver=scheduler_ilp_solver, + ), + group_settings=settings.GroupSettings( + group_components=group_components, + overwrite_groups=overwrite_groups, + ), + executor_settings=executor_settings, + ) + except Exception as e: + success = False + exception = e if shouldfail: assert not success, "expected error on execution" else: if not success: + if snakemake_api is not None and exception is not None: + snakemake_api.print_exception(exception) print("Workdir:") print_tree(tmpdir, exclude=".snakemake/conda") assert success, "expected successful execution" diff --git a/tests/test01/Snakefile b/tests/test01/Snakefile index cc3eb7b16..e09780a69 100644 --- a/tests/test01/Snakefile +++ b/tests/test01/Snakefile @@ -44,11 +44,10 @@ rule rule1: 'test.inter' output: 'dir/test.out' log: a='log/logfile.log' - version: version() threads: 3 shell: 'if [ {threads} -ne 3 ]; then echo "This test has to be run with -j3 in order to succeed!"; exit 1; fi; ' \ - 'echo {TEST}; echo {version}; cp {input[0]} {output[0]}; ' # append a comment + 'echo {TEST}; cp {input[0]} {output[0]}; ' # append a comment 'echo test > {log.a}' rule rule2: diff --git a/tests/test14/Snakefile.nonstandard b/tests/test14/Snakefile.nonstandard deleted file mode 100644 index 570e64c03..000000000 --- a/tests/test14/Snakefile.nonstandard +++ /dev/null @@ -1,50 +0,0 @@ -from snakemake.shell import shell - -chromosomes = [1, 2, 3, 4, 5] - - -envvars: - "TESTVAR", - "TESTVAR2", - - -rule all: - input: - "test.predictions", - "test.2.inter2", - - -rule compute1: - input: - "{name}.in", - expand("raw.{i}.txt", i=range(22)), - output: - ["{name}.%s.inter" % c for c in chromosomes], - params: - prefix="{name}", - run: - for out in output: - shell('(cat {input[0]} && echo "Part {out}") > {out}') - print(os.getcwd()) - print(os.listdir()) - - -rule compute2: - input: - "{name}.{chromosome}.inter", - output: - "{name}.{chromosome}.inter2", - params: - test="a=b", - threads: workflow.cores * 0.5 - shell: - "echo copy; cp {input[0]} {output[0]}" - - -rule gather: - input: - ["{name}.%s.inter2" % c for c in chromosomes], - output: - "{name}.predictions", - run: - shell("cat {} > {}".format(" ".join(input), output[0])) diff --git a/tests/test14/expected-results/test.1.inter b/tests/test14/expected-results/test.1.inter deleted file mode 100644 index 5cc1d91f6..000000000 --- a/tests/test14/expected-results/test.1.inter +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.1.inter diff --git a/tests/test14/expected-results/test.1.inter2 b/tests/test14/expected-results/test.1.inter2 deleted file mode 100644 index 5cc1d91f6..000000000 --- a/tests/test14/expected-results/test.1.inter2 +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.1.inter diff --git a/tests/test14/expected-results/test.2.inter b/tests/test14/expected-results/test.2.inter deleted file mode 100644 index 8b02f7f0b..000000000 --- a/tests/test14/expected-results/test.2.inter +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.2.inter diff --git a/tests/test14/expected-results/test.2.inter2 b/tests/test14/expected-results/test.2.inter2 deleted file mode 100644 index 8b02f7f0b..000000000 --- a/tests/test14/expected-results/test.2.inter2 +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.2.inter diff --git a/tests/test14/expected-results/test.3.inter b/tests/test14/expected-results/test.3.inter deleted file mode 100644 index 5144542ec..000000000 --- a/tests/test14/expected-results/test.3.inter +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.3.inter diff --git a/tests/test14/expected-results/test.3.inter2 b/tests/test14/expected-results/test.3.inter2 deleted file mode 100644 index 5144542ec..000000000 --- a/tests/test14/expected-results/test.3.inter2 +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.3.inter diff --git a/tests/test14/expected-results/test.predictions b/tests/test14/expected-results/test.predictions deleted file mode 100644 index 7d97db630..000000000 --- a/tests/test14/expected-results/test.predictions +++ /dev/null @@ -1,10 +0,0 @@ -testz0r -Part test.1.inter -testz0r -Part test.2.inter -testz0r -Part test.3.inter -testz0r -Part test.4.inter -testz0r -Part test.5.inter diff --git a/tests/test14/qsub b/tests/test14/qsub deleted file mode 100755 index 0bc8aabba..000000000 --- a/tests/test14/qsub +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -echo `date` >> qsub.log -tail -n1 $1 >> qsub.log -# simulate printing of job id by a random number -echo $RANDOM -cat $1 >> qsub.log -sh $1 diff --git a/tests/test14/qsub.py b/tests/test14/qsub.py deleted file mode 100755 index c3d4fcbad..000000000 --- a/tests/test14/qsub.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python3 -import sys -import os -import random - -from snakemake.utils import read_job_properties - -jobscript = sys.argv[1] -job_properties = read_job_properties(jobscript) -with open("qsub.log", "a") as log: - print(job_properties, file=log) - -print(random.randint(1, 100)) -os.system("sh {}".format(jobscript)) diff --git a/tests/test14/raw.10.txt b/tests/test14/raw.10.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.11.txt b/tests/test14/raw.11.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.12.txt b/tests/test14/raw.12.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.13.txt b/tests/test14/raw.13.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.14.txt b/tests/test14/raw.14.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.15.txt b/tests/test14/raw.15.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.16.txt b/tests/test14/raw.16.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.17.txt b/tests/test14/raw.17.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.18.txt b/tests/test14/raw.18.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.19.txt b/tests/test14/raw.19.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.2.txt b/tests/test14/raw.2.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.20.txt b/tests/test14/raw.20.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.21.txt b/tests/test14/raw.21.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.3.txt b/tests/test14/raw.3.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.4.txt b/tests/test14/raw.4.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.5.txt b/tests/test14/raw.5.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.6.txt b/tests/test14/raw.6.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.7.txt b/tests/test14/raw.7.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.8.txt b/tests/test14/raw.8.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/raw.9.txt b/tests/test14/raw.9.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test14/test.in b/tests/test14/test.in deleted file mode 100644 index ce667834a..000000000 --- a/tests/test14/test.in +++ /dev/null @@ -1 +0,0 @@ -testz0r diff --git a/tests/test_azure_batch_executor.py b/tests/test_azure_batch_executor.py deleted file mode 100644 index 2e54fd133..000000000 --- a/tests/test_azure_batch_executor.py +++ /dev/null @@ -1,26 +0,0 @@ -import sys -import os -import re - -sys.path.insert(0, os.path.dirname(__file__)) - -from common import * - - -@azbatch -def test_az_batch_executor(): - # AZ_BATCH_ACCOUNT_URL=https://${batch_account_name}.${region}.batch.azure.com - bau = os.getenv("AZ_BATCH_ACCOUNT_URL") - prefix = os.getenv("AZ_BLOB_PREFIX") - wdir = dpath("test_azure_batch") - blob_account_url = os.getenv("AZ_BLOB_ACCOUNT_URL") - assert blob_account_url is not None and blob_account_url.strip() != "" - - run( - path=wdir, - default_remote_prefix=prefix, - container_image="snakemake/snakemake", - envvars=["AZ_BLOB_ACCOUNT_URL", "AZ_BLOB_CREDENTIAL"], - az_batch=True, - az_batch_account_url=bau, - ) diff --git a/tests/test_cluster_sidecar/Snakefile b/tests/test_cluster_sidecar/Snakefile deleted file mode 100644 index 80c388472..000000000 --- a/tests/test_cluster_sidecar/Snakefile +++ /dev/null @@ -1,21 +0,0 @@ - - - -rule all: - input: - "f.1", - "f.2", - - -rule one: - output: - "f.1", - shell: - "touch {output}" - - -rule two: - output: - "f.2", - shell: - "touch {output}" diff --git a/tests/test_cluster_sidecar/expected-results/f.1 b/tests/test_cluster_sidecar/expected-results/f.1 deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_cluster_sidecar/expected-results/f.2 b/tests/test_cluster_sidecar/expected-results/f.2 deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_cluster_sidecar/expected-results/launched.txt b/tests/test_cluster_sidecar/expected-results/launched.txt deleted file mode 100644 index 3381db118..000000000 --- a/tests/test_cluster_sidecar/expected-results/launched.txt +++ /dev/null @@ -1,2 +0,0 @@ -SNAKEMAKE_CLUSTER_SIDECAR_VARS=FIRST_LINE -SNAKEMAKE_CLUSTER_SIDECAR_VARS=FIRST_LINE diff --git a/tests/test_cluster_sidecar/expected-results/sidecar.txt b/tests/test_cluster_sidecar/expected-results/sidecar.txt deleted file mode 100644 index bfeb0fe5e..000000000 --- a/tests/test_cluster_sidecar/expected-results/sidecar.txt +++ /dev/null @@ -1,2 +0,0 @@ -sidecar started -sidecar stopped diff --git a/tests/test_cluster_sidecar/sbatch b/tests/test_cluster_sidecar/sbatch deleted file mode 100755 index 7995c09c4..000000000 --- a/tests/test_cluster_sidecar/sbatch +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -set -x -echo "SNAKEMAKE_CLUSTER_SIDECAR_VARS=$SNAKEMAKE_CLUSTER_SIDECAR_VARS" >>launched.txt -echo --sbatch-- >> sbatch.log -echo `date` >> sbatch.log -tail -n1 $1 >> sbatch.log -cat $1 >> sbatch.log -# daemonize job script -nohup sh $1 0<&- &>/dev/null & -# print PID for job number -echo $! diff --git a/tests/test_cluster_sidecar/sidecar.sh b/tests/test_cluster_sidecar/sidecar.sh deleted file mode 100755 index 7849e66b2..000000000 --- a/tests/test_cluster_sidecar/sidecar.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -set -ex - -echo "FIRST_LINE" -echo "sidecar started" > sidecar.txt -sleep infinity & -pid=$! - -catch() -{ - set -x - kill -TERM $pid || true - echo "sidecar stopped" >> sidecar.txt - exit 0 -} - -trap catch SIGTERM SIGINT - -wait diff --git a/tests/test_cluster_sidecar/test.in b/tests/test_cluster_sidecar/test.in deleted file mode 100644 index ce667834a..000000000 --- a/tests/test_cluster_sidecar/test.in +++ /dev/null @@ -1 +0,0 @@ -testz0r diff --git a/tests/test_cluster_statusscript/Snakefile.nonstandard b/tests/test_cluster_statusscript/Snakefile.nonstandard deleted file mode 100644 index ce066f7c7..000000000 --- a/tests/test_cluster_statusscript/Snakefile.nonstandard +++ /dev/null @@ -1,32 +0,0 @@ - - -chromosomes = [1,2,3,4,5] - -envvars: - "TESTVAR" - - - -rule all: - input: 'test.predictions', 'test.2.inter2' - -rule compute1: - input: '{name}.in' - output: ['{name}.%s.inter'%c for c in chromosomes] - params: prefix="{name}" - run: - for out in output: - shell('(cat {input[0]} && echo "Part {out}") > {out}') - -rule compute2: - input: '{name}.{chromosome}.inter' - output: '{name}.{chromosome}.inter2' - params: test="a=b" - threads: workflow.cores * 0.5 - shell: 'echo copy; cp {input[0]} {output[0]}' - -rule gather: - input: ['{name}.%s.inter2'%c for c in chromosomes] - output: '{name}.predictions' - run: - shell('cat {} > {}'.format(' '.join(input), output[0])) diff --git a/tests/test_cluster_statusscript/expected-results/test.1.inter b/tests/test_cluster_statusscript/expected-results/test.1.inter deleted file mode 100644 index 5cc1d91f6..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.1.inter +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.1.inter diff --git a/tests/test_cluster_statusscript/expected-results/test.1.inter2 b/tests/test_cluster_statusscript/expected-results/test.1.inter2 deleted file mode 100644 index 5cc1d91f6..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.1.inter2 +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.1.inter diff --git a/tests/test_cluster_statusscript/expected-results/test.2.inter b/tests/test_cluster_statusscript/expected-results/test.2.inter deleted file mode 100644 index 8b02f7f0b..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.2.inter +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.2.inter diff --git a/tests/test_cluster_statusscript/expected-results/test.2.inter2 b/tests/test_cluster_statusscript/expected-results/test.2.inter2 deleted file mode 100644 index 8b02f7f0b..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.2.inter2 +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.2.inter diff --git a/tests/test_cluster_statusscript/expected-results/test.3.inter b/tests/test_cluster_statusscript/expected-results/test.3.inter deleted file mode 100644 index 5144542ec..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.3.inter +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.3.inter diff --git a/tests/test_cluster_statusscript/expected-results/test.3.inter2 b/tests/test_cluster_statusscript/expected-results/test.3.inter2 deleted file mode 100644 index 5144542ec..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.3.inter2 +++ /dev/null @@ -1,2 +0,0 @@ -testz0r -Part test.3.inter diff --git a/tests/test_cluster_statusscript/expected-results/test.predictions b/tests/test_cluster_statusscript/expected-results/test.predictions deleted file mode 100644 index 7d97db630..000000000 --- a/tests/test_cluster_statusscript/expected-results/test.predictions +++ /dev/null @@ -1,10 +0,0 @@ -testz0r -Part test.1.inter -testz0r -Part test.2.inter -testz0r -Part test.3.inter -testz0r -Part test.4.inter -testz0r -Part test.5.inter diff --git a/tests/test_cluster_statusscript/qsub b/tests/test_cluster_statusscript/qsub deleted file mode 100755 index 5857ca5f0..000000000 --- a/tests/test_cluster_statusscript/qsub +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -if [[ ! -z "$SNAKEMAKE_PROFILE" ]]; then - >&2 echo "SNAKEMAKE_PROFILE should not be set" - exit 1 -fi -echo `date` >> qsub.log -tail -n1 $1 >> qsub.log -# simulate printing of job id by a random number -echo $RANDOM -cat $1 >> qsub.log -sh $1 diff --git a/tests/test_cluster_statusscript/status.sh b/tests/test_cluster_statusscript/status.sh deleted file mode 100755 index e9d4078f6..000000000 --- a/tests/test_cluster_statusscript/status.sh +++ /dev/null @@ -1 +0,0 @@ -echo success diff --git a/tests/test_cluster_statusscript/test.in b/tests/test_cluster_statusscript/test.in deleted file mode 100644 index ce667834a..000000000 --- a/tests/test_cluster_statusscript/test.in +++ /dev/null @@ -1 +0,0 @@ -testz0r diff --git a/tests/test_cluster_statusscript_multi/Snakefile.nonstandard b/tests/test_cluster_statusscript_multi/Snakefile.nonstandard deleted file mode 100644 index eefffaefd..000000000 --- a/tests/test_cluster_statusscript_multi/Snakefile.nonstandard +++ /dev/null @@ -1,13 +0,0 @@ - - -envvars: - "TESTVAR" - - - -rule all: - input: 'output.txt' - -rule compute: - output: 'output.txt' - shell: 'touch {output}' diff --git a/tests/test_cluster_statusscript_multi/expected-results/output.txt b/tests/test_cluster_statusscript_multi/expected-results/output.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_cluster_statusscript_multi/sbatch b/tests/test_cluster_statusscript_multi/sbatch deleted file mode 100755 index 39b6e88da..000000000 --- a/tests/test_cluster_statusscript_multi/sbatch +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -echo `date` >> sbatch.log -tail -n1 $1 >> sbatch.log -# simulate printing of job id by a random number plus the name -# of the cluster -echo "$RANDOM;name-of-cluster" -cat $1 >> sbatch.log -sh $1 diff --git a/tests/test_cluster_statusscript_multi/status.sh b/tests/test_cluster_statusscript_multi/status.sh deleted file mode 100755 index 9e3eb1724..000000000 --- a/tests/test_cluster_statusscript_multi/status.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -# The argument passed from sbatch is "jobid;cluster_name" - -arg="$1" -jobid="${arg%%;*}" -cluster="${arg##*;}" - -echo success diff --git a/tests/test_conda_function/Snakefile b/tests/test_conda_function/Snakefile index 647027fd0..5bf12873a 100644 --- a/tests/test_conda_function/Snakefile +++ b/tests/test_conda_function/Snakefile @@ -25,9 +25,11 @@ rule all: rule a: output: "test_{version}_v.out", + log: + err="test_{version}.log", params: name_prefix="foo", conda: conda_func shell: - r"rg --version | head -n1 | grep -o 'ripgrep [0-9]*\.[0-9]*\.[0-9]*' > {output}" + r"(rg --version | grep -o 'ripgrep [0-9]*\.[0-9]*\.[0-9]*' > {output}) 2> {log.err}" diff --git a/tests/test_conda_named/Snakefile b/tests/test_conda_named/Snakefile index 1742416bd..9b67e8026 100644 --- a/tests/test_conda_named/Snakefile +++ b/tests/test_conda_named/Snakefile @@ -1,10 +1,12 @@ import subprocess as sp +import sys try: shell( "mamba create -y -n xxx-test-env -c conda-forge --override-channels ripgrep==13.0.0" ) + print("created conda env", file=sys.stderr) except sp.CalledProcessError as e: print(e.stderr) raise e @@ -21,7 +23,9 @@ onerror: rule a: output: "test.out", + log: + err="test.log", conda: "xxx-test-env" shell: - r"rg --version | head -n1 | grep -o 'ripgrep [0-9]*\.[0-9]*\.[0-9]*' > {output}" + r"(rg --version | grep -o 'ripgrep [0-9]*\.[0-9]*\.[0-9]*' > {output}) 2> {log.err}" diff --git a/tests/test_conda_pin_file/Snakefile b/tests/test_conda_pin_file/Snakefile index 098048136..2a30065d6 100644 --- a/tests/test_conda_pin_file/Snakefile +++ b/tests/test_conda_pin_file/Snakefile @@ -6,5 +6,5 @@ rule a: conda: "test-env.yaml" shell: - "rg --version | head -n1 | cut -f2 -d' ' > {output}" + "rg --version > version.txt; head -n1 version.txt | cut -f2 -d' ' > {output}" diff --git a/tests/test_converting_path_for_r_script/Snakefile b/tests/test_converting_path_for_r_script/Snakefile index c19e776e7..6afca7630 100644 --- a/tests/test_converting_path_for_r_script/Snakefile +++ b/tests/test_converting_path_for_r_script/Snakefile @@ -14,5 +14,7 @@ rule step1: param_dir=Path("dir") output: out_file=Path("out-file.txt") + conda: + "env.yaml" script: "r-script.R" diff --git a/tests/test_converting_path_for_r_script/env.yaml b/tests/test_converting_path_for_r_script/env.yaml new file mode 100644 index 000000000..d81c9475f --- /dev/null +++ b/tests/test_converting_path_for_r_script/env.yaml @@ -0,0 +1,4 @@ +channels: + - conda-forge +dependencies: + - r-base diff --git a/tests/test_google_lifesciences.py b/tests/test_google_lifesciences.py deleted file mode 100644 index 82bcf288f..000000000 --- a/tests/test_google_lifesciences.py +++ /dev/null @@ -1,191 +0,0 @@ -import os -import requests -import sys -import tempfile -import google.auth - -from google.cloud import storage - -sys.path.insert(0, os.path.dirname(__file__)) - -from common import * - - -def has_google_credentials(): - credentials, _ = google.auth.default() - return credentials - - -google_credentials = pytest.mark.skipif( - not has_google_credentials(), - reason="Skipping google lifesciences tests because " - "Google credentials were not found in the environment.", -) - - -def get_default_service_account_email(): - """Returns the default service account if running on a GCE VM, otherwise None.""" - response = requests.get( - "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/email", - headers={"Metadata-Flavor": "Google"}, - ) - if response.status_code == requests.codes.ok: - return response.text - else: - return None - - -def cleanup_google_storage(prefix, bucket_name="snakemake-testing", restrict_to=None): - """Given a storage prefix and a bucket, recursively delete files there - This is intended to run after testing to ensure that - the bucket is cleaned up. - - Arguments: - prefix (str) : the "subfolder" or prefix for some files in the buckets - bucket_name (str) : the name of the bucket, default snakemake-testing - restrict_to (list) : only delete files in these paths (None deletes all) - """ - client = storage.Client() - bucket = client.get_bucket(bucket_name) - blobs = bucket.list_blobs(prefix="source") - for blob in blobs: - blob.delete() - blobs = bucket.list_blobs(prefix=prefix) - for blob in blobs: - if restrict_to is None or f"{bucket_name}/{blob.name}" in restrict_to: - blob.delete() - if restrict_to is None: - # Using API we get an exception about bucket deletion - shell("gsutil -m rm -r gs://{bucket.name}/* || true") - bucket.delete() - - -def create_google_storage(bucket_name="snakemake-testing"): - """Given a bucket name, create the Google storage bucket, - intending to be used for testing and then cleaned up by - cleanup_google_storage - - Arguments: - bucket_name (str) : the name of the bucket, default snakemake-testing - """ - client = storage.Client() - return client.create_bucket(bucket_name) - - -def get_temp_bucket_name(): - return "snakemake-testing-%s-bucket" % next(tempfile._get_candidate_names()) - - -@google_credentials -def test_google_lifesciences(): - bucket_name = get_temp_bucket_name() - create_google_storage(bucket_name) - storage_prefix = "test_google_lifesciences" - workdir = dpath("test_google_lifesciences") - try: - run( - workdir, - use_conda=True, - default_remote_prefix="%s/%s" % (bucket_name, storage_prefix), - google_lifesciences=True, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=get_default_service_account_email(), - preemption_default=None, - preemptible_rules=["pack=1"], - ) - finally: - cleanup_google_storage(storage_prefix, bucket_name) - - -@pytest.mark.skip( - reason="Cannot test using touch with a remote prefix until the container image is deployed." -) -@google_credentials -def test_touch_remote_prefix(): - bucket_name = get_temp_bucket_name() - create_google_storage(bucket_name) - storage_prefix = "test_touch_remote_prefix" - workdir = dpath("test_touch_remote_prefix") - try: - run( - workdir, - use_conda=True, - default_remote_prefix="%s/%s" % (bucket_name, storage_prefix), - google_lifesciences=True, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=get_default_service_account_email(), - ) - finally: - cleanup_google_storage(storage_prefix, bucket_name) - - -@google_credentials -def test_cloud_checkpoints_issue574(): - """see Github issue #574""" - bucket_name = get_temp_bucket_name() - create_google_storage(bucket_name) - storage_prefix = "test_cloud_checkpoints_issue574" - workdir = dpath("test_cloud_checkpoints_issue574") - try: - run( - workdir, - use_conda=True, - default_remote_prefix="%s/%s" % (bucket_name, storage_prefix), - google_lifesciences=True, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=get_default_service_account_email(), - ) - finally: - cleanup_google_storage(storage_prefix, bucket_name) - - -def test_github_issue1396(): - bucket_name = get_temp_bucket_name() - create_google_storage(bucket_name) - storage_prefix = "test_github_issue1396" - workdir = dpath("test_github_issue1396") - try: - run( - workdir, - default_remote_prefix="%s/%s" % (bucket_name, storage_prefix), - google_lifesciences=True, - google_lifesciences_cache=False, - dryrun=True, - ) - finally: - cleanup_google_storage(storage_prefix, bucket_name) - - -def test_github_issue1460(): - service_account_email = get_default_service_account_email() - bucket_name = get_temp_bucket_name() - create_google_storage(bucket_name) - storage_prefix = "test_github_issue1460" - prefix = "%s/%s" % (bucket_name, storage_prefix) - workdir = dpath("test_github_issue1460") - try: - run( - workdir, - default_remote_prefix=prefix, - google_lifesciences=True, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=service_account_email, - ) - cleanup_google_storage( - storage_prefix, - bucket_name, - restrict_to=[ - f"{prefix}/test.txt", - f"{prefix}/blob.txt", - f"{prefix}/pretest.txt", - ], - ) - run( - workdir, - default_remote_prefix=prefix, - google_lifesciences=True, - google_lifesciences_cache=False, - google_lifesciences_service_account_email=service_account_email, - ) - finally: - cleanup_google_storage(storage_prefix, bucket_name) diff --git a/tests/test_kubernetes.py b/tests/test_kubernetes.py deleted file mode 100644 index b5a640293..000000000 --- a/tests/test_kubernetes.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import sys -import uuid - -sys.path.insert(0, os.path.dirname(__file__)) - -from snakemake.resources import DefaultResources - -from common import * - - -@pytest.fixture(scope="module") -def kubernetes_cluster(): - class Cluster: - def __init__(self): - self.cluster = "t-{}".format(uuid.uuid4()) - self.bucket_name = self.cluster - - try: - shell( - """ - gcloud container clusters create {self.cluster} --num-nodes 3 --scopes storage-rw --zone us-central1-a --machine-type n1-standard-2 --local-ssd-count=1 - gcloud container clusters get-credentials {self.cluster} --zone us-central1-a - gsutil mb gs://{self.bucket_name} - """ - ) - except Exception as e: - try: - self.delete() - except: - # ignore errors during deletion - pass - raise e - - def delete(self): - shell( - """ - gcloud container clusters delete {self.cluster} --zone us-central1-a --quiet || true - gsutil rm -r gs://{self.bucket_name} || true - """ - ) - - def run(self, test="test_kubernetes", **kwargs): - try: - run( - dpath(test), - kubernetes="default", - default_remote_provider="GS", - default_remote_prefix=self.bucket_name, - no_tmpdir=True, - default_resources=DefaultResources( - ["mem_mb=1000", "disk_mb=50"] - ), # ensure that we don't get charged too much - **kwargs - ) - except Exception as e: - shell( - "for p in `kubectl get pods | grep ^snakejob- | cut -f 1 -d ' '`; do kubectl logs $p; done" - ) - raise e - - def reset(self): - print("Resetting bucket...", file=sys.stderr) - shell("gsutil -m rm -r gs://{self.bucket_name}/* || true") - - cluster = Cluster() - yield cluster - cluster.delete() - - -@gcloud -def test_kubernetes_plain(kubernetes_cluster): - kubernetes_cluster.reset() - kubernetes_cluster.run() - - -@gcloud -@pytest.mark.skip(reason="need a faster cloud compute instance to run this") -def test_kubernetes_conda(kubernetes_cluster): - kubernetes_cluster.reset() - kubernetes_cluster.run(use_conda=True) - - -@gcloud -@pytest.mark.skip(reason="need a faster cloud compute instance to run this") -def test_kubernetes_singularity(kubernetes_cluster): - kubernetes_cluster.reset() - kubernetes_cluster.run(use_singularity=True) - - -@gcloud -@pytest.mark.skip(reason="need a faster cloud compute instance to run this") -def test_kubernetes_conda_singularity(kubernetes_cluster): - kubernetes_cluster.reset() - kubernetes_cluster.run(use_singularity=True, use_conda=True) - - -@gcloud() -@pytest.mark.skip(reason="need a faster cloud compute instance to run this") -def test_issue1041(kubernetes_cluster): - kubernetes_cluster.reset() - kubernetes_cluster.run(test="test_issue1041") diff --git a/tests/test_list_untracked/Snakefile b/tests/test_list_untracked/Snakefile index e484473b5..a6749a90f 100644 --- a/tests/test_list_untracked/Snakefile +++ b/tests/test_list_untracked/Snakefile @@ -2,4 +2,4 @@ shell.executable("bash") rule run_test: output: "leftover_files" - shell: "python -m snakemake -s Snakefile_inner --list-untracked 2> {output}" + shell: "python -m snakemake -s Snakefile_inner --list-untracked > {output}" diff --git a/tests/test_list_untracked/expected-results/leftover_files b/tests/test_list_untracked/expected-results/leftover_files index fac3ff55b..b0c84bdef 100644 --- a/tests/test_list_untracked/expected-results/leftover_files +++ b/tests/test_list_untracked/expected-results/leftover_files @@ -1,2 +1 @@ -Building DAG of jobs... some_subdir/not_used diff --git a/tests/test_list_untracked/expected-results/leftover_files_WIN b/tests/test_list_untracked/expected-results/leftover_files_WIN index fdd0cbdde..c6499a63e 100644 --- a/tests/test_list_untracked/expected-results/leftover_files_WIN +++ b/tests/test_list_untracked/expected-results/leftover_files_WIN @@ -1,2 +1 @@ -Building DAG of jobs... some_subdir\not_used diff --git a/tests/test_pipes/Snakefile b/tests/test_pipes/Snakefile index 17bc46d70..31e1e349c 100644 --- a/tests/test_pipes/Snakefile +++ b/tests/test_pipes/Snakefile @@ -9,7 +9,7 @@ rule a: output: pipe("test.{i}.txt") shell: - "for i in {{0..2}}; do echo {wildcards.i} >> {output}; done" + r"echo -e '{wildcards.i}\n{wildcards.i}\n{wildcards.i}' > {output}" rule b: diff --git a/tests/test_slurm.py b/tests/test_slurm.py deleted file mode 100644 index 0f49348e1..000000000 --- a/tests/test_slurm.py +++ /dev/null @@ -1,99 +0,0 @@ -__authors__ = ["Christian Meesters", "Johannes Köster"] -__copyright__ = "Copyright 2022, Christian Meesters, Johannes Köster" -__email__ = "johannes.koester@uni-due.de" -__license__ = "MIT" - -import os -import sys - -sys.path.insert(0, os.path.dirname(__file__)) - -from .common import * -from .conftest import skip_on_windows - - -@skip_on_windows -def test_slurm_mpi(): - run( - dpath("test_slurm_mpi"), - slurm=True, - show_failed_logs=True, - use_conda=True, - default_resources=DefaultResources( - ["slurm_account=runner", "slurm_partition=debug"] - ), - ) - - -@skip_on_windows -def test_slurm_group_job(): - """ - same test as test_group_job(), - but for SLURM - checks whether - the group-property is correctly - propagated. - """ - run( - dpath("test_group_jobs"), - slurm=True, - show_failed_logs=True, - default_resources=DefaultResources( - ["slurm_account=runner", "slurm_partition=debug", "tasks=1", "mem_mb=0"] - ), - ) - - -@skip_on_windows -def test_slurm_group_parallel(): - """ - same test as test_group_job(), - but for SLURM - checks whether - the group-property is correctly - propagated. - """ - run( - dpath("test_group_parallel"), - slurm=True, - show_failed_logs=True, - default_resources=DefaultResources( - ["slurm_account=runner", "slurm_partition=debug", "tasks=1", "mem_mb=0"] - ), - ) - - -@skip_on_windows -def test_slurm_complex(): - os.environ["TESTVAR"] = "test" - os.environ["TESTVAR2"] = "test" - run( - dpath("test14"), - snakefile="Snakefile.nonstandard", - show_failed_logs=True, - slurm=True, - default_resources=DefaultResources( - [ - "slurm_account=runner", - "slurm_partition=debug", - "tasks=1", - "mem_mb=0", - "disk_mb=max(2*input.size_mb, 200)", - ] - ), - ) - - -@skip_on_windows -def test_slurm_extra_arguments(): - """Make sure arguments to default resources - are allowed to contain = signs, which is needed - for extra slurm arguments""" - run( - dpath("test_slurm_mpi"), - slurm=True, - show_failed_logs=True, - use_conda=True, - default_resources=DefaultResources( - ["slurm_account=runner", "slurm_partition=debug", - "slurm_extra='--mail-type=none'"] - ), - ) diff --git a/tests/test_srcdir/Snakefile b/tests/test_srcdir/Snakefile deleted file mode 100644 index 4e0c97879..000000000 --- a/tests/test_srcdir/Snakefile +++ /dev/null @@ -1,7 +0,0 @@ -rule: - output: - "test.out" - params: - srcdir("script.sh") - shell: - "sh {params} > {output}" diff --git a/tests/test_srcdir/expected-results/test.out b/tests/test_srcdir/expected-results/test.out deleted file mode 100644 index 9daeafb98..000000000 --- a/tests/test_srcdir/expected-results/test.out +++ /dev/null @@ -1 +0,0 @@ -test diff --git a/tests/test_srcdir/script.sh b/tests/test_srcdir/script.sh deleted file mode 100644 index 7e182fdda..000000000 --- a/tests/test_srcdir/script.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/sh -echo test diff --git a/tests/test_temp/qsub b/tests/test_temp/qsub deleted file mode 100755 index 63a46f970..000000000 --- a/tests/test_temp/qsub +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -echo `date` >> qsub.log -tail -n1 $1 >> qsub.log -# simulate printing of job id by a random number -echo $RANDOM -$1 diff --git a/tests/test_tes.py b/tests/test_tes.py deleted file mode 100644 index 8cc7ed09c..000000000 --- a/tests/test_tes.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import sys -import subprocess -import requests_mock -import json - -sys.path.insert(0, os.path.dirname(__file__)) - -from common import * - - -TES_URL = "http://localhost:8000" -FUNNEL_SERVER_USER = "funnel" -FUNNEL_SERVER_PASSWORD = "funnel_password" - -TEST_POST_RESPONSE = {"id": "id_1"} - -TEST_TASK = {"id": "id_1", "state": "COMPLETE"} - -TES_TOKEN = ( - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." - + "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." - + "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" -) - - -def _validate_task(task): - print("\n>>>> _validate_task", file=sys.stderr) - required_keys = ["inputs", "outputs", "executors"] - return all(i in task.keys() for i in required_keys) - - -def _post_task(request, context): - outdir = dpath("test_tes") - print("\n>>>> _post_task", file=sys.stderr) - task = json.loads(request.body) - print(task, file=sys.stderr) - if _validate_task(task): - context.status_code = 200 - # create output file - print("\n create output files in {}".format(outdir), file=sys.stderr) - with open("{}/test_output.txt".format(outdir), "w+") as f: - f.write("output") - # create log file - with open("{}/test_log.txt".format(outdir), "w+") as f: - f.write("log") - return TEST_POST_RESPONSE - else: - context.status_code = 400 - return None - - -def _get_task(request, context): - print("\n>>>> _get_task", file=sys.stderr) - context.status_code = 200 - return TEST_TASK - - -def test_tes(requests_mock): - requests_mock.register_uri("POST", "{}/v1/tasks".format(TES_URL), json=_post_task) - requests_mock.register_uri( - "GET", "{}/v1/tasks/id_1".format(TES_URL), json=_get_task - ) - workdir = dpath("test_tes") - print("\n>>>> run workflow in {}".format(workdir), file=sys.stderr) - run( - workdir, - snakefile="Snakefile", - tes=TES_URL, - no_tmpdir=True, - cleanup=False, - forceall=True, - ) - os.environ["TES_TOKEN"] = TES_TOKEN - os.environ["FUNNEL_SERVER_USER"] = FUNNEL_SERVER_USER - os.environ["FUNNEL_SERVER_PASSWORD"] = FUNNEL_SERVER_PASSWORD - run( - workdir, - snakefile="Snakefile", - tes=TES_URL, - no_tmpdir=True, - cleanup=False, - forceall=True, - ) diff --git a/tests/test_tibanna.py b/tests/test_tibanna.py deleted file mode 100644 index 5975ec1c9..000000000 --- a/tests/test_tibanna.py +++ /dev/null @@ -1,24 +0,0 @@ -import os -import sys - -sys.path.insert(0, os.path.dirname(__file__)) - -from common import * - - -# run tibanna test before any moto-related tests because they apparently render AWS environment variables invalid or uneffective. -def test_tibanna(): - workdir = dpath("test_tibanna") - subprocess.check_call(["python", "cleanup.py"], cwd=workdir) - - os.environ["TEST_ENVVAR1"] = "test" - os.environ["TEST_ENVVAR2"] = "test" - - run( - workdir, - use_conda=True, - default_remote_prefix="snakemake-tibanna-test/1", - tibanna=True, - tibanna_sfn="tibanna_unicorn_johannes", - tibanna_config=["spot_instance=true"], - ) diff --git a/tests/testapi.py b/tests/testapi.py index 894ea272d..41474e62f 100644 --- a/tests/testapi.py +++ b/tests/testapi.py @@ -108,9 +108,9 @@ def test_lockexception(): persistence = Persistence() persistence.all_inputfiles = lambda: ["A.txt"] persistence.all_outputfiles = lambda: ["B.txt"] - persistence.lock() - try: - persistence.lock() - except LockException as e: - return True - assert False + with persistence.lock(): + try: + persistence.lock() + except LockException as e: + return True + assert False diff --git a/tests/tests.py b/tests/tests.py index 8de59722e..f5e862604 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -5,19 +5,20 @@ import os import sys -import uuid import subprocess as sp from pathlib import Path +from snakemake.resources import DefaultResources, GroupResources +from snakemake.settings import RerunTrigger -from snakemake.cli import parse_cores_jobs -from snakemake.exceptions import CliException -from snakemake.utils import available_cpu_count +from snakemake.shell import shell sys.path.insert(0, os.path.dirname(__file__)) from .common import * from .conftest import skip_on_windows, only_on_windows, ON_WINDOWS, needs_strace +from snakemake_interface_executor_plugins.settings import DeploymentMethod + def test_list_untracked(): run(dpath("test_list_untracked")) @@ -40,7 +41,7 @@ def test_github_issue_14(): shutil.rmtree(tmpdir) # And not here - tmpdir = run(dpath("test_github_issue_14"), cleanup=False) + tmpdir = run(dpath("test_github_issue_14"), cleanup=False, cleanup_scripts=True) assert not os.listdir(os.path.join(tmpdir, ".snakemake", "scripts")) shutil.rmtree(tmpdir) @@ -66,6 +67,7 @@ def test04(): run(dpath("test04"), targets=["test.out"]) +@skip_on_windows # error: "The filename, directory name, or volume label syntax is incorrect". def test05(): run(dpath("test05")) @@ -102,84 +104,59 @@ def test13(): run(dpath("test13")) -@skip_on_windows -def test14(): - os.environ["TESTVAR"] = "test" - os.environ["TESTVAR2"] = "test" - run(dpath("test14"), snakefile="Snakefile.nonstandard", cluster="./qsub") - - -@skip_on_windows -def test_cluster_statusscript(): - os.environ["TESTVAR"] = "test" - run( - dpath("test_cluster_statusscript"), - snakefile="Snakefile.nonstandard", - cluster="./qsub", - cluster_status="./status.sh", - ) - - -@skip_on_windows -def test_cluster_cancelscript(): - outdir = run( - dpath("test_cluster_cancelscript"), - snakefile="Snakefile.nonstandard", - shellcmd=( - "snakemake -j 10 --cluster=./sbatch --cluster-cancel=./scancel.sh " - "--cluster-status=./status.sh -s Snakefile.nonstandard" - ), - shouldfail=True, - cleanup=False, - sigint_after=4, - ) - scancel_txt = open("%s/scancel.txt" % outdir).read() - scancel_lines = scancel_txt.splitlines() - assert len(scancel_lines) == 1 - assert scancel_lines[0].startswith("cancel") - assert len(scancel_lines[0].split(" ")) == 3 - - -@skip_on_windows -def test_cluster_sidecar(): - run( - dpath("test_cluster_sidecar"), - shellcmd=("snakemake -j 10 --cluster=./sbatch --cluster-sidecar=./sidecar.sh"), - ) - - -@skip_on_windows -def test_cluster_cancelscript_nargs1(): - outdir = run( - dpath("test_cluster_cancelscript"), - snakefile="Snakefile.nonstandard", - shellcmd=( - "snakemake -j 10 --cluster=./sbatch --cluster-cancel=./scancel.sh " - "--cluster-status=./status.sh --cluster-cancel-nargs=1 " - "-s Snakefile.nonstandard" - ), - shouldfail=True, - cleanup=False, - sigint_after=4, - ) - scancel_txt = open("%s/scancel.txt" % outdir).read() - scancel_lines = scancel_txt.splitlines() - assert len(scancel_lines) == 2 - assert scancel_lines[0].startswith("cancel") - assert scancel_lines[1].startswith("cancel") - assert len(scancel_lines[0].split(" ")) == 2 - assert len(scancel_lines[1].split(" ")) == 2 - - -@skip_on_windows -def test_cluster_statusscript_multi(): - os.environ["TESTVAR"] = "test" - run( - dpath("test_cluster_statusscript_multi"), - snakefile="Snakefile.nonstandard", - cluster="./sbatch", - cluster_status="./status.sh", - ) +# TODO reenable once cluster-generic plugin is released +# @skip_on_windows +# def test_cluster_cancelscript(): +# outdir = run( +# dpath("test_cluster_cancelscript"), +# snakefile="Snakefile.nonstandard", +# shellcmd=( +# "snakemake -j 10 --cluster=./sbatch --cluster-cancel=./scancel.sh " +# "--cluster-status=./status.sh -s Snakefile.nonstandard" +# ), +# shouldfail=True, +# cleanup=False, +# sigint_after=4, +# ) +# scancel_txt = open("%s/scancel.txt" % outdir).read() +# scancel_lines = scancel_txt.splitlines() +# assert len(scancel_lines) == 1 +# assert scancel_lines[0].startswith("cancel") +# assert len(scancel_lines[0].split(" ")) == 3 + + +# @skip_on_windows +# def test_cluster_cancelscript_nargs1(): +# outdir = run( +# dpath("test_cluster_cancelscript"), +# snakefile="Snakefile.nonstandard", +# shellcmd=( +# "snakemake -j 10 --cluster=./sbatch --cluster-cancel=./scancel.sh " +# "--cluster-status=./status.sh --cluster-cancel-nargs=1 " +# "-s Snakefile.nonstandard" +# ), +# shouldfail=True, +# cleanup=False, +# sigint_after=4, +# ) +# scancel_txt = open("%s/scancel.txt" % outdir).read() +# scancel_lines = scancel_txt.splitlines() +# assert len(scancel_lines) == 2 +# assert scancel_lines[0].startswith("cancel") +# assert scancel_lines[1].startswith("cancel") +# assert len(scancel_lines[0].split(" ")) == 2 +# assert len(scancel_lines[1].split(" ")) == 2 + + +# @skip_on_windows +# def test_cluster_statusscript_multi(): +# os.environ["TESTVAR"] = "test" +# run( +# dpath("test_cluster_statusscript_multi"), +# snakefile="Snakefile.nonstandard", +# cluster="./sbatch", +# cluster_status="./status.sh", +# ) def test15(): @@ -282,21 +259,13 @@ def test_shell(): @skip_on_windows def test_temp(): - run(dpath("test_temp"), cluster="./qsub", targets="test.realigned.bam".split()) + run(dpath("test_temp"), targets="test.realigned.bam".split()) def test_keyword_list(): run(dpath("test_keyword_list")) -# Fails on WIN because some snakemake doesn't release the logfile -# which cause a PermissionError when the test setup tries to -# remove the temporary files -@skip_on_windows -def test_subworkflows(): - run(dpath("test_subworkflows"), subpath=dpath("test02")) - - def test_globwildcards(): run(dpath("test_globwildcards")) @@ -328,7 +297,7 @@ def test_touch(): def test_touch_flag_with_directories(): - run(dpath("test_touch_with_directories"), touch=True) + run(dpath("test_touch_with_directories"), executor="touch") def test_config(): @@ -363,11 +332,6 @@ def test_wildcard_count_ambiguity(): run(dpath("test_wildcard_count_ambiguity")) -@skip_on_windows -def test_srcdir(): - run(dpath("test_srcdir")) - - def test_multiple_includes(): run(dpath("test_multiple_includes")) @@ -388,12 +352,6 @@ def test_remote(): run(dpath("test_remote"), cores=1) -@skip_on_windows -def test_cluster_sync(): - os.environ["TESTVAR"] = "test" - run(dpath("test14"), snakefile="Snakefile.nonstandard", cluster_sync="./qsub") - - @pytest.mark.skip(reason="This does not work reliably in CircleCI.") def test_symlink_temp(): run(dpath("test_symlink_temp"), shouldfail=True) @@ -405,7 +363,11 @@ def test_empty_include(): @skip_on_windows def test_script(): - run(dpath("test_script"), use_conda=True, check_md5=False) + run( + dpath("test_script"), + deployment_method={DeploymentMethod.CONDA}, + check_md5=False, + ) def test_script_python(): @@ -427,9 +389,10 @@ def test_shadow_prefix(): run(dpath("test_shadow_prefix"), shadow_prefix="shadowdir") -@skip_on_windows -def test_shadow_prefix_qsub(): - run(dpath("test_shadow_prefix"), shadow_prefix="shadowdir", cluster="./qsub") +# TODO add again once generic cluster plugin is released +# @skip_on_windows +# def test_shadow_prefix_qsub(): +# run(dpath("test_shadow_prefix"), shadow_prefix="shadowdir", cluster="./qsub") @skip_on_windows @@ -493,32 +456,40 @@ def test_issue328(): def test_conda(): - run(dpath("test_conda"), use_conda=True) + run(dpath("test_conda"), deployment_method={DeploymentMethod.CONDA}) def test_conda_list_envs(): - run(dpath("test_conda"), list_conda_envs=True, check_results=False) + run(dpath("test_conda"), conda_list_envs=True, check_results=False) def test_upstream_conda(): - run(dpath("test_conda"), use_conda=True, conda_frontend="conda") + run( + dpath("test_conda"), + deployment_method={DeploymentMethod.CONDA}, + conda_frontend="conda", + ) @skip_on_windows def test_deploy_script(): - run(dpath("test_deploy_script"), use_conda=True) + run(dpath("test_deploy_script"), deployment_method={DeploymentMethod.CONDA}) @skip_on_windows def test_deploy_hashing(): - tmpdir = run(dpath("test_deploy_hashing"), use_conda=True, cleanup=False) + tmpdir = run( + dpath("test_deploy_hashing"), + deployment_method={DeploymentMethod.CONDA}, + cleanup=False, + ) assert len(next(os.walk(os.path.join(tmpdir, ".snakemake/conda")))[1]) == 2 def test_conda_custom_prefix(): run( dpath("test_conda_custom_prefix"), - use_conda=True, + deployment_method={DeploymentMethod.CONDA}, conda_prefix="custom", set_pythonpath=False, ) @@ -528,12 +499,12 @@ def test_conda_custom_prefix(): def test_conda_cmd_exe(): # Tests the conda environment activation when cmd.exe # is used as the shell - run(dpath("test_conda_cmd_exe"), use_conda=True) + run(dpath("test_conda_cmd_exe"), deployment_method={DeploymentMethod.CONDA}) @skip_on_windows # wrappers are for linux and macos only def test_wrapper(): - run(dpath("test_wrapper"), use_conda=True) + run(dpath("test_wrapper"), deployment_method={DeploymentMethod.CONDA}) @skip_on_windows # wrappers are for linux and macos only @@ -548,7 +519,9 @@ def test_wrapper_local_git_prefix(): print("Cloning complete.") run( - dpath("test_wrapper"), use_conda=True, wrapper_prefix=f"git+file://{tmpdir}" + dpath("test_wrapper"), + deployment_method={DeploymentMethod.CONDA}, + wrapper_prefix=f"git+file://{tmpdir}", ) @@ -657,59 +630,62 @@ def test_dup_out_patterns(): run(dpath("test_dup_out_patterns"), shouldfail=True) -@skip_on_windows -def test_restartable_job_cmd_exit_1_no_restart(): - """Test the restartable job feature on ``exit 1`` - - The shell snippet in the Snakemake file will fail the first time - and succeed the second time. - """ - run( - dpath("test_restartable_job_cmd_exit_1"), - cluster="./qsub", - restart_times=0, - shouldfail=True, - ) - - -@skip_on_windows -def test_restartable_job_cmd_exit_1_one_restart(): - # Restarting once is enough - run( - dpath("test_restartable_job_cmd_exit_1"), - cluster="./qsub", - restart_times=1, - printshellcmds=True, - ) - - -@skip_on_windows -def test_restartable_job_qsub_exit_1(): - """Test the restartable job feature when qsub fails - - The qsub in the subdirectory will fail the first time and succeed the - second time. - """ - # Even two consecutive times should fail as files are cleared - run( - dpath("test_restartable_job_qsub_exit_1"), - cluster="./qsub", - restart_times=0, - shouldfail=True, - ) - run( - dpath("test_restartable_job_qsub_exit_1"), - cluster="./qsub", - restart_times=0, - shouldfail=True, - ) - # Restarting once is enough - run( - dpath("test_restartable_job_qsub_exit_1"), - cluster="./qsub", - restart_times=1, - shouldfail=False, - ) +# TODO reactivate once generic cluster executor is properly released +# @skip_on_windows +# def test_restartable_job_cmd_exit_1_no_restart(): +# """Test the restartable job feature on ``exit 1`` + +# The shell snippet in the Snakemake file will fail the first time +# and succeed the second time. +# """ +# run( +# dpath("test_restartable_job_cmd_exit_1"), +# cluster="./qsub", +# retries=0, +# shouldfail=True, +# ) + + +# TODO reactivate once generic cluster executor is properly released +# @skip_on_windows +# def test_restartable_job_cmd_exit_1_one_restart(): +# # Restarting once is enough +# run( +# dpath("test_restartable_job_cmd_exit_1"), +# cluster="./qsub", +# retries=1, +# printshellcmds=True, +# ) + + +# TODO reactivate once generic cluster executor is properly released +# @skip_on_windows +# def test_restartable_job_qsub_exit_1(): +# """Test the restartable job feature when qsub fails + +# The qsub in the subdirectory will fail the first time and succeed the +# second time. +# """ +# # Even two consecutive times should fail as files are cleared +# run( +# dpath("test_restartable_job_qsub_exit_1"), +# cluster="./qsub", +# retries=0, +# shouldfail=True, +# ) +# run( +# dpath("test_restartable_job_qsub_exit_1"), +# cluster="./qsub", +# retries=0, +# shouldfail=True, +# ) +# # Restarting once is enough +# run( +# dpath("test_restartable_job_qsub_exit_1"), +# cluster="./qsub", +# retries=1, +# shouldfail=False, +# ) def test_threads(): @@ -809,16 +785,17 @@ def test_remote_log(): @connected -@pytest.mark.xfail def test_remote_http(): run(dpath("test_remote_http")) @skip_on_windows @connected -@pytest.mark.xfail def test_remote_http_cluster(): - run(dpath("test_remote_http"), cluster=os.path.abspath(dpath("test14/qsub"))) + run( + dpath("test_remote_http"), + cluster=os.path.abspath(dpath("test_group_job_fail/qsub")), + ) def test_profile(): @@ -828,7 +805,7 @@ def test_profile(): @skip_on_windows @connected def test_singularity(): - run(dpath("test_singularity"), use_singularity=True) + run(dpath("test_singularity"), deployment_method={DeploymentMethod.APPTAINER}) @skip_on_windows @@ -836,7 +813,7 @@ def test_singularity_invalid(): run( dpath("test_singularity"), targets=["invalid.txt"], - use_singularity=True, + deployment_method={DeploymentMethod.APPTAINER}, shouldfail=True, ) @@ -846,7 +823,7 @@ def test_singularity_module_invalid(): run( dpath("test_singularity_module"), targets=["invalid.txt"], - use_singularity=True, + deployment_method={DeploymentMethod.APPTAINER}, shouldfail=True, ) @@ -856,8 +833,7 @@ def test_singularity_module_invalid(): def test_singularity_conda(): run( dpath("test_singularity_conda"), - use_singularity=True, - use_conda=True, + deployment_method={DeploymentMethod.CONDA, DeploymentMethod.APPTAINER}, conda_frontend="conda", ) @@ -865,17 +841,19 @@ def test_singularity_conda(): @skip_on_windows @connected def test_singularity_none(): - run(dpath("test_singularity_none"), use_singularity=True) + run(dpath("test_singularity_none"), deployment_method={DeploymentMethod.APPTAINER}) @skip_on_windows @connected def test_singularity_global(): - run(dpath("test_singularity_global"), use_singularity=True) + run( + dpath("test_singularity_global"), deployment_method={DeploymentMethod.APPTAINER} + ) def test_issue612(): - run(dpath("test_issue612"), dryrun=True) + run(dpath("test_issue612"), executor="dryrun") def test_bash(): @@ -894,16 +872,10 @@ def test_log_input(): run(dpath("test_log_input")) -@skip_on_windows -@connected -def test_cwl(): - run(dpath("test_cwl")) - - @skip_on_windows @connected def test_cwl_singularity(): - run(dpath("test_cwl"), use_singularity=True) + run(dpath("test_cwl"), deployment_method={DeploymentMethod.APPTAINER}) def test_issue805(): @@ -940,7 +912,7 @@ def test_group_jobs(): @skip_on_windows def test_group_jobs_attempts(): - run(dpath("test_group_jobs_attempts"), cluster="./qsub", restart_times=2) + run(dpath("test_group_jobs_attempts"), cluster="./qsub", retries=2) def assert_resources(resources: dict, **expected_resources): @@ -1089,9 +1061,9 @@ def test_resources_can_be_overwritten_as_global(): @skip_on_windows def test_scopes_submitted_to_cluster(mocker): - from snakemake.executors import AbstractExecutor + from snakemake.spawn_jobs import SpawnedJobArgsFactory - spy = mocker.spy(AbstractExecutor, "get_resource_scopes_args") + spy = mocker.spy(SpawnedJobArgsFactory, "get_resource_scopes_args") run( dpath("test_group_jobs_resources"), cluster="./qsub", @@ -1101,12 +1073,12 @@ def test_scopes_submitted_to_cluster(mocker): default_resources=DefaultResources(["mem_mb=0"]), ) - assert spy.spy_return == "--set-resource-scopes 'fake_res=local'" + assert spy.spy_return == "--set-resource-scopes \"fake_res='local'\"" @skip_on_windows def test_resources_submitted_to_cluster(mocker): - from snakemake.executors import AbstractExecutor + from snakemake_interface_executor_plugins.executors.base import AbstractExecutor spy = mocker.spy(AbstractExecutor, "get_resource_declarations_dict") run( @@ -1126,7 +1098,7 @@ def test_resources_submitted_to_cluster(mocker): @skip_on_windows def test_excluded_resources_not_submitted_to_cluster(mocker): - from snakemake.executors import AbstractExecutor + from snakemake_interface_executor_plugins.executors.base import AbstractExecutor spy = mocker.spy(AbstractExecutor, "get_resource_declarations_dict") run( @@ -1145,7 +1117,7 @@ def test_excluded_resources_not_submitted_to_cluster(mocker): @skip_on_windows def test_group_job_resources_with_pipe(mocker): import copy - from snakemake.executors import RealExecutor + from snakemake_interface_executor_plugins.executors.real import RealExecutor spy = mocker.spy(GroupResources, "basic_layered") @@ -1207,8 +1179,8 @@ def test_group_job_fail(): @skip_on_windows # Not supported, but could maybe be implemented. https://stackoverflow.com/questions/48542644/python-and-windows-named-pipes -def test_pipes(): - run(dpath("test_pipes")) +def test_pipes_simple(): + run(dpath("test_pipes"), printshellcmds=True) @skip_on_windows @@ -1277,7 +1249,11 @@ def test_issue930(): @skip_on_windows def test_issue635(): - run(dpath("test_issue635"), use_conda=True, check_md5=False) + run( + dpath("test_issue635"), + deployment_method={DeploymentMethod.CONDA}, + check_md5=False, + ) # TODO remove skip @@ -1297,7 +1273,11 @@ def test_convert_to_cwl(): def test_issue1037(): - run(dpath("test_issue1037"), dryrun=True, cluster="qsub", targets=["Foo_A.done"]) + run( + dpath("test_issue1037"), + executor="dryrun", + targets=["Foo_A.done"], + ) def test_issue1046(): @@ -1318,11 +1298,11 @@ def test_issue1092(): @skip_on_windows def test_issue1093(): - run(dpath("test_issue1093"), use_conda=True) + run(dpath("test_issue1093"), deployment_method={DeploymentMethod.CONDA}) def test_issue958(): - run(dpath("test_issue958"), cluster="dummy", dryrun=True) + run(dpath("test_issue958"), executor="dryrun") def test_issue471(): @@ -1335,7 +1315,7 @@ def test_issue1085(): @skip_on_windows def test_issue1083(): - run(dpath("test_issue1083"), use_singularity=True) + run(dpath("test_issue1083"), deployment_method={DeploymentMethod.APPTAINER}) @skip_on_windows # Fails with "The flag 'pipe' used in rule two is only valid for outputs @@ -1426,7 +1406,7 @@ def test_github_issue52(): @skip_on_windows def test_github_issue78(): - run(dpath("test_github_issue78"), use_singularity=True) + run(dpath("test_github_issue78"), deployment_method={DeploymentMethod.APPTAINER}) def test_envvars(): @@ -1468,7 +1448,7 @@ def test_github_issue988(): ) def test_github_issue1062(): # old code failed in dry run - run(dpath("test_github_issue1062"), dryrun=True) + run(dpath("test_github_issue1062"), executor="dryrun") def test_output_file_cache(): @@ -1510,72 +1490,13 @@ def test_core_dependent_threads(): @skip_on_windows def test_env_modules(): - run(dpath("test_env_modules"), use_env_modules=True) - - -class TestParseCoresJobs: - def run_test(self, func, ref): - if ref is None: - with pytest.raises(CliException): - func() - return - assert func() == ref - - @pytest.mark.parametrize( - ("input", "output"), - [ - [(1, 1), (1, 1)], - [(4, 4), (4, 4)], - [(None, None), (1, 1)], - [("all", "unlimited"), (available_cpu_count(), sys.maxsize)], - ], - ) - def test_no_exec(self, input, output): - self.run_test(lambda: parse_cores_jobs(*input, True, False, False), output) - # Test dryrun seperately - self.run_test(lambda: parse_cores_jobs(*input, False, False, True), output) - - @pytest.mark.parametrize( - ("input", "output"), - [ - [(1, 1), (1, 1)], - [(4, 4), (4, 4)], - [(None, 1), (None, 1)], - [(None, None), None], - [(1, None), None], - [("all", "unlimited"), (available_cpu_count(), sys.maxsize)], - ], - ) - def test_non_local_job(self, input, output): - self.run_test(lambda: parse_cores_jobs(*input, False, True, False), output) - - @pytest.mark.parametrize( - ("input", "output"), - [ - [(1, 1), (1, None)], - [(4, 4), (4, None)], - [(None, 1), (1, None)], - [(None, None), None], - [(1, None), (1, None)], - [(None, "all"), (available_cpu_count(), None)], - [(None, "unlimited"), None], - [("all", "unlimited"), (available_cpu_count(), None)], - ], - ) - def test_local_job(self, input, output): - self.run_test(lambda: parse_cores_jobs(*input, False, False, False), output) + run(dpath("test_env_modules"), deployment_method={DeploymentMethod.ENV_MODULES}) @skip_on_windows @connected def test_container(): - run(dpath("test_container"), use_singularity=True) - - -def test_linting(): - snakemake( - snakefile=os.path.join(dpath("test14"), "Snakefile.nonstandard"), lint=True - ) + run(dpath("test_container"), deployment_method={DeploymentMethod.APPTAINER}) @skip_on_windows @@ -1590,16 +1511,16 @@ def test_string_resources(): def test_jupyter_notebook(): - run(dpath("test_jupyter_notebook"), use_conda=True) + run(dpath("test_jupyter_notebook"), deployment_method={DeploymentMethod.CONDA}) def test_jupyter_notebook_draft(): - from snakemake.notebook import EditMode + from snakemake.settings import NotebookEditMode run( dpath("test_jupyter_notebook_draft"), - use_conda=True, - edit_notebook=EditMode(draft_only=True), + deployment_method={DeploymentMethod.CONDA}, + edit_notebook=NotebookEditMode(draft_only=True), targets=["results/result_intermediate.txt"], check_md5=False, ) @@ -1625,7 +1546,7 @@ def test_github_issue640(): run( dpath("test_github_issue640"), targets=["Output/FileWithRights"], - dryrun=True, + executor="dryrun", cleanup=False, ) @@ -1643,38 +1564,6 @@ def test_generate_unit_tests(): sp.check_call(["pytest", ".tests", "-vs"], cwd=tmpdir) -@skip_on_windows -def test_metadata_migration(): - outpath = Path( - "tests/test_metadata_migration/some/veryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryverylong" - ) - os.makedirs(outpath, exist_ok=True) - metapath = Path( - "tests/test_metadata_migration/.snakemake/metadata/@c29tZS92ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5L3Zlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5L3Zlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcn/@l2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeS92ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnkvdmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnkvdmVyeXZlcnl2ZXJ5dmVy" - ) - os.makedirs(metapath, exist_ok=True) - exppath = Path( - "tests/test_metadata_migration/expected-results/some/veryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryvery/veryveryveryveryveryveryveryveryveryveryveryveryverylong" - ) - os.makedirs(exppath, exist_ok=True) - with open(outpath / "path.txt", "w"): - # generate empty file - pass - # generate artificial incomplete metadata in v1 format for migration - with open( - metapath - / "eXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeXZlcnl2ZXJ5dmVyeWxvbmcvcGF0aC50eHQ=", - "w", - ) as meta: - print('{"incomplete": true, "external_jobid": null}', file=meta) - with open(exppath / "path.txt", "w") as out: - print("updated", file=out) - - # run workflow, incomplete v1 metadata should be migrated and trigger rerun of the rule, - # which will save different data than the output contained in the git repo. - run(dpath("test_metadata_migration"), force_incomplete=True) - - def test_paramspace(): run(dpath("test_paramspace")) @@ -1689,7 +1578,10 @@ def test_github_issue806(): @skip_on_windows def test_containerized(): - run(dpath("test_containerized"), use_conda=True, use_singularity=True) + run( + dpath("test_containerized"), + deployment_method={DeploymentMethod.CONDA, DeploymentMethod.APPTAINER}, + ) @skip_on_windows @@ -1739,7 +1631,11 @@ def test_modules_specific(): @skip_on_windows # works in principle but the test framework modifies the target path separator def test_modules_meta_wrapper(): - run(dpath("test_modules_meta_wrapper"), targets=["mapped/a.bam.bai"], dryrun=True) + run( + dpath("test_modules_meta_wrapper"), + targets=["mapped/a.bam.bai"], + executor="dryrun", + ) def test_use_rule_same_module(): @@ -1747,11 +1643,11 @@ def test_use_rule_same_module(): def test_module_complex(): - run(dpath("test_module_complex"), dryrun=True) + run(dpath("test_module_complex"), executor="dryrun") def test_module_complex2(): - run(dpath("test_module_complex2"), dryrun=True) + run(dpath("test_module_complex2"), executor="dryrun") @skip_on_windows @@ -1815,7 +1711,7 @@ def test_github_issue1069(): def test_touch_pipeline_with_temp_dir(): # Issue #1028 - run(dpath("test_touch_pipeline_with_temp_dir"), forceall=True, touch=True) + run(dpath("test_touch_pipeline_with_temp_dir"), forceall=True, executor="touch") def test_all_temp(): @@ -1832,7 +1728,11 @@ def test_github_issue1158(): def test_converting_path_for_r_script(): - run(dpath("test_converting_path_for_r_script"), cores=1) + run( + dpath("test_converting_path_for_r_script"), + cores=1, + deployment_method={DeploymentMethod.CONDA}, + ) def test_ancient_dag(): @@ -1858,12 +1758,16 @@ def test_issue1331(): @skip_on_windows def test_conda_named(): - run(dpath("test_conda_named"), use_conda=True) + run(dpath("test_conda_named"), deployment_method={DeploymentMethod.CONDA}) @skip_on_windows def test_conda_function(): - run(dpath("test_conda_function"), use_conda=True, cores=1) + run( + dpath("test_conda_function"), + deployment_method={DeploymentMethod.CONDA}, + cores=1, + ) @skip_on_windows @@ -1914,7 +1818,7 @@ def test_service_jobs(): def test_incomplete_params(): - run(dpath("test_incomplete_params"), dryrun=True, printshellcmds=True) + run(dpath("test_incomplete_params"), executor="dryrun", printshellcmds=True) @skip_on_windows @@ -1934,11 +1838,11 @@ def test_pipe_depend_target_file(): @skip_on_windows # platform independent issue def test_github_issue1500(): - run(dpath("test_github_issue1500"), dryrun=True) + run(dpath("test_github_issue1500"), executor="dryrun") def test_github_issue1542(): - run(dpath("test_github_issue1542"), dryrun=True) + run(dpath("test_github_issue1542"), executor="dryrun") def test_github_issue1550(): @@ -1962,7 +1866,7 @@ def test_lazy_resources(): def test_cleanup_metadata_fail(): - run(dpath("test09"), cleanup_metadata=["xyz"]) + run(dpath("test09"), cleanup_metadata=["xyz"], shouldfail=True) @skip_on_windows # same on win, no need to test @@ -2005,7 +1909,7 @@ def test_retries(): def test_retries_not_overriden(): - run(dpath("test_retries_not_overriden"), restart_times=3, shouldfail=True) + run(dpath("test_retries_not_overriden"), retries=3, shouldfail=True) @skip_on_windows # OS agnostic @@ -2015,7 +1919,7 @@ def test_module_input_func(): @skip_on_windows # the testcase only has a linux-64 pin file def test_conda_pin_file(): - run(dpath("test_conda_pin_file"), use_conda=True) + run(dpath("test_conda_pin_file"), deployment_method={DeploymentMethod.CONDA}) @skip_on_windows # sufficient to test this on linux @@ -2024,21 +1928,27 @@ def test_github_issue1618(): def test_conda_python_script(): - run(dpath("test_conda_python_script"), use_conda=True) + run(dpath("test_conda_python_script"), deployment_method={DeploymentMethod.CONDA}) def test_conda_python_3_7_script(): - run(dpath("test_conda_python_3_7_script"), use_conda=True) + run( + dpath("test_conda_python_3_7_script"), + deployment_method={DeploymentMethod.CONDA}, + ) def test_prebuilt_conda_script(): - sp.run("conda env create -f tests/test_prebuilt_conda_script/env.yaml", shell=True) - run(dpath("test_prebuilt_conda_script"), use_conda=True) + sp.run( + f"conda env create -f {dpath('test_prebuilt_conda_script/env.yaml')}", + shell=True, + ) + run(dpath("test_prebuilt_conda_script"), deployment_method={DeploymentMethod.CONDA}) @skip_on_windows def test_github_issue1818(): - run(dpath("test_github_issue1818"), rerun_triggers="input") + run(dpath("test_github_issue1818"), rerun_triggers={RerunTrigger.INPUT}) @skip_on_windows # not platform dependent