From 3b51eb195f8b14af218cba6562cd47d1f7459e3c Mon Sep 17 00:00:00 2001 From: mostaphaRoudsari Date: Thu, 11 Aug 2022 15:25:31 -0400 Subject: [PATCH] feat(dag): add support for Grouped DAGs --- pollination_dsl/dag/__init__.py | 2 +- pollination_dsl/dag/base.py | 37 ++++++++++++++++++++++++++++++++- pollination_dsl/dag/task.py | 2 +- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pollination_dsl/dag/__init__.py b/pollination_dsl/dag/__init__.py index 2567da7..316cfeb 100644 --- a/pollination_dsl/dag/__init__.py +++ b/pollination_dsl/dag/__init__.py @@ -1,4 +1,4 @@ from .inputs import Inputs # expose for easy import from .outputs import Outputs -from .base import DAG +from .base import DAG, GroupedDAG from .task import task diff --git a/pollination_dsl/dag/base.py b/pollination_dsl/dag/base.py index a131c85..513fb98 100644 --- a/pollination_dsl/dag/base.py +++ b/pollination_dsl/dag/base.py @@ -1,8 +1,9 @@ from dataclasses import dataclass import inspect +import warnings from queenbee.recipe.dag import DAG as QBDAG - +from queenbee.recipe.recipe import TemplateFunction from ..common import _BaseClass @@ -111,3 +112,37 @@ def _dependencies(self): dependencies[key].append(v) return dependencies + + + +@dataclass +class GroupedDAG(DAG): + """A grouped DAG is a special DAG that will be executed on the same Pod. + + Grouped DAG is useful to group similar small tasks together to run them faster. + Unlike the default DAG, the tasks in group DAG cannot have a for loop. + """ + + @property + def queenbee(self) -> QBDAG: + dag = super().queenbee + dag.annotations['__is_grouped__'] = True + if len(dag.tasks) == 1: + warnings.warn( + 'A grouped DAG usually has more than one task. Consider using a ' + f'standard DAG for "{dag.name}"' + ) + for task in dag.tasks: + if task.loop: + raise ValueError( + f'Found a loop object in Task "{task.name}" in GroupedDAG ' + f'{dag.name}. Either remove the for loop or use a standard DAG.' + ) + + # We can technically support this but I'm keeping it simple for now + assert not task.parameter_returns, \ + f'Found an invalid task "{task.name}" with parameter output in ' \ + f'GroupedDAG "{dag.name}". Only file or folder outputs are allowed in ' \ + f'GroupedDAG.\n{task.parameter_returns}' + + return dag diff --git a/pollination_dsl/dag/task.py b/pollination_dsl/dag/task.py index 7a938eb..8dd8a55 100644 --- a/pollination_dsl/dag/task.py +++ b/pollination_dsl/dag/task.py @@ -165,7 +165,7 @@ def _get_task_returns(func) -> NamedTuple: pattern = r'[\'\"]from[\'\"]\s*:\s*.*\._outputs\.(\S*)\s*[,}]' parent = func.__name__.replace('_', '-') src = inspect.getsource(func) - # remove the last } which happens in case of parameters input. Somene who + # remove the last } which happens in case of parameters input. Someone who # knows regex better than I do should be able to fix this by changing the pattern # here is an example to recreate the issue. # return [