Skip to content

Commit

Permalink
issue apache#26499 fixed _ZipResult length bug
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmcginness committed Sep 23, 2022
1 parent 55d1146 commit 55a7445
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import NOTSET
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.models.dag import DAG
Expand Down Expand Up @@ -315,7 +315,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
task_id = self.operator.task_id
result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session)
if result is not NOTSET:
if not isinstance(result, ArgNotSet):
return result
if self.key == XCOM_RETURN_KEY:
return None
Expand Down Expand Up @@ -388,6 +388,9 @@ class _ZipResult(Sequence):
def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
self.values = values
self.fillvalue = fillvalue
# use the generator here, rather than in __len__ to improve efficiency
lengths = (len(v) for v in self.values)
self.length = min(lengths) if isinstance(self.fillvalue, ArgNotSet) else max(lengths)

@staticmethod
def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
Expand All @@ -402,10 +405,7 @@ def __getitem__(self, index: Any) -> Any:
return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)

def __len__(self) -> int:
lengths = (len(v) for v in self.values)
if self.fillvalue is NOTSET:
return min(lengths)
return max(lengths)
return self.length


class ZipXComArg(XComArg):
Expand All @@ -426,13 +426,13 @@ def __repr__(self) -> str:
args_iter = iter(self.args)
first = repr(next(args_iter))
rest = ", ".join(repr(arg) for arg in args_iter)
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return f"{first}.zip({rest})"
return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"

def _serialize(self) -> dict[str, Any]:
args = [serialize_xcom_arg(arg) for arg in self.args]
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return {"args": args}
return {"args": args, "fillvalue": self.fillvalue}

Expand All @@ -452,7 +452,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(self.args):
return None # If any of the referenced XComs is not ready, we are not ready either.
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return min(ready_lengths)
return max(ready_lengths)

Expand Down

0 comments on commit 55a7445

Please sign in to comment.