Skip to content

Commit

Permalink
Advanced DAG partitioning (#232)
Browse files Browse the repository at this point in the history
* add todos

* fmt

* add test case

* refactoring
  • Loading branch information
johanneskoester committed Sep 16, 2020
1 parent 07f7300 commit aff0b57
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 12 deletions.
69 changes: 68 additions & 1 deletion snakemake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def snakemake(
messaging=None,
edit_notebook=None,
envvars=None,
overwrite_groups=None,
group_components=None,
):
"""Run snakemake on a given snakefile.
Expand Down Expand Up @@ -280,8 +282,10 @@ def snakemake(
log_handler (function): redirect snakemake output to this custom log handler, a function that takes a log message dictionary (see below) as its only argument (default None). The log message dictionary for the log handler has to following entries:
keep_incomplete (bool): keep incomplete output files of failed jobs
edit_notebook (object): "notebook.Listen" object to configuring notebook server for interactive editing of a rule notebook. If None, do not edit.
log_handler (list): redirect snakemake output to this list of custom log handler, each a function that takes a log message dictionary (see below) as its only argument (default []). The log message dictionary for the log handler has to following entries:
scheduler (str): Select scheduling algorithm (default ilp)
overwrite_groups (dict): Rule to group assignments (default None)
group_components (dict): Number of connected components given groups shall span before being split up (1 by default if empty)
log_handler (list): redirect snakemake output to this list of custom log handler, each a function that takes a log message dictionary (see below) as its only argument (default []). The log message dictionary for the log handler has to following entries:
:level:
the log level ("info", "error", "debug", "progress", "job_info")
Expand Down Expand Up @@ -517,6 +521,8 @@ def snakemake(
overwrite_configfiles=configfiles,
overwrite_clusterconfig=cluster_config_content,
overwrite_threads=overwrite_threads,
overwrite_groups=overwrite_groups,
group_components=group_components,
config_args=config_args,
debug=debug,
verbose=verbose,
Expand Down Expand Up @@ -639,6 +645,8 @@ def snakemake(
cluster_status=cluster_status,
max_jobs_per_second=max_jobs_per_second,
max_status_checks_per_second=max_status_checks_per_second,
overwrite_groups=overwrite_groups,
group_components=group_components,
)
success = workflow.execute(
targets=targets,
Expand Down Expand Up @@ -775,6 +783,32 @@ def parse_batch(args):
return None


def parse_groups(args):
errmsg = "Invalid groups definition: entries have to be defined as RULE=GROUP pairs"
overwrite_groups = dict()
if args.groups is not None:
for entry in args.groups:
rule, group = parse_key_value_arg(entry, errmsg=errmsg)
overwrite_groups[rule] = group
return overwrite_groups


def parse_group_components(args):
errmsg = "Invalid group components definition: entries have to be defined as GROUP=COMPONENTS pairs (with COMPONENTS being a positive integer)"
group_components = dict()
if args.group_components is not None:
for entry in args.group_components:
group, count = parse_key_value_arg(entry, errmsg=errmsg)
try:
count = int(count)
except ValueError:
raise ValueError(errmsg)
if count <= 0:
raise ValueError(errmsg)
group_components[group] = count
return group_components


def parse_key_value_arg(arg, errmsg):
try:
key, val = arg.split("=", 1)
Expand Down Expand Up @@ -1206,6 +1240,34 @@ def get_argument_parser(profile=None):
),
)

# TODO add group_partitioning, allowing to define --group rulename=groupname.
# i.e. setting groups via the CLI for improving cluster performance given
# available resources.
# TODO add an additional flag --group-components groupname=3, allowing to set the
# number of connected components a group is allowed to span. By default, this is 1
# (as now), but the flag allows to extend this. This can be used to run e.g.
# 3 jobs of the same rule in the same group, although they are not connected.
# Can be helpful for putting together many small jobs or benefitting of shared memory
# setups.

group_group = parser.add_argument_group("GROUPING")
group_group.add_argument(
"--groups",
nargs="+",
help="Assign rules to groups (this overwrites any "
"group definitions from the workflow).",
)
group_group.add_argument(
"--group-components",
nargs="+",
help="Set the number of connected components a group is "
"allowed to span. By default, this is 1, but this flag "
"allows to extend this. This can be used to run e.g. 3 "
"jobs of the same rule in the same group, although they "
"are not connected. It can be helpful for putting together "
"many small jobs or benefitting of shared memory setups.",
)

group_report = parser.add_argument_group("REPORTS")

group_report.add_argument(
Expand Down Expand Up @@ -2076,6 +2138,9 @@ def adjust_path(f):
default_resources = DefaultResources(args.default_resources)
batch = parse_batch(args)
overwrite_threads = parse_set_threads(args)

overwrite_groups = parse_groups(args)
group_components = parse_group_components(args)
except ValueError as e:
print(e, file=sys.stderr)
print("", file=sys.stderr)
Expand Down Expand Up @@ -2443,6 +2508,8 @@ def open_browser():
keep_incomplete=args.keep_incomplete,
edit_notebook=args.edit_notebook,
envvars=args.envvars,
overwrite_groups=overwrite_groups,
group_components=group_components,
log_handler=log_handler,
)

Expand Down
14 changes: 14 additions & 0 deletions snakemake/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__license__ = "MIT"

from functools import update_wrapper
import itertools
import platform
import hashlib
import inspect
Expand Down Expand Up @@ -139,3 +140,16 @@ def log_location(msg):
logger.debug(
"{}: {info.filename}, {info.function}, {info.lineno}".format(msg, info=info)
)


def group_into_chunks(n, iterable):
"""Group iterable into chunks of size at most n.
See https://stackoverflow.com/a/8998040.
"""
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
20 changes: 19 additions & 1 deletion snakemake/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from snakemake.exceptions import RemoteFileException, WorkflowError, ChildIOException
from snakemake.exceptions import InputFunctionException
from snakemake.logging import logger
from snakemake.common import DYNAMIC_FILL
from snakemake.common import DYNAMIC_FILL, group_into_chunks
from snakemake.deployment import conda, singularity
from snakemake.output_index import OutputIndex
from snakemake import workflow
Expand Down Expand Up @@ -1041,8 +1041,26 @@ def update_groups(self):
for j in group:
if j not in groups:
groups[j] = group

self._group = groups

self._update_group_components()

def _update_group_components(self):
# span connected components if requested
for groupid, conn_components in groupby(
set(self._group.values()), key=lambda group: group.groupid
):
n_components = self.workflow.group_components.get(groupid, 1)
if n_components > 1:
for chunk in group_into_chunks(n_components, conn_components):
if len(chunk) > 1:
primary = chunk[0]
for secondary in chunk[1:]:
primary.merge(secondary)
for j in primary:
self._group[j] = primary

def update_ready(self, jobs=None):
"""Update information whether a job is ready to execute.
Expand Down
17 changes: 9 additions & 8 deletions snakemake/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,14 +613,15 @@ def job_selector_ilp(self, jobs):
) / lpSum([self.required_by_job(temp_file, job) for job in jobs])

# TODO enable this code once we have switched to pulp >=2.0
# if pulp.apis.LpSolverDefault is None:
# raise WorkflowError(
# "You need to install at least one LP solver compatible with PuLP (e.g. coincbc). "
# "See https://coin-or.github.io/pulp for details. Alternatively, run Snakemake with "
# "--scheduler greedy."
# )
# # disable extensive logging
# pulp.apis.LpSolverDefault.msg = False
if pulp.apis.LpSolverDefault is None:
raise WorkflowError(
"You need to install at least one LP solver compatible with PuLP (e.g. coincbc). "
"See https://coin-or.github.io/pulp for details. Alternatively, run Snakemake with "
"--scheduler greedy."
)
# disable extensive logging
pulp.apis.LpSolverDefault.msg = False

prob.solve()
selected_jobs = [
job for job, variable in scheduled_jobs.items() if variable.value() == 1.0
Expand Down
10 changes: 8 additions & 2 deletions snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(
overwrite_configfiles=None,
overwrite_clusterconfig=dict(),
overwrite_threads=dict(),
overwrite_groups=None,
group_components=None,
config_args=None,
debug=False,
verbose=False,
Expand Down Expand Up @@ -173,6 +175,8 @@ def __init__(
# environment variables to pass to jobs
# These are defined via the "envvars:" syntax in the Snakefile itself
self.envvars = set()
self.overwrite_groups = overwrite_groups or dict()
self.group_components = group_components or dict()

self.enable_cache = False
if cache is not None:
Expand Down Expand Up @@ -1228,8 +1232,10 @@ def decorate(ruleinfo):
rule.message = ruleinfo.message
if ruleinfo.benchmark:
rule.benchmark = ruleinfo.benchmark
if not self.run_local and ruleinfo.group is not None:
rule.group = ruleinfo.group
if not self.run_local:
group = self.overwrite_groups.get(name) or ruleinfo.group
if group is not None:
rule.group = group
if ruleinfo.wrapper:
rule.conda_env = snakemake.wrapper.get_conda_env(
ruleinfo.wrapper, prefix=self.wrapper_prefix
Expand Down
32 changes: 32 additions & 0 deletions tests/test_multicomp_group_jobs/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
samples = list(range(10))


rule all:
input:
"test.out"


rule a:
output:
"a/{sample}.out"
shell:
"touch {output}"


rule b:
input:
"a/{sample}.out"
output:
"b/{sample}.out"
shell:
"touch {output}"


rule c:
input:
expand("b/{sample}.out", sample=samples)
output:
"test.out"
group: 1
shell:
"touch {output}"
Empty file.
6 changes: 6 additions & 0 deletions tests/test_multicomp_group_jobs/qsub
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
echo `date` >> qsub.log
tail -n1 $1 >> qsub.log
# simulate printing of job id by a random number
echo $RANDOM
sh $1
10 changes: 10 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,16 @@ def test_group_jobs():
run(dpath("test_group_jobs"), cluster="./qsub")


@skip_on_windows
def test_multicomp_group_jobs():
run(
dpath("test_multicomp_group_jobs"),
cluster="./qsub",
overwrite_groups={"a": "group0", "b": "group0"},
group_components={"group0": 2},
)


@skip_on_windows
def test_group_job_fail():
run(dpath("test_group_job_fail"), cluster="./qsub", shouldfail=True)
Expand Down

0 comments on commit aff0b57

Please sign in to comment.