Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,12 @@ jobs:
- '3.10'
- '3.9'
airflow-version:
- '3.0'
- '2.10'
- '2.9'
- '2.8'
dbt-version:
- 1.9
- 1.8
exclude:
# Incompatible combinations
- python-version: 3.12
airflow-version: '2.8'

runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tagged_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
with:
token: ${{ secrets.GH_DEPLOY_TOKEN }}
checkName: CI
ref: ${{ github.sha }}
ref: ${{ github.ref_name }}
# Wait for one hour
timeoutSeconds: 3600
intervalSeconds: 60
Expand Down
85 changes: 67 additions & 18 deletions tests/dags/test_dbt_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
airflow = pytest.importorskip("airflow", minversion="2.2")

from airflow import DAG, settings
from airflow import __version__ as airflow_version
from airflow.models import DagBag, DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.state import DagRunState, TaskInstanceState
Expand All @@ -27,6 +28,41 @@

DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)
AIRFLOW_MAJOR_VERSION = int(airflow_version.split(".", 1)[0])


def _create_dagrun(
parent_dag: DAG,
state: DagRunState,
logical_date: dt.datetime,
data_interval: tuple[dt.datetime, dt.datetime],
start_date: dt.datetime,
run_type: DagRunType,
) -> DagRun:
if AIRFLOW_MAJOR_VERSION >= 3:
from airflow.utils.types import DagRunTriggeredByType # type: ignore

return parent_dag.create_dagrun( # type: ignore
run_id=f"{parent_dag.dag_id}-{logical_date.isoformat()}-RUN",
state=state,
logical_date=logical_date,
data_interval=data_interval,
start_date=start_date,
conf={},
backfill_id=None,
creating_job_id=None,
run_type=run_type,
run_after=dt.datetime(1970, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc),
triggered_by=DagRunTriggeredByType.TIMETABLE,
)
else:
return parent_dag.create_dagrun(
state=state,
execution_date=logical_date, # type: ignore
data_interval=data_interval,
start_date=start_date,
run_type=run_type,
)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -115,9 +151,10 @@ def test_dbt_operators_in_dag(
basic_dag, dbt_project_file, profiles_file, clear_dagruns
):
"""Assert DAG contains correct dbt operators when running."""
dagrun = basic_dag.create_dagrun(
dagrun = _create_dagrun(
basic_dag,
state=DagRunState.RUNNING,
execution_date=DATA_INTERVAL_START,
logical_date=DATA_INTERVAL_START,
data_interval=(DATA_INTERVAL_START, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
Expand Down Expand Up @@ -216,9 +253,15 @@ def test_dbt_operators_in_taskflow_dag(
taskflow_dag, dbt_project_file, profiles_file, clear_dagruns
):
"""Assert DAG contains correct dbt operators when running."""
dagrun = taskflow_dag.create_dagrun(
if AIRFLOW_MAJOR_VERSION >= 3:
dag = DAG.from_sdk_dag(taskflow_dag)
else:
dag = taskflow_dag

dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
execution_date=DATA_INTERVAL_START,
logical_date=DATA_INTERVAL_START,
data_interval=(DATA_INTERVAL_START, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
Expand All @@ -232,16 +275,18 @@ def test_dbt_operators_in_taskflow_dag(
"dbt_test_taskflow",
):
ti = dagrun.get_task_instance(task_id=task_id)
ti.task = taskflow_dag.get_task(task_id=task_id)
ti.task = dag.get_task(task_id=task_id)

ti.run(ignore_ti_state=True)

assert ti.state == TaskInstanceState.SUCCESS
assert ti.task.retries == taskflow_dag.default_args["retries"]
assert (
ti.task.on_failure_callback
== taskflow_dag.default_args["on_failure_callback"]
)
assert ti.task.retries == dag.default_args["retries"]

if isinstance(ti.task.on_failure_callback, list):
failure_callback = ti.task.on_failure_callback[0]
else:
failure_callback = ti.task.on_failure_callback
assert failure_callback == dag.default_args["on_failure_callback"]

if isinstance(ti.task, DbtBaseOperator):
assert ti.task.profiles_dir == str(profiles_file.parent)
Expand Down Expand Up @@ -349,9 +394,10 @@ def test_dbt_operators_in_connection_dag(
target_connection_dag, dbt_project_file, clear_dagruns
):
"""Assert DAG contains correct dbt operators when running."""
dagrun = target_connection_dag.create_dagrun(
dagrun = _create_dagrun(
target_connection_dag,
state=DagRunState.RUNNING,
execution_date=DATA_INTERVAL_START,
logical_date=DATA_INTERVAL_START,
data_interval=(DATA_INTERVAL_START, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
Expand Down Expand Up @@ -414,9 +460,10 @@ def test_example_basic_dag(
dbt_run.target = "test"
dbt_run.profile = "default"

dagrun = dag.create_dagrun(
dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
execution_date=dag.start_date,
logical_date=dag.start_date,
data_interval=(dag.start_date, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
Expand Down Expand Up @@ -460,9 +507,10 @@ def test_example_dbt_project_in_github_dag(dagbag, connection, clear_dagruns):
assert dag is not None
assert len(dag.tasks) == 3

dagrun = dag.create_dagrun(
dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
execution_date=dag.start_date,
logical_date=dag.start_date,
data_interval=(dag.start_date, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
Expand Down Expand Up @@ -511,9 +559,10 @@ def test_example_complete_dbt_workflow_dag(
assert dag is not None
assert len(dag.tasks) == 5

dagrun = dag.create_dagrun(
dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
execution_date=dag.start_date,
logical_date=dag.start_date,
data_interval=(dag.start_date, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
Expand Down
2 changes: 1 addition & 1 deletion tests/operators/test_dbt_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class FakeTaskInstance:
def __init__(self):
self.xcom = {}

def xcom_push(self, key, value, execution_date):
def xcom_push(self, key, value, execution_date=None):
"""Fake xcom_push method that stores the value in instance attribute."""
self.xcom[key] = (value, execution_date)

Expand Down
Loading
Loading