Skip to content

Commit

Permalink
Revert "Revert "Save scheduler execution time by caching dags (apache…
Browse files Browse the repository at this point in the history
…#30704)" (apache#31413)"

This reverts commit e6f2117.
  • Loading branch information
potiuk committed May 19, 2023
1 parent e6f2117 commit 978a6f7
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import lru_cache, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Collection, Iterable, Iterator
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator

from sqlalchemy import and_, func, not_, or_, text
from sqlalchemy.exc import OperationalError
Expand Down Expand Up @@ -1052,8 +1053,13 @@ def _do_scheduling(self, session: Session) -> int:
callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)

# Send the callbacks after we commit to ensure the context is up to date when it gets run
# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
)
for dag_run, callback_to_run in callback_tuples:
dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = cached_get_dag(dag_run.dag_id)

if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
Expand Down Expand Up @@ -1317,8 +1323,14 @@ def _update_state(dag: DAG, dag_run: DagRun):
tags={"dag_id": dag.dag_id},
)

# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
)

for dag_run in dag_runs:
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = dag_run.dag = cached_get_dag(dag_run.dag_id)

if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
Expand Down

0 comments on commit 978a6f7

Please sign in to comment.