Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions airflow_dbt_python/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
from dataclasses import asdict, is_dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Iterator, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TypeVar, Union

from airflow import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.version import version
from dbt.contracts.results import RunExecutionResult, agate

from airflow_dbt_python.hooks.dbt import BaseConfig, DbtHook, LogFormat, Output

# apply_defaults is deprecated in version 2 and beyond. This allows us to
# support version 1 and deal with the deprecation warning.
Expand All @@ -30,6 +27,12 @@ def apply_defaults(func: T) -> T:
return func


if TYPE_CHECKING:
from dbt.contracts.results import RunExecutionResult

from airflow_dbt_python.hooks.dbt import BaseConfig


class DbtBaseOperator(BaseOperator):
"""The basic Airflow dbt operator.

Expand Down Expand Up @@ -107,6 +110,8 @@ def __init__(
replace_on_push: bool = False,
**kwargs,
) -> None:
from airflow_dbt_python.hooks.dbt import LogFormat

super().__init__(**kwargs)
self.project_dir = project_dir
self.profiles_dir = profiles_dir
Expand Down Expand Up @@ -314,6 +319,8 @@ def prepare_directory(self, tmp_dir: str):
def dbt_hook(self):
"""Provides an existing DbtHook or creates one."""
if self._dbt_hook is None:
from airflow_dbt_python.hooks.dbt import DbtHook

self._dbt_hook = DbtHook()
return self._dbt_hook

Expand Down Expand Up @@ -611,6 +618,8 @@ def __init__(
indirect_selection: Optional[str] = None,
**kwargs,
) -> None:
from airflow_dbt_python.hooks.dbt import Output

super().__init__(**kwargs)
self.resource_types = resource_types
self.select = select
Expand Down Expand Up @@ -757,6 +766,8 @@ def run_result_factory(data: list[tuple[Any, Any]]):
We need to handle dt.datetime and agate.table.Table.
The rest of the types should already be JSON-serializable.
"""
from dbt.contracts.results import agate

d = {}
for key, val in data:
if isinstance(val, dt.datetime):
Expand Down