diff --git a/airflow_dbt_python/operators/dbt.py b/airflow_dbt_python/operators/dbt.py index 2927270..140fc21 100644 --- a/airflow_dbt_python/operators/dbt.py +++ b/airflow_dbt_python/operators/dbt.py @@ -8,15 +8,12 @@ from dataclasses import asdict, is_dataclass from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Callable, Iterator, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TypeVar, Union from airflow import AirflowException from airflow.models.baseoperator import BaseOperator from airflow.models.xcom import XCOM_RETURN_KEY from airflow.version import version -from dbt.contracts.results import RunExecutionResult, agate - -from airflow_dbt_python.hooks.dbt import BaseConfig, DbtHook, LogFormat, Output # apply_defaults is deprecated in version 2 and beyond. This allows us to # support version 1 and deal with the deprecation warning. @@ -30,6 +27,12 @@ def apply_defaults(func: T) -> T: return func +if TYPE_CHECKING: + from dbt.contracts.results import RunExecutionResult + + from airflow_dbt_python.hooks.dbt import BaseConfig + + class DbtBaseOperator(BaseOperator): """The basic Airflow dbt operator. @@ -107,6 +110,8 @@ def __init__( replace_on_push: bool = False, **kwargs, ) -> None: + from airflow_dbt_python.hooks.dbt import LogFormat + super().__init__(**kwargs) self.project_dir = project_dir self.profiles_dir = profiles_dir @@ -314,6 +319,8 @@ def prepare_directory(self, tmp_dir: str): def dbt_hook(self): """Provides an existing DbtHook or creates one.""" if self._dbt_hook is None: + from airflow_dbt_python.hooks.dbt import DbtHook + self._dbt_hook = DbtHook() return self._dbt_hook @@ -611,6 +618,8 @@ def __init__( indirect_selection: Optional[str] = None, **kwargs, ) -> None: + from airflow_dbt_python.hooks.dbt import Output + super().__init__(**kwargs) self.resource_types = resource_types self.select = select @@ -757,6 +766,8 @@ def run_result_factory(data: list[tuple[Any, Any]]): We need to handle dt.datetime and agate.table.Table. The rest of the types should already be JSON-serializable. """ + from dbt.contracts.results import agate + d = {} for key, val in data: if isinstance(val, dt.datetime):