Skip to content

Commit

Permalink
[AIP-34] TaskGroup: A UI task grouping concept as an alternative to S…
Browse files Browse the repository at this point in the history
…ubDagOperator apache#10153

- Introduce TaskMixin (apache#10930)
- Test updates
  • Loading branch information
yuqian90 committed Sep 19, 2020
1 parent a82ed64 commit b6a80ec
Show file tree
Hide file tree
Showing 16 changed files with 1,935 additions and 177 deletions.
58 changes: 58 additions & 0 deletions airflow/example_dags/example_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the TaskGroup."""

from airflow.models.dag import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup

# [START howto_task_group]
with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag:
start = DummyOperator(task_id="start")

# [START howto_task_group_section_1]
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
task_1 = DummyOperator(task_id="task_1")
task_2 = DummyOperator(task_id="task_2")
task_3 = DummyOperator(task_id="task_3")

task_1 >> [task_2, task_3]
# [END howto_task_group_section_1]

# [START howto_task_group_section_2]
with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
task_1 = DummyOperator(task_id="task_1")

# [START howto_task_group_inner_section_2]
with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2:
task_2 = DummyOperator(task_id="task_2")
task_3 = DummyOperator(task_id="task_3")
task_4 = DummyOperator(task_id="task_4")

[task_2, task_3] >> task_4
# [END howto_task_group_inner_section_2]

# [END howto_task_group_section_2]

end = DummyOperator(task_id='end')

start >> section_1 >> section_2 >> end
# [END howto_task_group]
79 changes: 46 additions & 33 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@

from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union

from typing import (
Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Type, Union
)

import attr
from cached_property import cached_property
Expand All @@ -43,6 +44,7 @@
from airflow.models.dag import DAG
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
Expand All @@ -59,7 +61,7 @@


@functools.total_ordering
class BaseOperator(LoggingMixin):
class BaseOperator(LoggingMixin, TaskMixin):
"""
Abstract base class for all operators. Since operators create objects that
become nodes in the dag, BaseOperator contains many recursive methods for
Expand Down Expand Up @@ -324,9 +326,11 @@ def __init__(
do_xcom_push=True, # type: bool
inlets=None, # type: Optional[Dict]
outlets=None, # type: Optional[Dict]
task_group=None,
*args,
**kwargs
):
from airflow.utils.task_group import TaskGroupContext

if args or kwargs:
# TODO remove *args and **kwargs in Airflow 2.0
Expand All @@ -341,6 +345,11 @@ def __init__(
)
validate_key(task_id)
self.task_id = task_id
self.label = task_id
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
if task_group:
self.task_id = task_group.child_id(task_id)
task_group.add(self)
self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
Expand Down Expand Up @@ -472,13 +481,14 @@ def __hash__(self):
hash_components.append(repr(val))
return hash(tuple(hash_components))

# Composing Operators -----------------------------------------------

def __rshift__(self, other):
"""
Implements Self >> Other == self.set_downstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
NOTE: This method is supposed to have moved to TaskMixin. But this override is needed
here because of this special treatment for DAG. It can be removed in Airflow 2.0.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
Expand All @@ -494,6 +504,9 @@ def __lshift__(self, other):
Implements Self << Other == self.set_upstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
NOTE: This method is supposed to have moved to TaskMixin. But this override is needed
here because of this special treatment for DAG. It can be removed in Airflow 2.0.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
Expand All @@ -504,24 +517,6 @@ def __lshift__(self, other):
self.set_upstream(other)
return other

def __rrshift__(self, other):
"""
Called for [DAG] >> [Operator] because DAGs don't have
__rshift__ operators.
"""
self.__lshift__(other)
return self

def __rlshift__(self, other):
"""
Called for [DAG] << [Operator] because DAGs don't have
__lshift__ operators.
"""
self.__rshift__(other)
return self

# /Composing Operators ---------------------------------------------

@property
def dag(self):
"""
Expand Down Expand Up @@ -989,12 +984,30 @@ def add_only_new(self, item_set, item):
else:
item_set.add(item)

def _set_relatives(self, task_or_task_list, upstream=False):
"""Sets relatives for the task."""
try:
task_list = list(task_or_task_list)
except TypeError:
task_list = [task_or_task_list]
@property
def roots(self):
"""Required by TaskMixin"""
return [self]

@property
def leaves(self):
"""Required by TaskMixin"""
return [self]

def _set_relatives(
self,
task_or_task_list, # type: Union[TaskMixin, Sequence[TaskMixin]]
upstream=False,
):
"""Sets relatives for the task or task list."""
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]

task_list = [] # type: List["BaseOperator"]
for task_object in task_or_task_list:
task_object.update_relative(self, not upstream)
relatives = task_object.leaves if upstream else task_object.roots
task_list.extend(relatives)

for task in task_list:
if not isinstance(task, BaseOperator):
Expand All @@ -1005,8 +1018,8 @@ def _set_relatives(self, task_or_task_list, upstream=False):
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
dags = {
task._dag.dag_id: task._dag # pylint: disable=protected-access
for task in [self] + task_list if task.has_dag()}
task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access,no-member
for task in self.roots + task_list if task.has_dag()} # pylint: disable=no-member

if len(dags) > 1:
raise AirflowException(
Expand Down Expand Up @@ -1036,14 +1049,14 @@ def _set_relatives(self, task_or_task_list, upstream=False):
def set_downstream(self, task_or_task_list):
"""
Set a task or a task list to be directly downstream from the current
task.
task. Required by TaskMixin.
"""
self._set_relatives(task_or_task_list, upstream=False)

def set_upstream(self, task_or_task_list):
"""
Set a task or a task list to be directly upstream from the current
task.
task. Required by TaskMixin.
"""
self._set_relatives(task_or_task_list, upstream=True)

Expand Down
50 changes: 44 additions & 6 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import warnings
from collections import OrderedDict, defaultdict
from datetime import timedelta, datetime
from typing import TYPE_CHECKING, Callable, Dict, FrozenSet, Iterable, List, Optional, Type, Union
from typing import Callable, Dict, FrozenSet, Iterable, List, Optional, Type, Union

import jinja2
import pendulum
Expand Down Expand Up @@ -64,9 +64,6 @@
from airflow.utils.sqlalchemy import UtcDateTime, Interval
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator # Avoid circular dependency

install_aliases()

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -245,6 +242,9 @@ def __init__(
jinja_environment_kwargs=None, # type: Optional[Dict]
tags=None, # type: Optional[List[str]]
):
from airflow.utils.task_group import TaskGroup
from airflow.models.baseoperator import BaseOperator

self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = copy.deepcopy(default_args or {})
Expand Down Expand Up @@ -329,6 +329,7 @@ def __init__(

self.jinja_environment_kwargs = jinja_environment_kwargs
self.tags = tags
self._task_group = TaskGroup.create_root(self)

def __repr__(self):
return "<DAG: {self.dag_id}>".format(self=self)
Expand Down Expand Up @@ -591,6 +592,10 @@ def filepath(self):
fn = fn.replace(os.path.dirname(__file__) + '/', '')
return fn

@property
def task_group(self):
return self._task_group

@property
def folder(self):
"""Folder location of where the DAG object is instantiated."""
Expand Down Expand Up @@ -1221,6 +1226,7 @@ def sub_dag(self, task_regex, include_downstream=False,
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.
"""
from airflow.models.baseoperator import BaseOperator

# deep-copying self.task_dict takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
Expand All @@ -1242,9 +1248,38 @@ def sub_dag(self, task_regex, include_downstream=False,
# Make sure to not recursively deepcopy the dag while copying the task
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
for t in regex_match + also_include}

# Remove tasks not included in the subdag from task_group
def remove_excluded(group):
for child in list(group.children.values()):
if isinstance(child, BaseOperator):
if child.task_id not in dag.task_dict:
group.children.pop(child.task_id)
else:
# The tasks in the subdag are a copy of tasks in the original dag
# so update the reference in the TaskGroups too.
group.children[child.task_id] = dag.task_dict[child.task_id]
else:
remove_excluded(child)

# Remove this TaskGroup if it doesn't contain any tasks in this subdag
if not child.children:
group.children.pop(child.group_id)

remove_excluded(dag.task_group)

# Removing upstream/downstream references to tasks and TaskGroups that did not make
# the cut.
subdag_task_groups = dag.task_group.get_task_group_dict()
for group in subdag_task_groups.values():
group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys())
group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys())
group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys())
group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys())

for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# made the cut
# make the cut
t._upstream_task_ids = t._upstream_task_ids.intersection(dag.task_dict.keys())
t._downstream_task_ids = t._downstream_task_ids.intersection(
dag.task_dict.keys())
Expand Down Expand Up @@ -1332,7 +1367,8 @@ def add_task(self, task):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)

if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task)
or task.task_id in self._task_group.used_group_ids):
# TODO: raise an error in Airflow 2.0
warnings.warn(
'The requested task could not be added to the DAG because a '
Expand All @@ -1343,6 +1379,8 @@ def add_task(self, task):
else:
self.task_dict[task.task_id] = task
task.dag = self
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
self._task_group.used_group_ids.add(task.task_id)

self.task_count = len(self.task_dict)

Expand Down
Loading

0 comments on commit b6a80ec

Please sign in to comment.