Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for naming tasks in @requires #3077

Merged
merged 3 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 31 additions & 18 deletions luigi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,33 +279,47 @@ def run(self):
# ...
"""

def __init__(self, *tasks_to_inherit):
def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit):
super(inherits, self).__init__()
if not tasks_to_inherit:
raise TypeError("tasks_to_inherit cannot be empty")

if not tasks_to_inherit and not kw_tasks_to_inherit:
raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task")
if tasks_to_inherit and kw_tasks_to_inherit:
raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present")
self.tasks_to_inherit = tasks_to_inherit
self.kw_tasks_to_inherit = kw_tasks_to_inherit

def __call__(self, task_that_inherits):
# Get all parameter objects from each of the underlying tasks
for task_to_inherit in self.tasks_to_inherit:
task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values()
for task_to_inherit in task_iterator:
for param_name, param_obj in task_to_inherit.get_params():
# Check if the parameter exists in the inheriting task
if not hasattr(task_that_inherits, param_name):
# If not, add it to the inheriting task
setattr(task_that_inherits, param_name, param_obj)

# Modify task_that_inherits by adding methods
def clone_parent(_self, **kwargs):
return _self.clone(cls=self.tasks_to_inherit[0], **kwargs)
task_that_inherits.clone_parent = clone_parent

def clone_parents(_self, **kwargs):
return [
_self.clone(cls=task_to_inherit, **kwargs)
for task_to_inherit in self.tasks_to_inherit
]
task_that_inherits.clone_parents = clone_parents
# Handle unnamed tasks as a list, named as a dictionary
if self.tasks_to_inherit:
def clone_parent(_self, **kwargs):
return _self.clone(cls=self.tasks_to_inherit[0], **kwargs)
task_that_inherits.clone_parent = clone_parent

def clone_parents(_self, **kwargs):
return [
_self.clone(cls=task_to_inherit, **kwargs)
for task_to_inherit in self.tasks_to_inherit
]
task_that_inherits.clone_parents = clone_parents
elif self.kw_tasks_to_inherit:
# Even if there is just one named task, return a dictionary
def clone_parents(_self, **kwargs):
return {
task_name: _self.clone(cls=task_to_inherit, **kwargs)
for task_name, task_to_inherit in self.kw_tasks_to_inherit.items()
}
task_that_inherits.clone_parents = clone_parents

return task_that_inherits

Expand All @@ -318,15 +332,14 @@ class requires:

"""

def __init__(self, *tasks_to_require):
def __init__(self, *tasks_to_require, **kw_tasks_to_require):
super(requires, self).__init__()
if not tasks_to_require:
raise TypeError("tasks_to_require cannot be empty")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error isn't required as inherits will be called, and inherits throws the same errors.


self.tasks_to_require = tasks_to_require
self.kw_tasks_to_require = kw_tasks_to_require

def __call__(self, task_that_requires):
task_that_requires = inherits(*self.tasks_to_require)(task_that_requires)
task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires)

# Modify task_that_requires by adding requires method.
# If only one task is required, this single task is returned.
Expand Down
29 changes: 29 additions & 0 deletions test/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def requires(self):
self.assertEqual(str(child_task), 'blah.ChildTask(my_param=hello)')
self.assertIn(ParentTask(my_param='hello'), luigi.task.flatten(child_task.requires()))

def test_task_ids_using_inherits_kwargs(self):
class ParentTask(luigi.Task):
my_param = luigi.Parameter()
luigi.namespace('blah')

@inherits(parent=ParentTask)
class ChildTask(luigi.Task):
def requires(self):
return self.clone(ParentTask)
luigi.namespace('')
child_task = ChildTask(my_param='hello')
self.assertEqual(str(child_task), 'blah.ChildTask(my_param=hello)')
self.assertIn(ParentTask(my_param='hello'), luigi.task.flatten(child_task.requires()))

def _setup_parent_and_child_inherits(self):
class ParentTask(luigi.Task):
my_parameter = luigi.Parameter()
Expand Down Expand Up @@ -174,3 +188,18 @@ def test_requires_has_effect_MRO(self):
ChildTask = self._setup_requires_inheritence()
self.assertNotEqual(str(ChildTask.__mro__[0]),
str(ChildTask.__mro__[1]))

def test_kwargs_requires_gives_named_inputs(self):
class ParentTask(RunOnceTask):
def output(self):
return "Target"

@requires(parent_1=ParentTask, parent_2=ParentTask)
class ChildTask(RunOnceTask):
resulting_input = 'notset'

def run(self):
self.__class__.resulting_input = self.input()

self.assertTrue(self.run_locally_split('ChildTask'))
self.assertEqual(ChildTask.resulting_input, {'parent_1': 'Target', 'parent_2': 'Target'})