Skip to content

Commit

Permalink
feat(dag): add support for Grouped DAGs
Browse files Browse the repository at this point in the history
  • Loading branch information
mostaphaRoudsari committed Oct 8, 2022
1 parent 1d1b4aa commit 3b51eb1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pollination_dsl/dag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .inputs import Inputs # expose for easy import
from .outputs import Outputs
from .base import DAG
from .base import DAG, GroupedDAG
from .task import task
37 changes: 36 additions & 1 deletion pollination_dsl/dag/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
import inspect
import warnings

from queenbee.recipe.dag import DAG as QBDAG

from queenbee.recipe.recipe import TemplateFunction
from ..common import _BaseClass


Expand Down Expand Up @@ -111,3 +112,37 @@ def _dependencies(self):
dependencies[key].append(v)

return dependencies



@dataclass
class GroupedDAG(DAG):
"""A grouped DAG is a special DAG that will be executed on the same Pod.
Grouped DAG is useful to group similar small tasks together to run them faster.
Unlike the default DAG, the tasks in group DAG cannot have a for loop.
"""

@property
def queenbee(self) -> QBDAG:
dag = super().queenbee
dag.annotations['__is_grouped__'] = True
if len(dag.tasks) == 1:
warnings.warn(
'A grouped DAG usually has more than one task. Consider using a '
f'standard DAG for "{dag.name}"'
)
for task in dag.tasks:
if task.loop:
raise ValueError(
f'Found a loop object in Task "{task.name}" in GroupedDAG '
f'{dag.name}. Either remove the for loop or use a standard DAG.'
)

# We can technically support this but I'm keeping it simple for now
assert not task.parameter_returns, \
f'Found an invalid task "{task.name}" with parameter output in ' \
f'GroupedDAG "{dag.name}". Only file or folder outputs are allowed in ' \
f'GroupedDAG.\n{task.parameter_returns}'

return dag
2 changes: 1 addition & 1 deletion pollination_dsl/dag/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _get_task_returns(func) -> NamedTuple:
pattern = r'[\'\"]from[\'\"]\s*:\s*.*\._outputs\.(\S*)\s*[,}]'
parent = func.__name__.replace('_', '-')
src = inspect.getsource(func)
# remove the last } which happens in case of parameters input. Somene who
# remove the last } which happens in case of parameters input. Someone who
# knows regex better than I do should be able to fix this by changing the pattern
# here is an example to recreate the issue.
# return [
Expand Down

0 comments on commit 3b51eb1

Please sign in to comment.