In [None]:
from sqlglot.optimizer.optimizer import optimize
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
import json
import sqlglot

In [None]:
sql = """
SELECT (re.payload ->> 'customer_id')::INTEGER AS customer_id, row_to_json((SELECT e.* FROM (SELECT re.event_id, re.payload ->> 'device_type' AS device_type) AS e)) AS event_data, re.payload ->> 'event_name' AS event_type, row_to_json(re.*) AS raw_event_json FROM my_project.raw.raw_events AS re   

schema_map = {'my_project.raw.raw_customers': 'source.my_project.raw.raw_customers', 'my_project.raw.raw_events': 'source.my_project.raw.raw_events', 'my_project.staging.stg_customers': 'model.my_project.stg_customers', 'my_project.staging.stg_events': 'model.my_project.stg_events', 'my_project.analytics.customers': 'model.my_project.customers', 'my_project.analytics.customer_summary': 'model.my_project.customer_summary'}


In [None]:
parsed_sql = sqlglot.parse_one(sql, read=postgres)
qualified_sql = parsed_sql.qualify(schema=schema_map, dialect=postgres, quote_identifiers=False)
sql_query = optimize(qualified_sql)

In [None]:
sql_query = sql_query.sql()

In [None]:
from IPython.display import display, Code

formatted_sql = sqlglot.transpile(sql_query, pretty=True)[0]
display(Code(formatted_sql, language=sql))


In [None]:
## PRINT QUERY

s = 'WITH e AS (SELECT re.event_id" AS "event_id", JSON_EXTRACT_SCALAR("re"."payload", \'$.device_type\') AS "device_type") SELECT CAST((JSON_EXTRACT_SCALAR("re"."payload", \'$.customer_id\')) AS INT) AS "customer_id", ROW_TO_JSON((SELECT "e"."event_id" AS "event_id", "e"."device_type" AS "device_type" FROM "e" AS "e")) AS "event_data", JSON_EXTRACT_SCALAR("re"."payload", \'$.event_name\') AS "event_type", ROW_TO_JSON("re".*) AS "raw_event_json" FROM "my_project"."raw"."raw_events" AS "re"'

formatted_s = sqlglot.transpile(s, pretty=True)[0]
display(Code(formatted_s, language="sql"))


## WALK test

In [None]:

import sqlglot.lineage as lineage
from sqlglot import exp

parsed_sql = sqlglot.parse_one(sql, read="postgres")
qualified_sql = parsed_sql.qualify(schema=schema_map, dialect="postgres", quote_identifiers=False)
sql_query = optimize(qualified_sql)
lineage_node = lineage.lineage(sql=sql_query, column='event_type', dialect="postgres")


## Lineage Find test

In [241]:
sql = """
SELECT (re.payload ->> 'customer_id')::INTEGER AS customer_id, row_to_json((SELECT e.* FROM (SELECT re.event_id, re.payload ->> 'device_type' AS device_type) AS e)) AS event_data, re.payload ->> 'event_name' AS event_type, row_to_json(re.*) AS raw_event_json FROM my_project.raw.raw_events AS re
"""

parsed_sql = sqlglot.parse_one(sql, read="postgres")
qualified_sql = parsed_sql.qualify(schema=schema_map, dialect="postgres", quote_identifiers=False)
sql_query = optimize(qualified_sql)

from IPython.display import display, Code

formatted_s = sqlglot.transpile(sql_query.sql(), pretty=True)[0]
display(Code(formatted_s, language="sql"))



In [None]:
expr = sql_query

table_alias_map = {}
for table in expr.find_all(exp.Table):
    if table.catalog and table.db:
        table_alias_map.update({table.alias: f"{table.catalog}.{table.db}.{table.name}"})

table_alias_map

In [None]:
expr = lineage_node.source

table_alias_map = {}
for table in expr.find_all(exp.Table):
    if table.catalog and table.db:
        table_alias_map.update({table.alias: f"{table.catalog}.{table.db}.{table.name}"})

table_alias_map

In [None]:
lineage_node.source.find(exp.Table)

In [None]:
dir(lineage_node.downstream[0].expression.find(exp.Table))


In [246]:
print(json.dumps({'my_project.raw.raw_customers': 'source.my_project.raw.raw_customers', 'my_project.raw.raw_events': 'source.my_project.raw.raw_events', 'my_project.staging.stg_customers': 'model.my_project.stg_customers', 'my_project.staging.stg_events': 'model.my_project.stg_events', 'my_project.analytics.customers': 'model.my_project.customers', 'my_project.analytics.customer_summary': 'model.my_project.customer_summary'}, indent=4))

{
    "my_project.raw.raw_customers": "source.my_project.raw.raw_customers",
    "my_project.raw.raw_events": "source.my_project.raw.raw_events",
    "my_project.staging.stg_customers": "model.my_project.stg_customers",
    "my_project.staging.stg_events": "model.my_project.stg_events",
    "my_project.analytics.customers": "model.my_project.customers",
    "my_project.analytics.customer_summary": "model.my_project.customer_summary"
}


In [None]:



def look_for_group_by_expr(parent_node):
    """TODO
    find a way to get catalog and db of columns used in group bys 
    """
    sources = set()
    expres_op_group = [op_exp for op, op_exp in parent_node.source.parent_select.hashable_args if op == 'group'] ## single item
    if(expres_op_group):
        expres_op_group = expres_op_group[0]
        for op in expres_op_group:
            while(not isinstance(op, exp.Column)):
                op = op.this
            group_by_column, group_by_table = op.name, op.table
            sources.update({f"{group_by_table}.{group_by_column}"})
    return sources



elif isinstance(parent_node.expression, exp.Table): # or isinstance(parent_node.source, exp.Table)
        # Final instance, resolving column name and table         
        from_column_name = parent_node.name.split('.')[-1]
        from_catalog = parent_node.expression.catalog ## expression ou source; Possible to get table alias from parent_node.expression.alias
        from_schema = parent_node.expression.db
        from_table_name = parent_node.expression.name
        from_full_tablename = f"{from_catalog}.{from_schema}.{from_table_name}"

        ## TODO add columns used in group by to lineage?
        # assuming all group by from a single table, change that 
        # group_by_columns = look_for_group_by_expr(parent_node) 
        # if(group_by_columns):
        #     sources.update(group_by_columns)

        parent_model_id = schema_map.get(from_full_tablename.lower())
        sources.update({f"{parent_model_id}.{from_column_name}"})
