In [None]:
import json
import collections
import os
# os.chdir('/Users/jaitoor/Documents/dev/aero/aero/dbt/dbt-snowflake-usage/')

project_path = "../../../aero/dbt/dbt-snowflake-usage/"
manifest_path = os.path.join(project_path, 'target/manifest.json')


In [None]:
import json
import collections
import os

def get_sql_file_path(node):
    # Adjust this function based on your project's structure
    # Here, assuming that all SQL files are in a directory named 'models'
    return os.path.join('models', node + '.sql')

def order_nodes_by_dependency(nodes):
    ordered_nodes = []
    while nodes:
        for node in list(nodes):
            if nodes[node].get('depends_on', {}).get('nodes'):
                nodes[node]['depends_on']['nodes'] = [n for n in nodes[node]['depends_on']['nodes'] if n in nodes]
            else:
                ordered_nodes.append(node)
                del nodes[node]
    return ordered_nodes

def order_nodes():
    # Adjust the path to your manifest.json file
    manifest_file = manifest_path
    with open(manifest_file, 'r') as file:
        manifest_data = json.load(file)
        # Get the dictionary of nodes from manifest_data
        nodes = manifest_data.get('nodes', {})
        # Order the nodes based on dependencies
        ordered_nodes = order_nodes_by_dependency(nodes)
        # Combine SQLs into one file
        # combined_sql = combine_sql_files(ordered_nodes)
        # Write the combined SQL to a new file
        # write_to_file(combined_sql, 'combined.sql')
        return ordered_nodes

with open(manifest_path, 'r') as file:
    data = json.load(file)

node_list = order_nodes()
nodes = data['nodes']

# counter=1

node_sql = {}
for node in node_list:
    sql_list = []
    # print(node)
    create_type = nodes[node]['config']['materialized']
    if ('compiled_code' in nodes[node].keys()) and create_type in ['view','table','incremental']:   
        # print("hello")
        create_type = nodes[node]['config']['materialized']
        # create_types.append(create_type)
        table_name = nodes[node]['name']
        schema_name = nodes[node]['schema']
        if create_type == 'view':
            create_string = f'CREATE OR REPLACE VIEW {schema_name}.{table_name} AS '
        if create_type in ['table','incremental']:
            create_string = f'CREATE OR REPLACE TRANSIENT TABLE {schema_name}.{table_name} AS '
        if create_type in ['test']:
            create_string = ''
        # sql = nodes[node]['compiled_code']
        # print(sql)
        if 'unrendered_config' in nodes[node].keys() and 'pre-hook' in nodes[node]['unrendered_config'].keys():
            if isinstance(nodes[node]['unrendered_config']['pre-hook'], list):
                for hook in nodes[node]['unrendered_config']['pre-hook']:
                    sql_list.append(f"""--{table_name} pre-hook\n  {hook} ;\n""")
            else:
                sql_list.append(f"""--{table_name} pre-hook\n {nodes[node]['unrendered_config']['pre-hook']} ;\n""")
        if 'compiled_code' in nodes[node].keys():
            sql_list.append(f"""--{table_name} compiled_code \n {create_string} {nodes[node]['compiled_code']} ;\n""")
        if 'unrendered_config' in nodes[node].keys() and 'post-hook' in nodes[node]['unrendered_config'].keys():
            # print('unrendered_config post-hook')
            sql_list.append(f"""--{table_name} post-hook\n {nodes[node]['unrendered_config']['post-hook']} ;\n""")
        node_sql[node] = sql_list
        

def format_query(query):
    query = query.replace("\\'", "''")
    return query

# combined_sql
# with open('data/queries.py', 'w') as file:
#     file.write('def get_native_app_query_list():\n\treturn [')
#     for query in sql_list:
#         file.write('"""\n' + format_query(query) + '""",\n')
#     file.write(']\n')

In [None]:
import json
import networkx as nx

# Read the dbt manifest file
with open(manifest_path) as f:
    manifest = json.load(f)

# Create a NetworkX graph
graph = nx.DiGraph()

# Add nodes to the graph
for model_name, model_data in manifest['nodes'].items():
    if model_data['resource_type'] == 'model':
        graph.add_node(model_name)

# Add edges between models based on their dependencies
for model_name, model_data in manifest['nodes'].items():
    for dep in model_data.get('depends_on', {}).get('nodes', []):
        if model_data['resource_type'] == 'model':
            if not "source" in dep:
                graph.add_edge(dep, model_name)

In [None]:
# Specify the node for which to get the subgraph
target_node = "model.dbt_snowflake_usage.warehouse_era"

# Get the ancestors of the target node
ancestors = nx.ancestors(graph, target_node)

# Add the target node itself to include it in the subgraph
ancestors.add(target_node)

# Extract the subgraph containing only the target node and its ancestors
subgraph = graph.subgraph(ancestors)

In [None]:
# nodes['model.dbt_snowflake_usage.test_results_60d']

In [None]:
execution_order = list(nx.topological_sort(subgraph))
execution_order

In [None]:
snowflake_tasks = []

def get_task_name(node):
    return f"task_{node.replace('.', '_')}"
for node in execution_order:
    print(node)
    sql_statements = node_sql[node]


    # Find dependencies for the current node
    dependencies = list(subgraph.predecessors(node))
    print(dependencies)
    # Generate "after" parameter value
    if len(dependencies) > 0:
        after_parameter = 'AFTER ' + ', '.join(get_task_name(dep) for dep in dependencies)
    else:
        after_parameter = ''

    for index, sql_statement in enumerate(sql_statements):
        sql_statement_without_schema = sql_statement.replace("SNOWFLAKE_USAGE.SANDBOX.","").replace(";","").replace("SANDBOX.","")
        task_name = get_task_name(node)
        snowflake_task = f"""
            CREATE OR REPLACE TASK {task_name} 
            WAREHOUSE = DEMO
            --STATEMENT_EXECUTION_TIMEOUT_IN_SECONDS = <timeout> 
            {after_parameter} 
            AS\n{sql_statement_without_schema};
            """
        after_parameter = f"AFTER {task_name}"
        snowflake_tasks.append(snowflake_task)

In [None]:
print("".join(snowflake_tasks))