diff --git a/airflow_dbt_python/operators/dbt.py b/airflow_dbt_python/operators/dbt.py index 67bed85e..f8302749 100644 --- a/airflow_dbt_python/operators/dbt.py +++ b/airflow_dbt_python/operators/dbt.py @@ -14,12 +14,18 @@ from dbt.contracts.results import RunExecutionResult, RunResult, agate from dbt.logger import log_manager from dbt.main import initialize_config_values, parse_args, track_run +from dbt.semver import VersionSpecifier +from dbt.version import installed as installed_version from airflow import AirflowException from airflow.models.baseoperator import BaseOperator from airflow.models.xcom import XCOM_RETURN_KEY from airflow.utils.decorators import apply_defaults +IS_DBT_VERSION_LESS_THAN_0_21 = ( + int(installed_version.minor) < 21 and int(installed_version.major) == 0 +) + class DbtBaseOperator(BaseOperator): """The basic Airflow dbt operator. @@ -295,6 +301,7 @@ class DbtRunOperator(DbtBaseOperator): __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ "full_refresh", "models", + "select", "fail_fast", "threads", "exclude", @@ -308,6 +315,7 @@ def __init__( self, full_refresh: Optional[bool] = None, models: Optional[list[str]] = None, + select: Optional[list[str]] = None, fail_fast: Optional[bool] = None, threads: Optional[int] = None, exclude: Optional[list[str]] = None, @@ -319,7 +327,6 @@ def __init__( ) -> None: super().__init__(**kwargs) self.full_refresh = full_refresh - self.models = models self.fail_fast = fail_fast self.threads = threads self.exclude = exclude @@ -328,6 +335,11 @@ def __init__( self.defer = defer self.no_defer = no_defer + if IS_DBT_VERSION_LESS_THAN_0_21: + self.models = models or select + else: + self.select = select or models + class DbtSeedOperator(DbtBaseOperator): """Executes dbt seed.""" @@ -375,6 +387,7 @@ class DbtTestOperator(DbtBaseOperator): "schema", "fail_fast", "models", + "select", "threads", "exclude", "selector", @@ -388,6 +401,7 @@ def __init__( data: Optional[bool] = None, schema: Optional[bool] = None, models: Optional[list[str]] = None, + select: Optional[list[str]] = None, fail_fast: Optional[bool] = None, threads: Optional[int] = None, exclude: Optional[list[str]] = None, @@ -400,7 +414,6 @@ def __init__( super().__init__(**kwargs) self.data = data self.schema = schema - self.models = models self.fail_fast = fail_fast self.threads = threads self.exclude = exclude @@ -409,6 +422,11 @@ def __init__( self.defer = defer self.no_defer = no_defer + if IS_DBT_VERSION_LESS_THAN_0_21: + self.models = models or select + else: + self.select = select or models + class DbtCompileOperator(DbtBaseOperator): """Executes dbt compile.""" @@ -421,6 +439,7 @@ class DbtCompileOperator(DbtBaseOperator): "fail_fast", "threads", "models", + "select", "exclude", "selector", "state", @@ -431,6 +450,7 @@ def __init__( parse_only: Optional[bool] = None, full_refresh: Optional[bool] = None, models: Optional[list[str]] = None, + select: Optional[list[str]] = None, fail_fast: Optional[bool] = None, threads: Optional[int] = None, exclude: Optional[list[str]] = None, @@ -441,13 +461,17 @@ def __init__( super().__init__(**kwargs) self.parse_only = parse_only self.full_refresh = full_refresh - self.models = models self.fail_fast = fail_fast self.threads = threads self.exclude = exclude self.selector = selector self.state = state + if IS_DBT_VERSION_LESS_THAN_0_21: + self.models = models or select + else: + self.select = select or models + class DbtDepsOperator(DbtBaseOperator): """Executes dbt deps.""" @@ -523,29 +547,29 @@ class DbtLsOperator(DbtBaseOperator): __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ "resource_type", "select", - "models", "exclude", "selector", "dbt_output", + "output_keys", ] def __init__( self, resource_type: Optional[list[str]] = None, select: Optional[list[str]] = None, - models: Optional[list[str]] = None, exclude: Optional[list[str]] = None, selector: Optional[str] = None, dbt_output: Optional[str] = None, + output_keys: Optional[list[str]] = None, **kwargs, ) -> None: super().__init__(**kwargs) self.resource_type = resource_type self.select = select - self.models = models self.exclude = exclude self.selector = selector self.dbt_output = dbt_output + self.output_keys = output_keys # Convinience alias @@ -590,16 +614,19 @@ class DbtSourceOperator(DbtBaseOperator): __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ "select", - "models", + "threads", "exclude", "selector", + "state", "dbt_output", ] def __init__( self, # Only one subcommand is currently provided - subcommand: str = "snapshot-freshness", + subcommand: str = "freshness" + if not installed_version < VersionSpecifier.from_version_string("0.21.0") + else "snapshot-freshness", select: Optional[list[str]] = None, dbt_output: Optional[Union[str, Path]] = None, threads: Optional[int] = None, @@ -610,11 +637,67 @@ def __init__( ) -> None: super().__init__(positional_args=[subcommand], **kwargs) self.select = select + self.threads = threads + self.exclude = exclude + self.selector = selector + self.state = state self.dbt_output = dbt_output + + +class DbtBuildOperator(DbtBaseOperator): + """Execute dbt build. + + The build command combines the run, test, seed, and snapshot commands into one. The + full Documentation for the dbt build command can be found here: + https://docs.getdbt.com/reference/commands/build. + """ + + command = "build" + + __dbt_args__ = DbtBaseOperator.__dbt_args__ + [ + "full_refresh", + "select", + "fail_fast", + "threads", + "exclude", + "selector", + "state", + "defer", + "no_defer", + "data", + "schema", + "show", + ] + + def __init__( + self, + full_refresh: Optional[bool] = None, + select: Optional[list[str]] = None, + fail_fast: Optional[bool] = None, + threads: Optional[int] = None, + exclude: Optional[list[str]] = None, + selector: Optional[str] = None, + state: Optional[Union[str, Path]] = None, + defer: Optional[bool] = None, + no_defer: Optional[bool] = None, + data: Optional[bool] = None, + schema: Optional[bool] = None, + show: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.full_refresh = full_refresh + self.select = select + self.fail_fast = fail_fast self.threads = threads self.exclude = exclude self.selector = selector self.state = state + self.defer = defer + self.no_defer = no_defer + self.data = data + self.schema = schema + self.show = show def run_result_factory(data: list[tuple[Any, Any]]): diff --git a/tests/conftest.py b/tests/conftest.py index 491864d7..fbfbb531 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,3 +228,22 @@ def s3_bucket(mocked_s3_res): bucket = "airflow-dbt-test-s3-bucket" mocked_s3_res.create_bucket(Bucket=bucket) return bucket + + +BROKEN_SQL = """ +SELECT + field1 AS field1 +FROM + non_existent_table +WHERE + field1 > 1 +""" + + +@pytest.fixture +def broken_file(dbt_project_dir): + d = dbt_project_dir / "models" + m = d / "broken.sql" + m.write_text(BROKEN_SQL) + yield m + m.unlink() diff --git a/tests/test_dbt_build.py b/tests/test_dbt_build.py new file mode 100644 index 00000000..90994a3c --- /dev/null +++ b/tests/test_dbt_build.py @@ -0,0 +1,290 @@ +import json +from unittest.mock import patch + +import pytest +from dbt.contracts.results import RunStatus +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse + +from airflow import AirflowException +from airflow_dbt_python.operators.dbt import DbtBuildOperator + +condition = False +try: + from airflow_dbt_python.hooks.dbt_s3 import DbtS3Hook +except ImportError: + condition = True +no_s3_hook = pytest.mark.skipif( + condition, reason="S3Hook not available, consider installing amazon extras" +) + +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + +if IS_DBT_VERSION_LESS_THAN_0_21: + pytest.skip( + "skipping DbtBuildOperator tests as dbt build command is available " + f"in dbt version 0.21 or later, and found version {DBT_VERSION} installed", + allow_module_level=True, + ) + + +def test_dbt_build_mocked_all_args(): + """Test mocked dbt build call with all arguments.""" + op = DbtBuildOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + full_refresh=True, + select=["/path/to/model.sql", "+/another/model.sql+2"], + fail_fast=True, + threads=3, + exclude=["/path/to/model/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + data=True, + schema=True, + show=True, + ) + args = [ + "build", + "--project-dir", + "/path/to/project/", + "--profiles-dir", + "/path/to/profiles/", + "--profile", + "dbt-profile", + "--target", + "dbt-target", + "--vars", + "{target: override}", + "--log-cache-events", + "--bypass-cache", + "--full-refresh", + "--select", + "/path/to/model.sql", + "+/another/model.sql+2", + "--fail-fast", + "--threads", + "3", + "--exclude", + "/path/to/model/to/exclude.sql", + "--selector", + "a-selector", + "--state", + "/path/to/state/", + "--data", + "--schema", + "--show", + ] + + with patch.object(DbtBuildOperator, "run_dbt_command") as mock: + mock.return_value = ([], True) + op.execute({}) + mock.assert_called_once_with(args) + + +def test_dbt_build_mocked_default(): + """Test mocked dbt build call with default arguments.""" + op = DbtBuildOperator( + task_id="dbt_task", + do_xcom_push=False, + ) + + assert op.command == "build" + + args = ["build"] + + with patch.object(DbtBuildOperator, "run_dbt_command") as mock: + mock.return_value = ([], True) + res = op.execute({}) + mock.assert_called_once_with(args) + + assert res == [] + + +def test_dbt_build_mocked_with_do_xcom_push(): + op = DbtBuildOperator( + task_id="dbt_task", + do_xcom_push=True, + ) + + assert op.command == "build" + + args = ["build"] + + with patch.object(DbtBuildOperator, "run_dbt_command") as mock: + mock.return_value = ([], True) + res = op.execute({}) + mock.assert_called_once_with(args) + + assert isinstance(json.dumps(res), str) + assert res == [] + + +def test_dbt_build_non_existent_model(profiles_file, dbt_project_file, model_files): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=["fake"], + full_refresh=True, + do_xcom_push=True, + ) + + execution_results = op.execute({}) + assert len(execution_results["results"]) == 0 + assert isinstance(json.dumps(execution_results), str) + + +def test_dbt_build_select(profiles_file, dbt_project_file, model_files): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + + +def test_dbt_build_models_full_refresh(profiles_file, dbt_project_file, model_files): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=[str(m.stem) for m in model_files], + full_refresh=True, + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + assert isinstance(json.dumps(execution_results), str) + + +def test_dbt_build_fails_with_malformed_sql( + profiles_file, dbt_project_file, broken_file +): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=profiles_file.parent, + select=[str(broken_file.stem)], + full_refresh=True, + ) + + with pytest.raises(AirflowException): + op.execute({}) + + +def test_dbt_build_fails_with_non_existent_project(profiles_file, dbt_project_file): + op = DbtBuildOperator( + task_id="dbt_task", + project_dir="/home/fake/project", + profiles_dir="/home/fake/profiles/", + full_refresh=True, + ) + + with pytest.raises(AirflowException): + op.execute({}) + + +@no_s3_hook +def test_dbt_build_models_from_s3( + s3_bucket, profiles_file, dbt_project_file, model_files +): + hook = DbtS3Hook() + bucket = hook.get_bucket(s3_bucket) + + with open(dbt_project_file) as pf: + project_content = pf.read() + bucket.put_object(Key="project/dbt_project.yml", Body=project_content.encode()) + + with open(profiles_file) as pf: + profiles_content = pf.read() + bucket.put_object(Key="project/profiles.yml", Body=profiles_content.encode()) + + for model_file in model_files: + with open(model_file) as mf: + model_content = mf.read() + bucket.put_object( + Key=f"project/models/{model_file.name}", Body=model_content.encode() + ) + + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=f"s3://{s3_bucket}/project/", + profiles_dir=f"s3://{s3_bucket}/project/", + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + print(execution_results) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + + +@no_s3_hook +def test_dbt_build_models_with_profile_from_s3( + s3_bucket, profiles_file, dbt_project_file, model_files +): + hook = DbtS3Hook() + bucket = hook.get_bucket(s3_bucket) + + with open(profiles_file) as pf: + profiles_content = pf.read() + bucket.put_object(Key="project/profiles.yml", Body=profiles_content.encode()) + + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=dbt_project_file.parent, + profiles_dir=f"s3://{s3_bucket}/project/", + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success + + +@no_s3_hook +def test_dbt_build_models_with_project_from_s3( + s3_bucket, profiles_file, dbt_project_file, model_files +): + hook = DbtS3Hook() + bucket = hook.get_bucket(s3_bucket) + + with open(dbt_project_file) as pf: + project_content = pf.read() + bucket.put_object(Key="project/dbt_project.yml", Body=project_content.encode()) + + for model_file in model_files: + with open(model_file) as mf: + model_content = mf.read() + bucket.put_object( + Key=f"project/models/{model_file.name}", Body=model_content.encode() + ) + + op = DbtBuildOperator( + task_id="dbt_task", + project_dir=f"s3://{s3_bucket}/project/", + profiles_dir=profiles_file.parent, + select=[str(m.stem) for m in model_files], + do_xcom_push=True, + ) + execution_results = op.execute({}) + build_result = execution_results["results"][0] + + assert build_result["status"] == RunStatus.Success diff --git a/tests/test_dbt_compile.py b/tests/test_dbt_compile.py index 6362e08b..24b427de 100644 --- a/tests/test_dbt_compile.py +++ b/tests/test_dbt_compile.py @@ -8,6 +8,7 @@ DBT_VERSION = parse(DBT_VERSION) IS_DBT_VERSION_LESS_THAN_0_20 = DBT_VERSION.minor < 20 and DBT_VERSION.major == 0 +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 def test_dbt_compile_mocked_all_args(): @@ -30,6 +31,12 @@ def test_dbt_compile_mocked_all_args(): selector="a-selector", state="/path/to/state/", ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + SELECTION_KEY = "models" + else: + SELECTION_KEY = "select" + args = [ "compile", "--project-dir", @@ -49,7 +56,7 @@ def test_dbt_compile_mocked_all_args(): "--fail-fast", "--threads", "2", - "--models", + f"--{SELECTION_KEY}", "/path/to/model1.sql", "/path/to/model2.sql", "--exclude", @@ -162,8 +169,7 @@ def test_dbt_compile_models(profiles_file, dbt_project_file, model_files, compil with open(compile_dir / "model_3.sql") as f: model_3 = f.read() - - assert clean_lines(model_3) == clean_lines(COMPILED_MODEL_3) + assert clean_lines(model_3)[0:2] == clean_lines(COMPILED_MODEL_3)[0:2] with open(compile_dir / "model_4.sql") as f: model_4 = f.read() @@ -206,3 +212,32 @@ def test_dbt_compile_models_full_refresh( model_4 = f.read() assert clean_lines(model_4) == clean_lines(COMPILED_MODEL_4) + + +def test_dbt_compile_uses_correct_argument_according_to_version(): + """Test if dbt run operator sets the proper attribute based on dbt version.""" + op = DbtCompileOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + parse_only=True, + full_refresh=True, + fail_fast=True, + models=["/path/to/model1.sql", "/path/to/model2.sql"], + threads=2, + exclude=["/path/to/data/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + assert op.models == ["/path/to/model1.sql", "/path/to/model2.sql"] + assert getattr(op, "select", None) is None + else: + assert op.select == ["/path/to/model1.sql", "/path/to/model2.sql"] + assert getattr(op, "models", None) is None diff --git a/tests/test_dbt_debug.py b/tests/test_dbt_debug.py index 146adf61..f23e60a2 100644 --- a/tests/test_dbt_debug.py +++ b/tests/test_dbt_debug.py @@ -48,6 +48,7 @@ def test_dbt_debug_mocked_all_args(): def test_dbt_debug_mocked_default(): + """Test mocked dbt debug call with default arguments.""" op = DbtDebugOperator( task_id="dbt_task", ) diff --git a/tests/test_dbt_list.py b/tests/test_dbt_list.py index ca93ce63..c34e1e73 100644 --- a/tests/test_dbt_list.py +++ b/tests/test_dbt_list.py @@ -5,6 +5,7 @@ def test_dbt_ls_mocked_all_args(): + """Test mocked dbt ls call with all arguments.""" op = DbtLsOperator( task_id="dbt_task", project_dir="/path/to/project/", @@ -19,7 +20,9 @@ def test_dbt_ls_mocked_all_args(): exclude=["/path/to/data/to/exclude.sql"], selector="a-selector", dbt_output="json", + output_keys=["a-key", "another-key"], ) + args = [ "ls", "--project-dir", @@ -45,6 +48,9 @@ def test_dbt_ls_mocked_all_args(): "a-selector", "--output", "json", + "--output-keys", + "a-key", + "another-key", ] with patch.object(DbtLsOperator, "run_dbt_command") as mock: diff --git a/tests/test_dbt_run.py b/tests/test_dbt_run.py index acf6b7f0..805559dd 100644 --- a/tests/test_dbt_run.py +++ b/tests/test_dbt_run.py @@ -3,6 +3,8 @@ import pytest from dbt.contracts.results import RunStatus +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse from airflow import AirflowException from airflow_dbt_python.operators.dbt import DbtRunOperator @@ -16,6 +18,9 @@ condition, reason="S3Hook not available, consider installing amazon extras" ) +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + def test_dbt_run_mocked_all_args(): """Test mocked dbt run call with all arguments.""" @@ -36,6 +41,12 @@ def test_dbt_run_mocked_all_args(): selector="a-selector", state="/path/to/state/", ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + SELECTION_KEY = "models" + else: + SELECTION_KEY = "select" + args = [ "run", "--project-dir", @@ -51,7 +62,7 @@ def test_dbt_run_mocked_all_args(): "--log-cache-events", "--bypass-cache", "--full-refresh", - "--models", + f"--{SELECTION_KEY}", "/path/to/model.sql", "+/another/model.sql+2", "--fail-fast", @@ -72,6 +83,7 @@ def test_dbt_run_mocked_all_args(): def test_dbt_run_mocked_default(): + """Test mocked dbt run call with default arguments.""" op = DbtRunOperator( task_id="dbt_task", do_xcom_push=False, @@ -153,24 +165,6 @@ def test_dbt_run_models_full_refresh(profiles_file, dbt_project_file, model_file assert isinstance(json.dumps(execution_results), str) -BROKEN_SQL = """ -SELECT - field1 AS field1 -FROM - non_existent_table -WHERE - field1 > 1 -""" - - -@pytest.fixture -def broken_file(dbt_project_dir): - d = dbt_project_dir / "models" - m = d / "broken.sql" - m.write_text(BROKEN_SQL) - return m - - def test_dbt_run_fails_with_malformed_sql(profiles_file, dbt_project_file, broken_file): op = DbtRunOperator( task_id="dbt_task", @@ -284,3 +278,31 @@ def test_dbt_run_models_with_project_from_s3( run_result = execution_results["results"][0] assert run_result["status"] == RunStatus.Success + + +def test_dbt_run_uses_correct_argument_according_to_version(): + """Test if dbt run operator sets the proper attribute based on dbt version.""" + op = DbtRunOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + full_refresh=True, + models=["/path/to/model.sql", "+/another/model.sql+2"], + fail_fast=True, + threads=3, + exclude=["/path/to/model/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + assert op.models == ["/path/to/model.sql", "+/another/model.sql+2"] + assert getattr(op, "select", None) is None + else: + assert op.select == ["/path/to/model.sql", "+/another/model.sql+2"] + assert getattr(op, "models", None) is None diff --git a/tests/test_dbt_source.py b/tests/test_dbt_source.py index 9896a3f4..3639b3d4 100644 --- a/tests/test_dbt_source.py +++ b/tests/test_dbt_source.py @@ -1,10 +1,17 @@ from pathlib import Path from unittest.mock import patch +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse + from airflow_dbt_python.operators.dbt import DbtSourceOperator +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + def test_dbt_source_mocked_all_args(): + """Test mocked dbt source call with all arguments.""" op = DbtSourceOperator( task_id="dbt_task", subcommand="freshness", @@ -52,13 +59,17 @@ def test_dbt_source_mocked_all_args(): def test_dbt_source_mocked_default(): + """Test mocked dbt source call with default arguments.""" op = DbtSourceOperator( task_id="dbt_task", ) assert op.command == "source" - args = ["source", "snapshot-freshness"] + if IS_DBT_VERSION_LESS_THAN_0_21: + args = ["source", "snapshot-freshness"] + else: + args = ["source", "freshness"] with patch.object(DbtSourceOperator, "run_dbt_command") as mock: mock.return_value = ([], True) @@ -69,6 +80,7 @@ def test_dbt_source_mocked_default(): def test_dbt_source_basic(profiles_file, dbt_project_file): + """Test the execution of a dbt source basic operator.""" op = DbtSourceOperator( task_id="dbt_task", project_dir=dbt_project_file.parent, @@ -85,6 +97,7 @@ def test_dbt_source_basic(profiles_file, dbt_project_file): def test_dbt_source_different_output(profiles_file, dbt_project_file): + """Test dbt source operator execution with different output.""" new_sources = Path(dbt_project_file.parent) / "target/new_sources.json" if new_sources.exists(): new_sources.unlink() diff --git a/tests/test_dbt_test.py b/tests/test_dbt_test.py index 5939a818..dbad6bb6 100644 --- a/tests/test_dbt_test.py +++ b/tests/test_dbt_test.py @@ -2,6 +2,8 @@ import pytest from dbt.contracts.results import TestStatus +from dbt.version import __version__ as DBT_VERSION +from packaging.version import parse from airflow_dbt_python.operators.dbt import DbtTestOperator @@ -14,8 +16,12 @@ condition, reason="S3Hook not available, consider installing amazon extras" ) +DBT_VERSION = parse(DBT_VERSION) +IS_DBT_VERSION_LESS_THAN_0_21 = DBT_VERSION.minor < 21 and DBT_VERSION.major == 0 + def test_dbt_test_mocked_all_args(): + """Test mocked dbt test call with all arguments.n""" op = DbtTestOperator( task_id="dbt_task", project_dir="/path/to/project/", @@ -34,6 +40,12 @@ def test_dbt_test_mocked_all_args(): state="/path/to/state/", no_defer=True, ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + SELECTION_KEY = "models" + else: + SELECTION_KEY = "select" + args = [ "test", "--project-dir", @@ -50,7 +62,7 @@ def test_dbt_test_mocked_all_args(): "--bypass-cache", "--data", "--schema", - "--models", + f"--{SELECTION_KEY}", "/path/to/models", "--threads", "2", @@ -272,3 +284,32 @@ def test_dbt_test_with_project_from_s3( results = op.execute({}) for test_result in results["results"]: assert test_result["status"] == TestStatus.Pass + + +def test_dbt_compile_uses_correct_argument_according_to_version(): + """Test if dbt run operator sets the proper attribute based on dbt version.""" + op = DbtTestOperator( + task_id="dbt_task", + project_dir="/path/to/project/", + profiles_dir="/path/to/profiles/", + profile="dbt-profile", + target="dbt-target", + vars={"target": "override"}, + log_cache_events=True, + bypass_cache=True, + data=True, + schema=True, + models=["/path/to/models"], + threads=2, + exclude=["/path/to/data/to/exclude.sql"], + selector="a-selector", + state="/path/to/state/", + no_defer=True, + ) + + if IS_DBT_VERSION_LESS_THAN_0_21: + assert op.models == ["/path/to/models"] + assert getattr(op, "select", None) is None + else: + assert op.select == ["/path/to/models"] + assert getattr(op, "models", None) is None