Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 91 additions & 8 deletions airflow_dbt_python/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
from dbt.contracts.results import RunExecutionResult, RunResult, agate
from dbt.logger import log_manager
from dbt.main import initialize_config_values, parse_args, track_run
from dbt.semver import VersionSpecifier
from dbt.version import installed as installed_version

from airflow import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.decorators import apply_defaults

IS_DBT_VERSION_LESS_THAN_0_21 = (
int(installed_version.minor) < 21 and int(installed_version.major) == 0
)


class DbtBaseOperator(BaseOperator):
"""The basic Airflow dbt operator.
Expand Down Expand Up @@ -295,6 +301,7 @@ class DbtRunOperator(DbtBaseOperator):
__dbt_args__ = DbtBaseOperator.__dbt_args__ + [
"full_refresh",
"models",
"select",
"fail_fast",
"threads",
"exclude",
Expand All @@ -308,6 +315,7 @@ def __init__(
self,
full_refresh: Optional[bool] = None,
models: Optional[list[str]] = None,
select: Optional[list[str]] = None,
fail_fast: Optional[bool] = None,
threads: Optional[int] = None,
exclude: Optional[list[str]] = None,
Expand All @@ -319,7 +327,6 @@ def __init__(
) -> None:
super().__init__(**kwargs)
self.full_refresh = full_refresh
self.models = models
self.fail_fast = fail_fast
self.threads = threads
self.exclude = exclude
Expand All @@ -328,6 +335,11 @@ def __init__(
self.defer = defer
self.no_defer = no_defer

if IS_DBT_VERSION_LESS_THAN_0_21:
self.models = models or select
else:
self.select = select or models


class DbtSeedOperator(DbtBaseOperator):
"""Executes dbt seed."""
Expand Down Expand Up @@ -375,6 +387,7 @@ class DbtTestOperator(DbtBaseOperator):
"schema",
"fail_fast",
"models",
"select",
"threads",
"exclude",
"selector",
Expand All @@ -388,6 +401,7 @@ def __init__(
data: Optional[bool] = None,
schema: Optional[bool] = None,
models: Optional[list[str]] = None,
select: Optional[list[str]] = None,
fail_fast: Optional[bool] = None,
threads: Optional[int] = None,
exclude: Optional[list[str]] = None,
Expand All @@ -400,7 +414,6 @@ def __init__(
super().__init__(**kwargs)
self.data = data
self.schema = schema
self.models = models
self.fail_fast = fail_fast
self.threads = threads
self.exclude = exclude
Expand All @@ -409,6 +422,11 @@ def __init__(
self.defer = defer
self.no_defer = no_defer

if IS_DBT_VERSION_LESS_THAN_0_21:
self.models = models or select
else:
self.select = select or models


class DbtCompileOperator(DbtBaseOperator):
"""Executes dbt compile."""
Expand All @@ -421,6 +439,7 @@ class DbtCompileOperator(DbtBaseOperator):
"fail_fast",
"threads",
"models",
"select",
"exclude",
"selector",
"state",
Expand All @@ -431,6 +450,7 @@ def __init__(
parse_only: Optional[bool] = None,
full_refresh: Optional[bool] = None,
models: Optional[list[str]] = None,
select: Optional[list[str]] = None,
fail_fast: Optional[bool] = None,
threads: Optional[int] = None,
exclude: Optional[list[str]] = None,
Expand All @@ -441,13 +461,17 @@ def __init__(
super().__init__(**kwargs)
self.parse_only = parse_only
self.full_refresh = full_refresh
self.models = models
self.fail_fast = fail_fast
self.threads = threads
self.exclude = exclude
self.selector = selector
self.state = state

if IS_DBT_VERSION_LESS_THAN_0_21:
self.models = models or select
else:
self.select = select or models


class DbtDepsOperator(DbtBaseOperator):
"""Executes dbt deps."""
Expand Down Expand Up @@ -523,29 +547,29 @@ class DbtLsOperator(DbtBaseOperator):
__dbt_args__ = DbtBaseOperator.__dbt_args__ + [
"resource_type",
"select",
"models",
"exclude",
"selector",
"dbt_output",
"output_keys",
]

def __init__(
self,
resource_type: Optional[list[str]] = None,
select: Optional[list[str]] = None,
models: Optional[list[str]] = None,
exclude: Optional[list[str]] = None,
selector: Optional[str] = None,
dbt_output: Optional[str] = None,
output_keys: Optional[list[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.resource_type = resource_type
self.select = select
self.models = models
self.exclude = exclude
self.selector = selector
self.dbt_output = dbt_output
self.output_keys = output_keys


# Convinience alias
Expand Down Expand Up @@ -590,16 +614,19 @@ class DbtSourceOperator(DbtBaseOperator):

__dbt_args__ = DbtBaseOperator.__dbt_args__ + [
"select",
"models",
"threads",
"exclude",
"selector",
"state",
"dbt_output",
]

def __init__(
self,
# Only one subcommand is currently provided
subcommand: str = "snapshot-freshness",
subcommand: str = "freshness"
if not installed_version < VersionSpecifier.from_version_string("0.21.0")
else "snapshot-freshness",
select: Optional[list[str]] = None,
dbt_output: Optional[Union[str, Path]] = None,
threads: Optional[int] = None,
Expand All @@ -610,11 +637,67 @@ def __init__(
) -> None:
super().__init__(positional_args=[subcommand], **kwargs)
self.select = select
self.threads = threads
self.exclude = exclude
self.selector = selector
self.state = state
self.dbt_output = dbt_output


class DbtBuildOperator(DbtBaseOperator):
"""Execute dbt build.

The build command combines the run, test, seed, and snapshot commands into one. The
full Documentation for the dbt build command can be found here:
https://docs.getdbt.com/reference/commands/build.
"""

command = "build"

__dbt_args__ = DbtBaseOperator.__dbt_args__ + [
"full_refresh",
"select",
"fail_fast",
"threads",
"exclude",
"selector",
"state",
"defer",
"no_defer",
"data",
"schema",
"show",
]

def __init__(
self,
full_refresh: Optional[bool] = None,
select: Optional[list[str]] = None,
fail_fast: Optional[bool] = None,
threads: Optional[int] = None,
exclude: Optional[list[str]] = None,
selector: Optional[str] = None,
state: Optional[Union[str, Path]] = None,
defer: Optional[bool] = None,
no_defer: Optional[bool] = None,
data: Optional[bool] = None,
schema: Optional[bool] = None,
show: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.full_refresh = full_refresh
self.select = select
self.fail_fast = fail_fast
self.threads = threads
self.exclude = exclude
self.selector = selector
self.state = state
self.defer = defer
self.no_defer = no_defer
self.data = data
self.schema = schema
self.show = show


def run_result_factory(data: list[tuple[Any, Any]]):
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,22 @@ def s3_bucket(mocked_s3_res):
bucket = "airflow-dbt-test-s3-bucket"
mocked_s3_res.create_bucket(Bucket=bucket)
return bucket


BROKEN_SQL = """
SELECT
field1 AS field1
FROM
non_existent_table
WHERE
field1 > 1
"""


@pytest.fixture
def broken_file(dbt_project_dir):
d = dbt_project_dir / "models"
m = d / "broken.sql"
m.write_text(BROKEN_SQL)
yield m
m.unlink()
Loading