From 55a7445bf724a06bf6beda81b7161d94cf8400d2 Mon Sep 17 00:00:00 2001 From: Robert J McGinness Date: Thu, 22 Sep 2022 01:03:13 -0400 Subject: [PATCH] issue #26499 fixed _ZipResult length bug --- airflow/models/xcom_arg.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 2fb60195ef911..b5e0a645452b8 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -29,7 +29,7 @@ from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import NOTSET +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.models.dag import DAG @@ -315,7 +315,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: task_id = self.operator.task_id result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session) - if result is not NOTSET: + if not isinstance(result, ArgNotSet): return result if self.key == XCOM_RETURN_KEY: return None @@ -388,6 +388,9 @@ class _ZipResult(Sequence): def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: self.values = values self.fillvalue = fillvalue + # use the generator here, rather than in __len__ to improve efficiency + lengths = (len(v) for v in self.values) + self.length = min(lengths) if isinstance(self.fillvalue, ArgNotSet) else max(lengths) @staticmethod def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: @@ -402,10 +405,7 @@ def __getitem__(self, index: Any) -> Any: return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) def __len__(self) -> int: - lengths = (len(v) for v in self.values) - if self.fillvalue is NOTSET: - return min(lengths) - return max(lengths) + return self.length class ZipXComArg(XComArg): @@ -426,13 +426,13 @@ def __repr__(self) -> str: args_iter = iter(self.args) first = repr(next(args_iter)) rest = ", ".join(repr(arg) for arg in args_iter) - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return f"{first}.zip({rest})" return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" def _serialize(self) -> dict[str, Any]: args = [serialize_xcom_arg(arg) for arg in self.args] - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return {"args": args} return {"args": args, "fillvalue": self.fillvalue} @@ -452,7 +452,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(self.args): return None # If any of the referenced XComs is not ready, we are not ready either. - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return min(ready_lengths) return max(ready_lengths)