In [None]:
%%bash
# Airflow needs a home. `~/airflow` is the default, but you can put it
# somewhere else if you prefer (optional)
export AIRFLOW_HOME=~/airflow

# Install Airflow using the constraints file
AIRFLOW_VERSION=2.2.2
PYTHON_VERSION="$(python --version | cut -d " " -f 2 | cut -d "." -f 1-2)"
# For example: 3.6
CONSTRAINT_URL="https://raw.githubusercontent.com/apache/airflow/constraints-${AIRFLOW_VERSION}/constraints-${PYTHON_VERSION}.txt"
# For example: https://raw.githubusercontent.com/apache/airflow/constraints-2.2.2/constraints-3.6.txt
# echo "apache-airflow==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}"
pip install --upgrade pip
pip install "apache-airflow==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" --user

# # The Standalone command will initialise the database, make a user,
# # and start all components for you.
# airflow standalone

# Visit localhost:8080 in the browser and use the admin account details
# shown on the terminal to login.
# Enable the example_bash_operator dag in the home page

In [2]:
%%bash
env PGPASSWORD=corise psql -U corise -d dbt -c 'CREATE ROLE reporting;'
env PGPASSWORD=corise psql -U corise -d dbt -c 'CREATE SCHEMA dbt_ramnath_v'
dbt compile --project-dir=/workspace/dbt-explore/dbt-greenery

CREATE ROLE
CREATE SCHEMA
Running with dbt=0.21.0
Found 29 models, 32 tests, 4 snapshots, 0 analyses, 560 macros, 1 operation, 0 seed files, 8 sources, 1 exposure

17:26:23 | Concurrency: 4 threads (target='dev')
17:26:23 | 
17:26:24 | Done.


In [4]:
%%writefile ~/airflow/dags/dbt-greenery.py
import logging
from copy import copy
from logging import Logger
from typing import Dict, List, Optional
from airflow_dbt.operators.dbt_operator import (
    DbtSeedOperator,
    DbtSnapshotOperator,
    DbtRunOperator,
    DbtTestOperator,
)

from airflow import DAG
from airflow.models import Variable, BaseOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.task_group import TaskGroup

DbtRunOperator.ui_color = '#f5f5dc'


logger = logging.getLogger(__name__)


class DbtNode:
    def __init__(self, full_name: str, children: List[str], config: Optional[dict]):
        self.full_name = full_name
        self.children = children
        self.is_model = self.full_name.startswith('model')
        self.name = self.full_name.split('.')[-1]
        self.is_persisted = self.is_model and config["materialized"] in ['table', 'incremental', 'view']


class DbtTaskGenerator:

    def __init__(
        self, dag: DAG, manifest: dict
    ) -> None:
        self.dag: DAG = dag
        self.manifest = manifest
        self.persisted_node_map: Dict[str, DbtNode] = self._get_persisted_parent_to_child_map()
        self.logger: Logger = logging.getLogger(__name__)

    def _get_persisted_parent_to_child_map(self) -> Dict[str, DbtNode]:
        node_info = self.manifest['nodes']
        parent_to_child_map = self.manifest['child_map']

        all_nodes: Dict[str, DbtNode] = {
            node_name: DbtNode(
                full_name=node_name,
                children=children,
                config=node_info.get(node_name, {}).get('config')
            )
            for node_name, children in parent_to_child_map.items()
        }

        persisted_nodes = {
            node.full_name: DbtNode(
                full_name=node.full_name,
                children=self._get_persisted_children(node, all_nodes),
                config=node_info.get(node_name, {}).get('config')
            )
            for node_name, node in all_nodes.items()
            if node.is_persisted and node.full_name
        }

        return persisted_nodes

    @classmethod
    def _get_persisted_children(cls, node: DbtNode, all_nodes: Dict[str, DbtNode]) -> List[str]:
        persisted_children = []
        for child_key in node.children:
            child_node = all_nodes[child_key]
            if child_node.is_persisted:
                persisted_children.append(child_key)
            else:
                persisted_children += cls._get_persisted_children(child_node, all_nodes)

        return persisted_children

    def add_all_tasks(self) -> None:
        nodes_to_add: Dict[str, DbtNode] = {}
        for node in self.persisted_node_map:
            included_node = copy(self.persisted_node_map[node])
            included_children = []
            for child in self.persisted_node_map[node].children:
                included_children.append(child)
            included_node.children = included_children
            nodes_to_add[node] = included_node

        self._add_tasks(nodes_to_add)

    def _add_tasks(self, nodes_to_add: Dict[str, DbtNode]) -> None:
        dbt_model_tasks = self._create_dbt_run_model_tasks(nodes_to_add)
        self.logger.info(f'{len(dbt_model_tasks)} tasks created for models')

        for parent_node in nodes_to_add.values():
            if parent_node.is_model:
                self._add_model_dependencies(dbt_model_tasks, parent_node)

    def _create_dbt_run_model_tasks(self, nodes_to_add: Dict[str, DbtNode]) -> Dict[str, BaseOperator]:
        # dbt_docker_image_details = Variable.get("docker_dbt-data-platform", deserialize_json=True)
        dbt_model_tasks: Dict[str, BaseOperator] = {
            node.full_name: self._create_dbt_run_task(node.name)
            for node in nodes_to_add.values()
            if node.is_model
        }
        return dbt_model_tasks

    def _create_dbt_run_task(self, model_name: str) -> BaseOperator:
        # This is where you create a task to run the model - see
        # https://docs.getdbt.com/docs/running-a-dbt-project/running-dbt-in-production#using-airflow
        # We pass the run date into our models: f'dbt run --models={model_name} --vars '{"run_date":""}'
        # return DummyOperator(dag=self.dag, task_id=model_name)
        return DbtRunOperator(
            dag=self.dag, 
            task_id=model_name, 
            dir='/workspace/dbt-explore/dbt-greenery',
            models=model_name,
            verbose=True
        )

    @staticmethod
    def _add_model_dependencies(dbt_model_tasks: Dict[str, BaseOperator], parent_node: DbtNode) -> None:
        for child_key in parent_node.children:
            child = dbt_model_tasks.get(child_key)
            if child:
                dbt_model_tasks[parent_node.full_name] >> child

from datetime import datetime
from airflow import DAG
import json
import os

CUR_DIR = os.path.abspath(os.path.dirname(__file__))
with open(f"{CUR_DIR}/manifest.json", "r") as file:
    manifest = json.load(file)

dag = DAG(
    dag_id="dbt_connected_task_creator_test_dag",
    start_date=datetime(2021, 12, 6),
    schedule_interval="0 1 * * *",
)
dbt_task_generator = DbtTaskGenerator(dag, manifest)
dbt_task_generator.add_all_tasks()

Writing /home/gitpod/airflow/dags/dbt-greenery.py


In [5]:
!cp dags/foobar.py ~/airflow/dags

In [79]:
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
import json

class DbtNode:
    def __init__(self, node_name: str, node_info: Dict, node_children: List):
        self.node_info = node_info
        self.node_name = node_name
        self.parents = node_info.get('depends_on', {}).get('nodes', [])
        self.children = node_children
        self.materialized = node_info.get('config', {}).get('materialized', '')
        self.is_model = node_name.startswith('model')
        self.is_persisted = self.is_model and self.materialized in ['table', 'incremental', 'view']

    def __repr__(self):
        return f'<DbtNode> {self.node_name} ({self.materialized})'

@dataclass
class DbtProject:
    dbt_project_dir: str = None

    def __post_init__(self):
        self.manifest = self.__load_manifest()
        self.nodes = self.__load_nodes()

    def __load_manifest(self):
        manifest_path = Path(self.dbt_project_dir) / 'target/manifest.json'
        return json.loads(manifest_path.read_text())

    def __load_nodes(self):
        child_map = self.manifest['child_map']
        return [
            DbtNode(
              node_name, 
              node_info, 
              child_map.get(node_name, [])
            ) 
            for node_name, node_info in self.manifest.get('nodes', {}).items()
        ]
            


dbt = DbtProject('/workspace/dbt-explore/dbt-greenery')
