From ff1c544aac7bf77693dcb940e05805d11d0e9944 Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Sat, 16 Oct 2021 18:45:59 +0200 Subject: [PATCH 1/2] feat: Add support for new dbt build command and arguments --- airflow_dbt_python/operators/dbt.py | 66 ++++++- tests/conftest.py | 19 ++ tests/test_dbt_build.py | 290 ++++++++++++++++++++++++++++ tests/test_dbt_compile.py | 3 +- tests/test_dbt_debug.py | 1 + tests/test_dbt_list.py | 5 + tests/test_dbt_run.py | 18 -- tests/test_dbt_source.py | 15 +- 8 files changed, 395 insertions(+), 22 deletions(-) create mode 100644 tests/test_dbt_build.py diff --git a/airflow_dbt_python/operators/dbt.py b/airflow_dbt_python/operators/dbt.py index 67bed85e..7015d345 100644 --- a/airflow_dbt_python/operators/dbt.py +++ b/airflow_dbt_python/operators/dbt.py @@ -14,6 +14,8 @@ from dbt.contracts.results import RunExecutionResult, RunResult, agate from dbt.logger import log_manager from dbt.main import initialize_config_values, parse_args, track_run +from dbt.semver import VersionSpecifier +from dbt.version import installed as installed_version from airflow import AirflowException from airflow.models.baseoperator import BaseOperator @@ -308,6 +310,7 @@ def __init__( self, full_refresh: Optional[bool] = None, models: Optional[list[str]] = None, + select: Optional[list[str]] = None, fail_fast: Optional[bool] = None, threads: Optional[int] = None, exclude: Optional[list[str]] = None, @@ -527,6 +530,7 @@ class DbtLsOperator(DbtBaseOperator): "exclude", "selector", "dbt_output", + "output_keys", ] def __init__( @@ -537,6 +541,7 @@ def __init__( exclude: Optional[list[str]] = None, selector: Optional[str] = None, dbt_output: Optional[str] = None, + output_keys: Optional[list[str]] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -546,6 +551,7 @@ def __init__( self.exclude = exclude self.selector = selector self.dbt_output = dbt_output + self.output_keys = output_keys # Convinience alias @@ -599,7 +605,9 @@ class DbtSourceOperator(DbtBaseOperator): def __init__( self, # Only one subcommand is currently provided - subcommand: str = "snapshot-freshness", + subcommand: str = "freshness" + if not installed_version < VersionSpecifier.from_version_string("0.21.0") + else "snapshot-freshness", select: Optional[list[str]] = None, dbt_output: Optional[Union[str, Path]] = None, threads: Optional[int] = None, @@ -617,6 +625,62 @@ def __init__( self.state = state +class DbtBuildOperator(DbtBaseOperator): + """Execute dbt build. + + The build command combines the run, test, seed, and snapshot commands into one. The + full Documentation for the dbt build command can be found here: + https://docs.getdbt.com/reference/commands/build. + """ + + command = "build" + + __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ + "full_refresh", + "select", + "fail_fast", + "threads", + "exclude", + "selector", + "state", + "defer", + "no_defer", + "data", + "schema", + "show", + ] + + def __init__( + self, + full_refresh: Optional[bool] = None, + select: Optional[list[str]] = None, + fail_fast: Optional[bool] = None, + threads: Optional[int] = None, + exclude: Optional[list[str]] = None, + selector: Optional[str] = None, + state: Optional[Union[str, Path]] = None, + defer: Optional[bool] = None, + no_defer: Optional[bool] = None, + data: Optional[bool] = None, + schema: Optional[bool] = None, + show: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.full_refresh = full_refresh + self.select = select + self.fail_fast = fail_fast + self.threads = threads + self.exclude = exclude + self.selector = selector + self.state = state + self.defer = defer + self.no_defer = no_defer + self.data = data + self.schema = schema + self.show = show + + def run_result_factory(data: list[tuple[Any, Any]]): """Dictionary factory for dbt's run_result. diff --git a/tests/conftest.py b/tests/conftest.py index 491864d7..fbfbb531 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,3 +228,22 @@ def s3_bucket(mocked_s3_res): bucket = "airflow-dbt-test-s3-bucket" mocked_s3_res.create_bucket(Bucket=bucket) return bucket + + +BROKEN_SQL = """ +SELECT + field1 AS field1 +FROM + non_existent_table +WHERE + field1 > 1 +""" + + +@pytest.fixture +def broken_file(dbt_project_dir): + d = dbt_project_dir / "models" + m = d / "broken.sql" + m.write_text(BROKEN_SQL) + yield m + m.unlink() diff --git a/tests/test_dbt_build.py b/tests/test_dbt_build.py new file mode 100644 index 00000000..90994a3c --- /dev/null +++ b/tests/test_dbt_build.py @@ -0,0 +1,290 @@ +import json +from unittest.mock import patch + +import pytest +from dbt.contracts.results import RunStatus +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse + +from airflow import AirflowException +from airflow_dbt_python.operators.dbt import DbtBuildOperator + +condition = False +try: + from airflow_dbt_python.hooks.dbt_s3 import DbtS3Hook +except ImportError: + condition = True +no_s3_hook = pytest.mark.skipif( + condition, reason="S3Hook not available, consider installing amazon extras" +) + +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + +if IS_DBT_VERSION_LESS_THAN_0_21: + pytest.skip( + "skipping DbtBuildOperator tests as dbt build command is available " + f"in dbt version 0.21 or later, and found version {DBT_VERSION} installed", + allow_module_level=True, + ) + + +def test_dbt_build_mocked_all_args(): + """Test mocked dbt build call with all arguments.""" + op = DbtBuildOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + full_refresh=True, + select=["/path/to/model.sql", "+/another/model.sql+2"], + fail_fast=True, + threads=3, + exclude=["/path/to/model/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + data=True, + schema=True, + show=True, + ) + args = [ + "build", + "--project-dir", + "/path/to/project/", + "--profiles-dir", + "/path/to/profiles/", + "--profile", + "dbt-profile", + "--target", + "dbt-target", + "--vars", + "{target: override}", + "--log-cache-events", + "--bypass-cache", + "--full-refresh", + "--select", + "/path/to/model.sql", + "+/another/model.sql+2", + "--fail-fast", + "--threads", + "3", + "--exclude", + "/path/to/model/to/exclude.sql", + "--selector", + "a-selector", + "--state", + "/path/to/state/", + "--data", + "--schema", + "--show", + ] + + with patch.object(DbtBuildOperator, "run_dbt_command") as mock: + mock.return_value = ([], True) + op.execute({}) + mock.assert_called_once_with(args) + + +def test_dbt_build_mocked_default(): + """Test mocked dbt build call with default arguments.""" + op = DbtBuildOperator( + task_id="dbt_task", + do_xcom_push=False, + ) + + assert op.command == "build" + + args = ["build"] + + with patch.object(DbtBuildOperator, "run_dbt_command") as mock: + mock.return_value = ([], True) + res = op.execute({}) + mock.assert_called_once_with(args) + + assert res == [] + + +def test_dbt_build_mocked_with_do_xcom_push(): + op = DbtBuildOperator( + task_id="dbt_task", + do_xcom_push=True, + ) + + assert op.command == "build" + + args = ["build"] + + with patch.object(DbtBuildOperator, "run_dbt_command") as mock: + mock.return_value = ([], True) + res = op.execute({}) + mock.assert_called_once_with(args) + + assert isinstance(json.dumps(res), str) + assert res == [] + + +def test_dbt_build_non_existent_model(profiles_file, dbt_project_file, model_files): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=["fake"], + full_refresh=True, + do_xcom_push=True, + ) + + execution_results = op.execute({}) + assert len(execution_results["results"]) == 0 + assert isinstance(json.dumps(execution_results), str) + + +def test_dbt_build_select(profiles_file, dbt_project_file, model_files): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + + +def test_dbt_build_models_full_refresh(profiles_file, dbt_project_file, model_files): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=[str(m.stem) for m in model_files], + full_refresh=True, + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + assert isinstance(json.dumps(execution_results), str) + + +def test_dbt_build_fails_with_malformed_sql( + profiles_file, dbt_project_file, broken_file +): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=[str(broken_file.stem)], + full_refresh=True, + ) + + with pytest.raises(AirflowException): + op.execute({}) + + +def test_dbt_build_fails_with_non_existent_project(profiles_file, dbt_project_file): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir="/home/fake/project", + profiles_dir="/home/fake/profiles/", + full_refresh=True, + ) + + with pytest.raises(AirflowException): + op.execute({}) + + +@no_s3_hook +def test_dbt_build_models_from_s3( + s3_bucket, profiles_file, dbt_project_file, model_files +): + hook = DbtS3Hook() + bucket = hook.get_bucket(s3_bucket) + + with open(dbt_project_file) as pf: + project_content = pf.read() + bucket.put_object(Key="project/dbt_project.yml", Body=project_content.encode()) + + with open(profiles_file) as pf: + profiles_content = pf.read() + bucket.put_object(Key="project/profiles.yml", Body=profiles_content.encode()) + + for model_file in model_files: + with open(model_file) as mf: + model_content = mf.read() + bucket.put_object( + Key=f"project/models/{model_file.name}", Body=model_content.encode() + ) + + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=f"s3://{s3_bucket}/project/", + profiles_dir=f"s3://{s3_bucket}/project/", + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + print(execution_results) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + + +@no_s3_hook +def test_dbt_build_models_with_profile_from_s3( + s3_bucket, profiles_file, dbt_project_file, model_files +): + hook = DbtS3Hook() + bucket = hook.get_bucket(s3_bucket) + + with open(profiles_file) as pf: + profiles_content = pf.read() + bucket.put_object(Key="project/profiles.yml", Body=profiles_content.encode()) + + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=f"s3://{s3_bucket}/project/", + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + + +@no_s3_hook +def test_dbt_build_models_with_project_from_s3( + s3_bucket, profiles_file, dbt_project_file, model_files +): + hook = DbtS3Hook() + bucket = hook.get_bucket(s3_bucket) + + with open(dbt_project_file) as pf: + project_content = pf.read() + bucket.put_object(Key="project/dbt_project.yml", Body=project_content.encode()) + + for model_file in model_files: + with open(model_file) as mf: + model_content = mf.read() + bucket.put_object( + Key=f"project/models/{model_file.name}", Body=model_content.encode() + ) + + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=f"s3://{s3_bucket}/project/", + profiles_dir=profiles_file.parent, + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success diff --git a/tests/test_dbt_compile.py b/tests/test_dbt_compile.py index 6362e08b..f22bfe85 100644 --- a/tests/test_dbt_compile.py +++ b/tests/test_dbt_compile.py @@ -162,8 +162,7 @@ def test_dbt_compile_models(profiles_file, dbt_project_file, model_files, compil with open(compile_dir / "model_3.sql") as f: model_3 = f.read() - - assert clean_lines(model_3) == clean_lines(COMPILED_MODEL_3) + assert clean_lines(model_3)[0:2] == clean_lines(COMPILED_MODEL_3)[0:2] with open(compile_dir / "model_4.sql") as f: model_4 = f.read() diff --git a/tests/test_dbt_debug.py b/tests/test_dbt_debug.py index 146adf61..f23e60a2 100644 --- a/tests/test_dbt_debug.py +++ b/tests/test_dbt_debug.py @@ -48,6 +48,7 @@ def test_dbt_debug_mocked_all_args(): def test_dbt_debug_mocked_default(): + """Test mocked dbt debug call with default arguments.""" op = DbtDebugOperator( task_id="dbt_task", ) diff --git a/tests/test_dbt_list.py b/tests/test_dbt_list.py index ca93ce63..8324410a 100644 --- a/tests/test_dbt_list.py +++ b/tests/test_dbt_list.py @@ -5,6 +5,7 @@ def test_dbt_ls_mocked_all_args(): + """Test mocked dbt ls call with all arguments.""" op = DbtLsOperator( task_id="dbt_task", project_dir="/path/to/project/", @@ -19,6 +20,7 @@ def test_dbt_ls_mocked_all_args(): exclude=["/path/to/data/to/exclude.sql"], selector="a-selector", dbt_output="json", + output_keys=["a-key", "another-key"], ) args = [ "ls", @@ -45,6 +47,9 @@ def test_dbt_ls_mocked_all_args(): "a-selector", "--output", "json", + "--output-keys", + "a-key", + "another-key", ] with patch.object(DbtLsOperator, "run_dbt_command") as mock: diff --git a/tests/test_dbt_run.py b/tests/test_dbt_run.py index acf6b7f0..807f62c2 100644 --- a/tests/test_dbt_run.py +++ b/tests/test_dbt_run.py @@ -153,24 +153,6 @@ def test_dbt_run_models_full_refresh(profiles_file, dbt_project_file, model_file assert isinstance(json.dumps(execution_results), str) -BROKEN_SQL = """ -SELECT - field1 AS field1 -FROM - non_existent_table -WHERE - field1 > 1 -""" - - -@pytest.fixture -def broken_file(dbt_project_dir): - d = dbt_project_dir / "models" - m = d / "broken.sql" - m.write_text(BROKEN_SQL) - return m - - def test_dbt_run_fails_with_malformed_sql(profiles_file, dbt_project_file, broken_file): op = DbtRunOperator( task_id="dbt_task", diff --git a/tests/test_dbt_source.py b/tests/test_dbt_source.py index 9896a3f4..3639b3d4 100644 --- a/tests/test_dbt_source.py +++ b/tests/test_dbt_source.py @@ -1,10 +1,17 @@ from pathlib import Path from unittest.mock import patch +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse + from airflow_dbt_python.operators.dbt import DbtSourceOperator +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + def test_dbt_source_mocked_all_args(): + """Test mocked dbt source call with all arguments.""" op = DbtSourceOperator( task_id="dbt_task", subcommand="freshness", @@ -52,13 +59,17 @@ def test_dbt_source_mocked_all_args(): def test_dbt_source_mocked_default(): + """Test mocked dbt source call with default arguments.""" op = DbtSourceOperator( task_id="dbt_task", ) assert op.command == "source" - args = ["source", "snapshot-freshness"] + if IS_DBT_VERSION_LESS_THAN_0_21: + args = ["source", "snapshot-freshness"] + else: + args = ["source", "freshness"] with patch.object(DbtSourceOperator, "run_dbt_command") as mock: mock.return_value = ([], True) @@ -69,6 +80,7 @@ def test_dbt_source_mocked_default(): def test_dbt_source_basic(profiles_file, dbt_project_file): + """Test the execution of a dbt source basic operator.""" op = DbtSourceOperator( task_id="dbt_task", project_dir=dbt_project_file.parent, @@ -85,6 +97,7 @@ def test_dbt_source_basic(profiles_file, dbt_project_file): def test_dbt_source_different_output(profiles_file, dbt_project_file): + """Test dbt source operator execution with different output.""" new_sources = Path(dbt_project_file.parent) / "target/new_sources.json" if new_sources.exists(): new_sources.unlink() From 34589580df7738487b7c595696cd873f00faaa6e Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Sun, 17 Oct 2021 13:00:36 +0200 Subject: [PATCH 2/2] feat: Support select argument for run, test, and compile operators dbt 0.21 added the select flag to the run, test, and compile commands with the intention for it to replace models and bring them in line with the rest of the dbt commands. We have added support for a support argument but kept backwards compatibility with models as we intend to support dbt versions >= 0.19 (at least for the moment). --- airflow_dbt_python/operators/dbt.py | 35 +++++++++++++++++------ tests/test_dbt_compile.py | 38 ++++++++++++++++++++++++- tests/test_dbt_list.py | 1 + tests/test_dbt_run.py | 42 +++++++++++++++++++++++++++- tests/test_dbt_test.py | 43 ++++++++++++++++++++++++++++- 5 files changed, 148 insertions(+), 11 deletions(-) diff --git a/airflow_dbt_python/operators/dbt.py b/airflow_dbt_python/operators/dbt.py index 7015d345..f8302749 100644 --- a/airflow_dbt_python/operators/dbt.py +++ b/airflow_dbt_python/operators/dbt.py @@ -22,6 +22,10 @@ from airflow.models.xcom import XCOM_RETURN_KEY from airflow.utils.decorators import apply_defaults +IS_DBT_VERSION_LESS_THAN_0_21 = ( + int(installed_version.minor) < 21 and int(installed_version.major) == 0 +) + class DbtBaseOperator(BaseOperator): """The basic Airflow dbt operator. @@ -297,6 +301,7 @@ class DbtRunOperator(DbtBaseOperator): __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ "full_refresh", "models", + "select", "fail_fast", "threads", "exclude", @@ -322,7 +327,6 @@ def __init__( ) -> None: super().__init__(**kwargs) self.full_refresh = full_refresh - self.models = models self.fail_fast = fail_fast self.threads = threads self.exclude = exclude @@ -331,6 +335,11 @@ def __init__( self.defer = defer self.no_defer = no_defer + if IS_DBT_VERSION_LESS_THAN_0_21: + self.models = models or select + else: + self.select = select or models + class DbtSeedOperator(DbtBaseOperator): """Executes dbt seed.""" @@ -378,6 +387,7 @@ class DbtTestOperator(DbtBaseOperator): "schema", "fail_fast", "models", + "select", "threads", "exclude", "selector", @@ -391,6 +401,7 @@ def __init__( data: Optional[bool] = None, schema: Optional[bool] = None, models: Optional[list[str]] = None, + select: Optional[list[str]] = None, fail_fast: Optional[bool] = None, threads: Optional[int] = None, exclude: Optional[list[str]] = None, @@ -403,7 +414,6 @@ def __init__( super().__init__(**kwargs) self.data = data self.schema = schema - self.models = models self.fail_fast = fail_fast self.threads = threads self.exclude = exclude @@ -412,6 +422,11 @@ def __init__( self.defer = defer self.no_defer = no_defer + if IS_DBT_VERSION_LESS_THAN_0_21: + self.models = models or select + else: + self.select = select or models + class DbtCompileOperator(DbtBaseOperator): """Executes dbt compile.""" @@ -424,6 +439,7 @@ class DbtCompileOperator(DbtBaseOperator): "fail_fast", "threads", "models", + "select", "exclude", "selector", "state", @@ -434,6 +450,7 @@ def __init__( parse_only: Optional[bool] = None, full_refresh: Optional[bool] = None, models: Optional[list[str]] = None, + select: Optional[list[str]] = None, fail_fast: Optional[bool] = None, threads: Optional[int] = None, exclude: Optional[list[str]] = None, @@ -444,13 +461,17 @@ def __init__( super().__init__(**kwargs) self.parse_only = parse_only self.full_refresh = full_refresh - self.models = models self.fail_fast = fail_fast self.threads = threads self.exclude = exclude self.selector = selector self.state = state + if IS_DBT_VERSION_LESS_THAN_0_21: + self.models = models or select + else: + self.select = select or models + class DbtDepsOperator(DbtBaseOperator): """Executes dbt deps.""" @@ -526,7 +547,6 @@ class DbtLsOperator(DbtBaseOperator): __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ "resource_type", "select", - "models", "exclude", "selector", "dbt_output", @@ -537,7 +557,6 @@ def __init__( self, resource_type: Optional[list[str]] = None, select: Optional[list[str]] = None, - models: Optional[list[str]] = None, exclude: Optional[list[str]] = None, selector: Optional[str] = None, dbt_output: Optional[str] = None, @@ -547,7 +566,6 @@ def __init__( super().__init__(**kwargs) self.resource_type = resource_type self.select = select - self.models = models self.exclude = exclude self.selector = selector self.dbt_output = dbt_output @@ -596,9 +614,10 @@ class DbtSourceOperator(DbtBaseOperator): __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ "select", - "models", + "threads", "exclude", "selector", + "state", "dbt_output", ] @@ -618,11 +637,11 @@ def __init__( ) -> None: super().__init__(positional_args=[subcommand], **kwargs) self.select = select - self.dbt_output = dbt_output self.threads = threads self.exclude = exclude self.selector = selector self.state = state + self.dbt_output = dbt_output class DbtBuildOperator(DbtBaseOperator): diff --git a/tests/test_dbt_compile.py b/tests/test_dbt_compile.py index f22bfe85..24b427de 100644 --- a/tests/test_dbt_compile.py +++ b/tests/test_dbt_compile.py @@ -8,6 +8,7 @@ DBT_VERSION = parse(DBT_VERSION) IS_DBT_VERSION_LESS_THAN_0_20 = DBT_VERSION.minor < 20 and DBT_VERSION.major == 0 +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 def test_dbt_compile_mocked_all_args(): @@ -30,6 +31,12 @@ def test_dbt_compile_mocked_all_args(): selector="a-selector", state="/path/to/state/", ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + SELECTION_KEY = "models" + else: + SELECTION_KEY = "select" + args = [ "compile", "--project-dir", @@ -49,7 +56,7 @@ def test_dbt_compile_mocked_all_args(): "--fail-fast", "--threads", "2", - "--models", + f"--{SELECTION_KEY}", "/path/to/model1.sql", "/path/to/model2.sql", "--exclude", @@ -205,3 +212,32 @@ def test_dbt_compile_models_full_refresh( model_4 = f.read() assert clean_lines(model_4) == clean_lines(COMPILED_MODEL_4) + + +def test_dbt_compile_uses_correct_argument_according_to_version(): + """Test if dbt run operator sets the proper attribute based on dbt version.""" + op = DbtCompileOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + parse_only=True, + full_refresh=True, + fail_fast=True, + models=["/path/to/model1.sql", "/path/to/model2.sql"], + threads=2, + exclude=["/path/to/data/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + assert op.models == ["/path/to/model1.sql", "/path/to/model2.sql"] + assert getattr(op, "select", None) is None + else: + assert op.select == ["/path/to/model1.sql", "/path/to/model2.sql"] + assert getattr(op, "models", None) is None diff --git a/tests/test_dbt_list.py b/tests/test_dbt_list.py index 8324410a..c34e1e73 100644 --- a/tests/test_dbt_list.py +++ b/tests/test_dbt_list.py @@ -22,6 +22,7 @@ def test_dbt_ls_mocked_all_args(): dbt_output="json", output_keys=["a-key", "another-key"], ) + args = [ "ls", "--project-dir", diff --git a/tests/test_dbt_run.py b/tests/test_dbt_run.py index 807f62c2..805559dd 100644 --- a/tests/test_dbt_run.py +++ b/tests/test_dbt_run.py @@ -3,6 +3,8 @@ import pytest from dbt.contracts.results import RunStatus +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse from airflow import AirflowException from airflow_dbt_python.operators.dbt import DbtRunOperator @@ -16,6 +18,9 @@ condition, reason="S3Hook not available, consider installing amazon extras" ) +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + def test_dbt_run_mocked_all_args(): """Test mocked dbt run call with all arguments.""" @@ -36,6 +41,12 @@ def test_dbt_run_mocked_all_args(): selector="a-selector", state="/path/to/state/", ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + SELECTION_KEY = "models" + else: + SELECTION_KEY = "select" + args = [ "run", "--project-dir", @@ -51,7 +62,7 @@ def test_dbt_run_mocked_all_args(): "--log-cache-events", "--bypass-cache", "--full-refresh", - "--models", + f"--{SELECTION_KEY}", "/path/to/model.sql", "+/another/model.sql+2", "--fail-fast", @@ -72,6 +83,7 @@ def test_dbt_run_mocked_all_args(): def test_dbt_run_mocked_default(): + """Test mocked dbt run call with default arguments.""" op = DbtRunOperator( task_id="dbt_task", do_xcom_push=False, @@ -266,3 +278,31 @@ def test_dbt_run_models_with_project_from_s3( run_result = execution_results["results"][0] assert run_result["status"] == RunStatus.Success + + +def test_dbt_run_uses_correct_argument_according_to_version(): + """Test if dbt run operator sets the proper attribute based on dbt version.""" + op = DbtRunOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + full_refresh=True, + models=["/path/to/model.sql", "+/another/model.sql+2"], + fail_fast=True, + threads=3, + exclude=["/path/to/model/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + assert op.models == ["/path/to/model.sql", "+/another/model.sql+2"] + assert getattr(op, "select", None) is None + else: + assert op.select == ["/path/to/model.sql", "+/another/model.sql+2"] + assert getattr(op, "models", None) is None diff --git a/tests/test_dbt_test.py b/tests/test_dbt_test.py index 5939a818..dbad6bb6 100644 --- a/tests/test_dbt_test.py +++ b/tests/test_dbt_test.py @@ -2,6 +2,8 @@ import pytest from dbt.contracts.results import TestStatus +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse from airflow_dbt_python.operators.dbt import DbtTestOperator @@ -14,8 +16,12 @@ condition, reason="S3Hook not available, consider installing amazon extras" ) +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + def test_dbt_test_mocked_all_args(): + """Test mocked dbt test call with all arguments.n""" op = DbtTestOperator( task_id="dbt_task", project_dir="/path/to/project/", @@ -34,6 +40,12 @@ def test_dbt_test_mocked_all_args(): state="/path/to/state/", no_defer=True, ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + SELECTION_KEY = "models" + else: + SELECTION_KEY = "select" + args = [ "test", "--project-dir", @@ -50,7 +62,7 @@ def test_dbt_test_mocked_all_args(): "--bypass-cache", "--data", "--schema", - "--models", + f"--{SELECTION_KEY}", "/path/to/models", "--threads", "2", @@ -272,3 +284,32 @@ def test_dbt_test_with_project_from_s3( results = op.execute({}) for test_result in results["results"]: assert test_result["status"] == TestStatus.Pass + + +def test_dbt_compile_uses_correct_argument_according_to_version(): + """Test if dbt run operator sets the proper attribute based on dbt version.""" + op = DbtTestOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + data=True, + schema=True, + models=["/path/to/models"], + threads=2, + exclude=["/path/to/data/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + no_defer=True, + ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + assert op.models == ["/path/to/models"] + assert getattr(op, "select", None) is None + else: + assert op.select == ["/path/to/models"] + assert getattr(op, "models", None) is None