Skip to content

Commit

Permalink
test(dags): Add more example DAG tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed May 4, 2022
1 parent 0a1284d commit 67de5e1
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/basic_dag.py
Expand Up @@ -3,7 +3,6 @@

from airflow import DAG
from airflow.utils.dates import days_ago

from airflow_dbt_python.operators.dbt import DbtRunOperator

with DAG(
Expand All @@ -12,6 +11,7 @@
start_date=days_ago(1),
catchup=False,
dagrun_timeout=dt.timedelta(minutes=60),
default_args={"retries": 2},
) as dag:
dbt_run = DbtRunOperator(
task_id="dbt_run_hourly",
Expand Down
28 changes: 26 additions & 2 deletions tests/conftest.py
Expand Up @@ -58,7 +58,8 @@
MODEL_2 = """
{{ config(
materialized="table",
schema="another"
schema="another",
tags=["deprecated"],
) }}
SELECT
Expand All @@ -69,7 +70,8 @@
MODEL_3 = """
{{ config(
materialized="incremental",
schema="a_schema"
schema="a_schema",
tags=["hourly"],
) }}
SELECT
Expand All @@ -86,6 +88,7 @@
MODEL_4 = """
{{ config(
materialized="view",
tags=["hourly"],
) }}
{% set l = ("a", "b", "c") %}
Expand Down Expand Up @@ -521,3 +524,24 @@ def test_files(tmp_path_factory, dbt_project_file):
f1.unlink()
f2.unlink()
f3.unlink()


def pytest_addoption(parser):
parser.addoption(
"--run-integration",
action="store_true",
default=False,
help="run integration tests",
)


def pytest_collection_modifyitems(config, items):
if config.getoption("--run-integration"):
return

skip_integration = pytest.mark.skip(
reason="need --run-integration to run integration tests"
)
for item in items:
if "integration" in item.keywords:
item.add_marker(skip_integration)
129 changes: 125 additions & 4 deletions tests/dags/test_dbt_dags.py
@@ -1,4 +1,5 @@
import datetime as dt
import typing

import pendulum
import pytest
Expand All @@ -14,16 +15,21 @@
DbtBaseOperator,
DbtRunOperator,
DbtSeedOperator,
DbtSourceFreshnessOperator,
DbtTestOperator,
)

DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)


def test_dags_loaded():
@pytest.fixture(scope="session")
def dagbag():
dagbag = DagBag(dag_folder="examples/", include_examples=False)
return dagbag


def test_dags_loaded(dagbag):
assert dagbag.import_errors == {}

for dag_id in dagbag.dag_ids:
Expand Down Expand Up @@ -129,6 +135,10 @@ def taskflow_dag(
catchup=False,
schedule_interval=None,
tags=["taskflow", "dbt"],
default_args={
"retries": 3,
"on_failure_callback": lambda _: print("Failed"),
},
)
def generate_dag():
@task
Expand Down Expand Up @@ -200,6 +210,11 @@ def test_dbt_operators_in_taskflow_dag(taskflow_dag, dbt_project_file, profiles_
ti.run(ignore_ti_state=True)

assert ti.state == TaskInstanceState.SUCCESS
assert ti.task.retries == taskflow_dag.default_args["retries"]
assert (
ti.task.on_failure_callback
== taskflow_dag.default_args["on_failure_callback"]
)

if isinstance(ti.task, DbtBaseOperator):
assert ti.task.profiles_dir == str(profiles_file.parent)
Expand All @@ -217,9 +232,23 @@ def test_dbt_operators_in_taskflow_dag(taskflow_dag, dbt_project_file, profiles_
)


def test_example_basic_dag(dagbag):
def assert_dbt_results(
results, expected_results: dict[typing.Union[RunStatus, TestStatus], int]
):
"""Evaluate dbt run results match expected results."""
assert len(results["results"]) == sum(
expected_results.values()
), "Expected number of results doesn't match"

for state, count in expected_results.items():
assert sum(result["status"] == state for result in results["results"]) == count


def test_example_basic_dag(
dagbag, dbt_project_file, profiles_file, model_files, seed_files
):
"""Test the example basic DAG."""
dag = dagbag.get_dag(dag_id="example_basic_dag")
dag = dagbag.get_dag(dag_id="example_basic_dbt")

assert dag is not None
assert len(dag.tasks) == 1
Expand All @@ -229,11 +258,41 @@ def test_example_basic_dag(dagbag):
assert dbt_run.select == ["+tag:hourly"]
assert dbt_run.exclude == ["tag:deprecated"]
assert dbt_run.full_refresh is False
assert dbt_run.retries == 2

dbt_run.project_dir = dbt_project_file.parent
dbt_run.profiles_dir = profiles_file.parent
dbt_run.target = "test"
dbt_run.profile = "default"

dagrun = dag.create_dagrun(
state=DagRunState.RUNNING,
execution_date=dag.start_date,
data_interval=(dag.start_date, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
)

ti = dagrun.get_task_instance(task_id="dbt_run_hourly")
ti.task = dbt_run

ti.run(ignore_ti_state=True)

assert ti.state == TaskInstanceState.SUCCESS

results = ti.xcom_pull(
task_ids="dbt_run_hourly",
key="return_value",
)
expected = {
RunStatus.Success: 2,
}
assert_dbt_results(results, expected)


def test_example_dbt_project_in_s3_dag(dagbag):
"""Test the example basic DAG."""
dag = dagbag.get_dag(dag_id="dbt_project_in_s3_dag")
dag = dagbag.get_dag(dag_id="example_basic_dbt_run_with_s3")

assert dag is not None
assert len(dag.tasks) == 2
Expand All @@ -243,3 +302,65 @@ def test_example_dbt_project_in_s3_dag(dagbag):
assert dbt_run.select == ["+tag:hourly"]
assert dbt_run.exclude == ["tag:deprecated"]
assert dbt_run.full_refresh is False


def test_example_complete_dbt_workflow_dag(
dagbag,
dbt_project_file,
profiles_file,
model_files,
seed_files,
singular_tests_files,
generic_tests_files,
):
"""Test the example complete dbt workflow DAG."""
dag = dagbag.get_dag(dag_id="example_complete_dbt_workflow")

assert dag is not None
assert len(dag.tasks) == 5

dagrun = dag.create_dagrun(
state=DagRunState.RUNNING,
execution_date=dag.start_date,
data_interval=(dag.start_date, DATA_INTERVAL_END),
start_date=DATA_INTERVAL_END,
run_type=DagRunType.MANUAL,
)

for task in dag.tasks:

task.project_dir = dbt_project_file.parent
task.profiles_dir = profiles_file.parent
task.target = "test"
task.profile = "default"

ti = dagrun.get_task_instance(task_id=task.task_id)
ti.task = task

ti.run(ignore_ti_state=True)

assert ti.state == TaskInstanceState.SUCCESS

if not isinstance(task, DbtSourceFreshnessOperator):
results = ti.xcom_pull(
task_ids=task.task_id,
key="return_value",
)

if task.task_id == "dbt_run_incremental_hourly":
expected = {
RunStatus.Success: 1,
}
elif task.task_id == "dbt_seed":
expected = {
RunStatus.Success: 2,
}
elif task.task_id == "dbt_run_hourly":
expected = {
RunStatus.Success: 2,
}
elif task.task_id == "dbt_test":
expected = {
TestStatus.Pass: 7,
}
assert_dbt_results(results, expected)

0 comments on commit 67de5e1

Please sign in to comment.