Skip to content

Commit

Permalink
fix: wrong rule names when nesting module imports (#1817)
Browse files Browse the repository at this point in the history
### Description

Fix wrong rule names when nesting module imports.

### QC
<!-- Make sure that you can tick the boxes below. -->

* [X] The PR contains a test case for the changes or the changes are
already covered by an existing test case.
* [X] The documentation (`docs/`) is updated to reflect the changes or
this is not necessary (e.g. if the change does neither modify the
language nor the behavior or functionalities of Snakemake).

---------

Co-authored-by: Johannes Köster <johannes.koester@tu-dortmund.de>
Co-authored-by: Johannes Köster <johannes.koester@uni-due.de>
  • Loading branch information
3 people committed Aug 5, 2023
1 parent b33aeec commit 65c79a4
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 14 deletions.
14 changes: 13 additions & 1 deletion snakemake/common/__init__.py
Expand Up @@ -252,7 +252,19 @@ def group_into_chunks(n, iterable):
class Rules:
"""A namespace for rules so that they can be accessed via dot notation."""

pass
def __init__(self):
self._rules = dict()

def _register_rule(self, name, rule):
self._rules[name] = rule

def __getattr__(self, name):
from snakemake.exceptions import WorkflowError

try:
return self._rules[name]
except KeyError:
raise WorkflowError(f"Rule {name} is not defined in this workflow.")


class Scatter:
Expand Down
35 changes: 27 additions & 8 deletions snakemake/modules.py
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
import types
import re
from snakemake.common import Rules

from snakemake.exceptions import WorkflowError
from snakemake.path_modifier import PathModifier
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self.meta_wrapper = meta_wrapper
self.config = config
self.skip_validation = skip_validation
self.parent_modifier = self.workflow.modifier

if prefix is not None:
if isinstance(prefix, Path):
Expand All @@ -76,7 +78,7 @@ def use_rules(
skip_global_report_caption=False,
):
snakefile = self.get_snakefile()
with WorkflowModifier(
modifier = WorkflowModifier(
self.workflow,
config=self.config,
base_snakefile=snakefile,
Expand All @@ -85,15 +87,20 @@ def use_rules(
skip_global_report_caption=skip_global_report_caption,
rule_exclude_list=exclude_rules,
rule_whitelist=self.get_rule_whitelist(rules),
rulename_modifier=get_name_modifier_func(rules, name_modifier),
resolved_rulename_modifier=get_name_modifier_func(
rules, name_modifier, parent_modifier=self.parent_modifier
),
local_rulename_modifier=get_name_modifier_func(rules, name_modifier),
ruleinfo_overwrite=ruleinfo,
allow_rule_overwrite=True,
namespace=self.name,
replace_prefix=self.replace_prefix,
prefix=self.prefix,
replace_wrapper_tag=self.get_wrapper_tag(),
):
)
with modifier:
self.workflow.include(snakefile, overwrite_default_target=True)
self.parent_modifier.inherit_rule_proxies(modifier)

def get_snakefile(self):
if self.meta_wrapper:
Expand Down Expand Up @@ -138,7 +145,8 @@ def __init__(
skip_configfile=False,
skip_validation=False,
skip_global_report_caption=False,
rulename_modifier=None,
resolved_rulename_modifier=None,
local_rulename_modifier=None,
rule_whitelist=None,
rule_exclude_list=None,
ruleinfo_overwrite=None,
Expand All @@ -153,7 +161,8 @@ def __init__(
self.base_snakefile = parent_modifier.base_snakefile
self.globals = parent_modifier.globals
self.skip_configfile = parent_modifier.skip_configfile
self.rulename_modifier = parent_modifier.rulename_modifier
self.resolved_rulename_modifier = parent_modifier.resolved_rulename_modifier
self.local_rulename_modifier = parent_modifier.local_rulename_modifier
self.skip_validation = parent_modifier.skip_validation
self.skip_global_report_caption = parent_modifier.skip_global_report_caption
self.rule_whitelist = parent_modifier.rule_whitelist
Expand All @@ -175,12 +184,15 @@ def __init__(

self.workflow = workflow
self.base_snakefile = base_snakefile
self.rule_proxies = Rules()

if config is not None:
self.globals["config"] = config
self.globals["rules"] = self.rule_proxies

self.skip_configfile = skip_configfile
self.rulename_modifier = rulename_modifier
self.resolved_rulename_modifier = resolved_rulename_modifier
self.local_rulename_modifier = local_rulename_modifier
self.skip_validation = skip_validation
self.skip_global_report_caption = skip_global_report_caption
self.rule_whitelist = rule_whitelist
Expand All @@ -191,14 +203,21 @@ def __init__(
self.replace_wrapper_tag = replace_wrapper_tag
self.namespace = namespace

def inherit_rule_proxies(self, child_modifier):
if child_modifier.local_rulename_modifier is not None:
for name, rule in child_modifier.rule_proxies._rules.items():
self.rule_proxies._register_rule(
child_modifier.local_rulename_modifier(name), rule
)

def skip_rule(self, rulename):
return (
self.rule_whitelist is not None and rulename not in self.rule_whitelist
) or (self.rule_exclude_list is not None and rulename in self.rule_exclude_list)

def modify_rulename(self, rulename):
if self.rulename_modifier is not None:
return self.rulename_modifier(rulename)
if self.resolved_rulename_modifier is not None:
return self.resolved_rulename_modifier(rulename)
return rulename

def modify_path(self, path, property=None):
Expand Down
10 changes: 5 additions & 5 deletions snakemake/workflow.py
Expand Up @@ -246,7 +246,6 @@ def __init__(
_globals = globals()
_globals["workflow"] = self
_globals["cluster_config"] = copy.deepcopy(self.overwrite_clusterconfig)
_globals["rules"] = Rules()
_globals["checkpoints"] = Checkpoints()
_globals["scatter"] = Scatter()
_globals["gather"] = Gather()
Expand Down Expand Up @@ -1805,9 +1804,10 @@ def decorate(ruleinfo):
self.globals[ruleinfo.func.__name__] = ruleinfo.func

rule_proxy = RuleProxy(rule)
if orig_name is not None:
setattr(self.globals["rules"], orig_name, rule_proxy)
setattr(self.globals["rules"], rule.name, rule_proxy)
# Register rule under its original name.
# Modules using this snakefile as a module, will register it additionally under their
# requested name.
self.modifier.rule_proxies._register_rule(orig_name, rule_proxy)

if checkpoint:
self.globals["checkpoints"].register(rule, fallback_name=orig_name)
Expand Down Expand Up @@ -2118,7 +2118,7 @@ def decorate(maybe_ruleinfo):
with WorkflowModifier(
self,
parent_modifier=self.modifier,
rulename_modifier=get_name_modifier_func(
resolved_rulename_modifier=get_name_modifier_func(
rules, name_modifier, parent_modifier=self.modifier
),
ruleinfo_overwrite=ruleinfo,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_module_nested/Snakefile
@@ -0,0 +1,20 @@
shell.executable("bash")


module shallow_module:
snakefile:
"module_shallow.smk"


use rule deep_work from shallow_module as shallow_work


rule all:
input:
"foo.txt",
default_target: True


assert hasattr(
rules, "shallow_work"
), f"bug: rule cannot be accessed as shallow_work: {dir(rules)}"
1 change: 1 addition & 0 deletions tests/test_module_nested/expected-results/foo.txt
@@ -0,0 +1 @@
I was here
16 changes: 16 additions & 0 deletions tests/test_module_nested/module_deep.smk
@@ -0,0 +1,16 @@
# deep_module.smk
rule all:
input:
"foo.txt",


rule work:
output:
"foo.txt",
shell:
"echo 'I was here' > {output}"


# rules.work has to work even if the rule is renamed in a parent module
# The rulename itself can be already modified.
assert hasattr(rules, "work"), f"bug: rule cannot be accessed as work: {dir(rules)}"
19 changes: 19 additions & 0 deletions tests/test_module_nested/module_shallow.smk
@@ -0,0 +1,19 @@
module deep_module:
snakefile:
"module_deep.smk"


use rule work from deep_module as deep_work


rule all:
input:
"foo.txt",
default_target: True


# rules.deep_work has to work even if the rule is renamed in a parent module
# The rulename itself can be already modified.
assert hasattr(
rules, "deep_work"
), f"bug: rule cannot be accessed as deep_work: {dir(rules)}"
4 changes: 4 additions & 0 deletions tests/tests.py
Expand Up @@ -1705,6 +1705,10 @@ def test_modules_all():
run(dpath("test_modules_all"), targets=["a"])


def test_module_nested():
run(dpath("test_module_nested"))


def test_modules_all_exclude_1():
# Fail due to conflicting rules
run(dpath("test_modules_all_exclude"), shouldfail=True)
Expand Down

0 comments on commit 65c79a4

Please sign in to comment.