In [80]:
import pandas as pd

!pip install sqlparse --user --upgrade
!pip install sqlglot --user --upgrade

import sqlparse

import sqlglot
import sqlglot.expressions as exp
from sqlglot import parse_one
from sqlglot.optimizer import optimize
from sqlglot import optimizer
from sqlglot.errors import OptimizeError
from sqlglot import lineage



In [81]:
#helper functions
#find all alias from cte and related table and schema
def obtain_table_alias_type(sql):
    
    for alias in parse_one(sql, dialect="redshift").find_all(exp.TableAlias):
        
        print(f"alias => {alias.this.this} | alias_table_type => {type(alias.parent_select)}" )
        
        break
        
#find schema by name
def find_schema_by_table_name(sql, table_name):
    
    for table in parse_one(sql, dialect="redshift").find_all(exp.Table):
        
        if (table_name == table.name):
            return str(table.args['db'])
            
def obtain_list_column_table(sql):
    
    column_l = []
    table_l = []
    alias_l = []
    
    for column in parse_one(sql, dialect="redshift").find_all(exp.Column):
        
        column_l.append(column.name)
        table_l.append(column.table)
        #print(column.key)
        #print(f"Column => {column.name} | DB => {column.table}" )
    
    #in case none of the field has table information, the only table in the sql will be the source table
    if(all(elem == '' for elem in table_l)):
        
        tablename = ''
        
        for table in parse_one(sql, dialect="redshift").find_all(exp.Table):
            
            if(table.name != ''):
                tablename = table.name
    
        for n in range(len(table_l)):
        
            table_l[n] = tablename
    
    return column_l, table_l

#find all alias from cte and related table and schema
def obtain_list_table_alias(sql):
    
    unalias_name = []
    alias_l = []
    related_table_l = []
    alias_type_ = []
    schema_l = []
    
    for alias in parse_one(sql, dialect="redshift").find_all(exp.TableAlias):
        
        table_t = []
        #if it is cte
        if (alias.parent.name == ''):
            alias_type_.append(type(alias.parent.args['this']))
            alias_l.append(alias.this.this)
            column_t, table_t = obtain_list_column_table(str(alias.parent))
            related_table_l.append(list(dict.fromkeys(table_t)))
            unalias_name.append(alias.parent.name)
            
            schema = []
            
            for table in list(dict.fromkeys(table_t)):
                
                schema.append(find_schema_by_table_name(sql, table))
                
            schema_l.append(schema)
            
        #if it is normal table
        else:
            alias_type_.append('')
            related_table_l.append('')
            alias_l.append(alias.this.this)
            unalias_name.append(alias.parent.name)
            
            schema_l.append(find_schema_by_table_name(sql, alias.parent.name))
        
        #print(f"Column => {alias_l} | DB => {related_table_l}" )
        
    return alias_l, related_table_l, alias_type_, unalias_name, schema_l

def find_column_originated(sql):

    column_l = []
    table_l = []
    originated_l = []
    
    group_column_l = []
    group_table_l = []
    group_originated_l = []
    #print(sql)
    
    #need to use parse instead of parse_one
    for expression in sqlglot.parse(sql):
        #print(expression.args)
        
        keysList = list(expression.args.keys())
        #print(keysList)
    
        for key in keysList:
            if expression.args[key] != None:
                
                #find all expression/join key
                if type(expression.args[key]) is list:
                    for objecto in expression.args[key]:
                        for column in objecto.find_all(exp.Column):
                            #print(f"Column => {column.name} | type => {key}" )
                            column_l.append(column.name)
                            table_l.append(column.table)
                            originated_l.append(key)
                else:
                    #list all group by items
                    if key == 'group':
                        for groupby_clause in expression.args[key].expressions:
                            if (type(groupby_clause)==sqlglot.expressions.Column):
                                group_column_l.append(groupby_clause.this)
                                group_table_l.append(groupby_clause.table)
                            else:
                                group_column_l.append(groupby_clause)
                                group_table_l.append(None)
                            group_originated_l.append('group by')
    
    #combine into one single dataframe
    dict = {'output_column': column_l, 'table': table_l, 'action': originated_l} 
    df_relationship = pd.DataFrame(dict)
    dict = {'output_column': group_column_l, 'table': group_table_l, 'action': group_originated_l}
    df_group = pd.DataFrame(dict)
    df_relationship = df_relationship.append(df_group)
    
    return df_relationship

#just in case if there is union
def deep_find_column_originated(sql, table_name):
    
    sqlglot_ = parse_one(sql)
    if (type(sqlglot_)==sqlglot.expressions.Union):
        df_union_1 = find_column_originated(str(sqlglot_.args['this']))
        df_union_2 = find_column_originated(str(sqlglot_.args['expression']))
        
        df_union_1['union'] = ' left union'
        df_union_2['union'] = ' right union'
        
        df_ = pd.concat([df_union_1, df_union_2], ignore_index=True)
        
    else:
        df_ = find_column_originated(sql)
        df_['union'] = None
        
    df_['destined_table'] = table_name
    return df_ 


In [82]:
#sql restructuring helper functions
def sql_restructuring(sql):
    
    token = ''
    
    statement = sqlparse.parse(sql)[0]
    print(statement.get_type())
    #print(statement)
    
    if (statement.get_type() == "INSERT" or statement.get_type() == "CREATE"):

        token = 'insert'

        for expression in sqlglot.parse(sql):
            table_name = str(expression.args['this'].this)
            columns = []
            for column in expression.args['this'].expressions:
                columns.append(str(column))
            expression_ = str(expression.expression)

        #recompose the list of column into a single string: REAL Columns for the datamart
        all_column = columns

        sql = expression_
        #print(sql)

    elif(statement.get_type() == "UPDATE"):

        token = 'update'

        for expression in sqlglot.parse(sql):
            #print(expression.args)
            table_name = str(expression.args['this'])
            field_value_name = expression.expressions

            fieldname = []
            value = []

            for pair in field_value_name:
                fieldname.append(str(pair.this))
                value.append(str(pair.expression))

            from_ = str(expression.args['from'])
            where = str(expression.args['where'])

            sql_ = ''

            for value_, fieldname_ in zip(value, fieldname):

                sql_ = sql_ + ' ' + value_ + ' AS ' + fieldname_ + ','

            sql_ = sql_[:-1]

            sql_ = 'SELECT ' + sql_ + ' FROM ' + table_name + ' JOIN ' + from_.replace('FROM', ' ', 1) + ' ON (' + where.replace('WHERE', ' ', 1) + ')' 
            sql = sql_

            update_column = fieldname
            all_column = update_column
    
    elif(statement.get_type() == "SELECT"):
        
        token='select'
        
        all_column = []
        
        for expression in sqlglot.parse(sql):
            table_name = str(expression.args['from'].this)
        
        for expression in sqlglot.parse_one(sql):
            all_column.append(expression.args['this'].this)
            
        sql = sql
    
    else:
        print("can't define the sql type")
        all_column = ''
        table_name = ''
        sql = 'Invalid'
        token = 'Invalid'
        
        
    return all_column, sql, token, table_name

In [83]:
#print column level leneage dataframe when given a column name from a given sql

def column_lineage(column, sql, token):

    node_list = []
    source_list = []
    expression_list = []
    column_list = []
    alias_list = []
    reference_list = []

    for node in lineage.lineage(column, sql).walk():

        node_ = str(node.name)

        #print(node_)
        source = str(node.source)
        alias = str(node.expression.alias)
        expression = str(node.expression)
        #depth = expression.depth

        #full source
        node_list.append(column)
        source_list.append(source)
        alias_list.append(alias)
        expression_list.append(expression)
        column_list.append(node_)
        reference_list.append(node.reference_node_name)

    # Convert to DataFrame
    dict = {'node': node_list, 'output_node': column_list, 'alias': alias_list, 'reference node': reference_list, 'logic': expression_list, 'full source': source_list} 

    df_i = pd.DataFrame(dict)
    df_i.iloc[0, 3] = 'main'
    
    #determine all table and column relationship
    logics = []

    if (token == 'insert'):
        logics.append(df_i.iloc[0]['logic'])
        list_1 = df_i.loc[df_i['output_node'].str.isnumeric()]['logic'].to_list()
    else:
        logics.append(df_i.iloc[0]['full source'])
        list_1 = df_i.loc[df_i['output_node'].str.isnumeric()]['full source'].to_list()

    logics = logics + list_1
    
    column_f = []
    table_f = []

    for logic in logics:

        column_, table_ = obtain_list_column_table(logic)
        column_f = column_f + column_
        table_f = table_f + table_

    dict = {'field_name': column_f, 'table_name': table_f} 
    df_temp = pd.DataFrame(dict)
    df_temp

    table_list, table_component, logic_type, original_table, table_schema = obtain_list_table_alias(sql)

    # Convert to DataFrame
    dict = {'table_name': table_list, 'table component': table_component, 'logic': logic_type, 'original table': original_table, 'schema': table_schema} 

    table_alias = pd.DataFrame(dict)
    table_mapping = pd.merge(df_temp, table_alias, how='left', on='table_name')
    table_mapping = table_mapping[['table_name', 'table component', 'logic', 'original table', 'schema']]
    table_mapping = table_mapping[~table_mapping.astype(str).duplicated()]
    
    #clean up
    if (token=='insert'):
        df_i['column_l'], df_i['table_l'] = zip(*df_i['logic'].map(obtain_list_column_table))
        df_i.loc[df_i['logic']==df_i['full source'], 'column_l'] = ''
        df_i.loc[df_i['logic']==df_i['full source'], 'table_l'] = ''
    elif (token=='update'):
        df_i['column_l'], df_i['table_l'] = zip(*df_i['full source'].map(obtain_list_column_table))
        df_i.loc[df_i['logic']==df_i['full source'], 'column_l'] = ''
        df_i.loc[df_i['logic']==df_i['full source'], 'table_l'] = ''
    else:
        print('no action.')

    #make a list of all component of a node 
    main_df = df_i.loc[df_i['reference node']!='']
    main_df = main_df.explode(['column_l', 'table_l'])
    main_df['subnode'] = main_df['table_l'] + '.' + main_df['column_l']

    #produce field dictionary
    df_i.loc[df_i['output_node'].str.isnumeric(), 'output_node'] = df_i['reference node'] + '.' + df_i['alias']
    df_field = df_i[['output_node', 'logic']].drop_duplicates(ignore_index=True)
    df_field = df_field.rename(columns={"output_node": "field", "logic": "original field",}, errors="raise")
    
    #final table cleaning
    final_df = pd.merge(main_df, df_field, how='left', left_on='subnode', right_on='field')
    final_df = final_df[['node', 'output_node', 'alias', 'reference node', 'logic', 'full source', 'table_l', 'column_l', 'original field']]

    final_df = pd.merge(final_df, table_mapping, how='left', left_on='table_l', right_on='table_name')
    
    final_df['original field'] = final_df['original field'].fillna('noSchema')
    final_df.loc[:, 'temp_schema'] = final_df['original field'].map(lambda x: x.split('.')[0])
    final_df.loc[final_df['table component'].isna(), 'schema'] = final_df['temp_schema']
    final_df.loc[final_df['table component'].isna(), 'original table'] = final_df['table_name']

    final_df = final_df[['node', 'reference node', 'alias', 'logic_x', 
              'full source', 'schema', 'table_l', 'table component', 'original table', 'column_l']].fillna('')
    
    return final_df

In [84]:
def get_sp_column_lineage_relationship(sql_list, file_name):

    overall_df_relation = pd.DataFrame()
    overall_lineage = pd.DataFrame()

    n = 0

    for sql in sql_list:
        
        try: 
            #beautify sql by removing all comments
            sql = sqlparse.format(sql, strip_comments=True).strip()

            #prepocessing
            #change update and insert to respective select statement to fit in the format of sqlglot
            #prepare list of output table columns
            all_column, sql, token, table_name = sql_restructuring(sql)

            #print(f'after restructuring: {sql}')
            
            if (sql!=''):
                #print(all_column)
                #find the position of the column in the sql, whether it is in select/update/insert or join clause or group by
                df_relation = deep_find_column_originated(sql, table_name)

                for subquery in parse_one(sql).find_all(exp.Subquery):

                    df_subquery = deep_find_column_originated(str(subquery.args['this']), subquery.alias)
                    df_relation = pd.concat([df_relation, df_subquery], ignore_index=True)

                groupby_df = df_relation.loc[df_relation['action']=='group by']
                column_df = df_relation.loc[df_relation['action']=='expressions']
                join_df = df_relation.loc[df_relation['action']=='joins']

                lineage_df = pd.DataFrame()

                for column in all_column:
                    print(column)
                    #print(f'before obtaining cll: {sql}')
                    temp_df = column_lineage(column, sql, token)
                    temp_df['row_count'] = temp_df.shape[0]
                    lineage_df = pd.concat([lineage_df, temp_df], ignore_index=True)

                df_relation['sql_num'] = n
                lineage_df['sql_num'] = n
                overall_df_relation = pd.concat([overall_df_relation, df_relation], ignore_index=True)
                overall_lineage = pd.concat([overall_lineage, lineage_df], ignore_index=True)

                n = n+1
        
        except Exception as ex:
            raise Exception('Message error: ' + str(ex))

    # create a excel writer object
    with pd.ExcelWriter(file_name + ".xlsx") as writer:

        # use to_excel function and specify the sheet_name and index 
        # to store the dataframe in specified sheet
        overall_df_relation.to_excel(writer, sheet_name="Relationships", index=False)
        overall_lineage.to_excel(writer, sheet_name="Lineage", index=False)

In [89]:
import os
for file in os.listdir("stored procedure"):
    if file.endswith(".txt"):
        print(os.path.join("stored procedure", file))

        with open(os.path.join("stored procedure", file), 'r', encoding="utf8") as f:
            #remove illegal words
            text = f.read().replace('~', '!=')
            text = text.replace('#', '')
            lines = text.split(';')

            sql_2_b_processed = []
            fail_2_processed = []
            stored_procedure_called = []

            for sql in lines:

                if ('insert' in sql.lower() or 'update' in sql.lower()): #grab all building block to build the database

                    sql_2_b_processed.append(sql)

                elif ('call' in sql.lower()): #to grap all procedure called within procedure

                    stored_procedure_called.append(sql)

                else: 

                    fail_2_processed.append(sql)

            print(f"filename is {file.replace('.txt', '')}.")
            print(f"there are: {len(sql_2_b_processed)} lines.")
            
            get_sp_column_lineage_relationship(sql_2_b_processed, file.replace('.txt', ''))

stored procedure\sp_otp_po_cut_level.txt
filename is sp_otp_po_cut_level.
there are: 53 lines.


Format argument unsupported for TO_CHAR/TO_VARCHAR function


INSERT


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


COUNTRY_OF_ORIGIN
PO_CUT
STYLE
COLOR
STYLE_DESCRIPTION
GOODS_DESCRIPTION
PO_ISSUE_DATE
ORIGINAL_CRD_AT_ORIGIN
REVISED_CRD_AT_ORIGIN
ACTUAL_CRD_AT_ORIGIN
LOCAL_CURRENCY
SEASON
SOURCING_OFFICE
SOURCE_SYSTEM
DC_CODE
PO_TYPE
PURCHASING_GROUP
SBU
SUB_SBU
PRODUCT_LINE
PURCHASING_COMPANY
VENDOR_FFC
VENDOR_GROUP_NAME
FACTORY_FFC
DELAY_REASON
SHIPMENT_TERMS
PO_LOCATION
VENDOR_NAME
FACTORY_NAME
REPORT_ORDER_QTY_LUM
ORDER_AMOUNT_LOCAL_CURRENCY
SHIPPED_QTY_LUM
SHIPPED_AMT_LOCAL_CURRENCY
BOOKED_QTY_LUM
BOOKED_AMT_LOCAL_CURRENCY
EFFECTIVE_QTY
isSample
isCustomOrder
MANAGING_OFFICE
CURRENT_DATE_
src_sys
MISC3
MISC26
MISC33
PURCHASING_COMPANY_CODE
exch_rate_date
od_misc_flag
tbr_defect_flag
po_creation_date
shipment_id
hts_code
factory_designation
ERP_Factory_Code
ERP_Vendor_Code
PO_Season
Revised_In_DC_Date
Ship_Mode
PO_Acknowledgement_Date
PO_Complete_Status
Greenlight_date
First_Shipment_ID
First_Actual_crd
master_po
freight_paid_by
exit_cnty_port
ship_to_name
Brand_Requested_in_DC_date
Original_Re

  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


sbu
product_line
misc3
UPDATE
BRAND
MAJOR_PRODUCT_CATEGORY_NAME


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


BUSINESS_UNIT
BRAND2
MARKET
PRODUCT_SUPPLY_GROUP
UPDATE
brand
UPDATE
exchange_rate
UPDATE
exchange_rate
UPDATE
certified_pocut_flag
certprintdate
UPDATE


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


table_names
ORDER_AMOUNT_USD
SHIPPED_AMT_USD
UNIT_PRICE
EFFECTIVE_CRD
fob_price_usd
UPDATE
Unit_Price_Average_in_USD
UPDATE
effective_crd_year
effective_crd_month
UPDATE
days_late
UPDATE
ORDER_TYPE_CALCULATED
BOOKED_AMT_USD
EFFECTIVE_AMT_USD


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


UPDATE
crc_code
crc_description
UPDATE
CRC_Owner
UPDATE
crc_code
crc_description
UPDATE
CRC_Owner
UPDATE


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


balance_qty
UPDATE
dc_name
destination_country
UPDATE
Shipment_Status
fr_release


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


Shipment_id_closing_date
Actual_Ship_Date
First_Actual_Ship_Date
UPDATE


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


Shipment_Closed_By
UPDATE
hts_product_category
hts_product_type
UPDATE
hts_product_category
UPDATE


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


costing_season
UPDATE
orders_to_be_produced_qty
orders_to_be_produced_amt_usd
actual_produced_qty
actual_produced_amt_usd
UPDATE
fob_duties_rate
UPDATE
fob_duties_rate
UPDATE


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


"style"
season
po_season
original_requested_production_end_date
latest_confirmed_production_end_date
payment_terms
UPDATE
dc_name
UPDATE
ship_to_name
UPDATE
season_cleaned
UPDATE
buy_month


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


UPDATE
VF_Fiscal_Year_Original_crd
VF_Fiscal_Year_Month_Original_crd
UPDATE
actual_ship_date
UPDATE
first_actual_ship_date
CREATE
UPDATE
standard_duty_rate


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


UPDATE
standard_duty_rate
UPDATE
standard_duty_rate
UPDATE
standard_duty_rate
UPDATE
preferential_duty_rate
UPDATE
cbm_pc_brand_cat


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


UPDATE
freight_cost
UNKNOWN
can't define the sql type
UPDATE
applied_duty_rate
calculated_duty_amt
fully_landed_cost


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


UPDATE
vendor_name
vendor_group_name
mdg_vendor_created_on
mdg_vendor_purchasing_block


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


UPDATE
factory_name
mdg_factory_created_on
mdg_factory_purchasing_block


  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


INSERT
snapshot_created_date
snapshot_name


  df_relationship = df_relationship.append(df_group)


country_of_origin
po_cut
style
color
style_description
goods_description
po_issue_date
original_crd_at_origin
revised_crd_at_origin
actual_crd_at_origin
local_currency
unit_price
season
sourcing_office
source_system
dc_code
po_type
purchasing_group
sbu
sub_sbu
product_line
purchasing_company
vendor_ffc
vendor_group_name
factory_ffc
delay_reason
shipment_terms
po_location
destination_country
vendor_name
factory_name
report_order_qty_lum
order_amount_local_currency
shipped_qty_lum
shipped_amt_local_currency
booked_qty_lum
booked_amt_local_currency
effective_crd
effective_qty
order_type_calculated
qty_per_pack
issample
iscustomorder
managing_office
brand
major_product_category_name
business_unit
brand2
market
order_amount_usd
shipped_amt_usd
booked_amt_usd
effective_amt_usd
exchange_rate
product_supply_group
CURRENT_DATE
src_sys
table_names
effective_crd_year
effective_crd_month
misc3
misc26
misc33
purchasing_company_code
days_late
exch_rate_date
od_misc_flag
tbr_defect_flag
po_creation_d

  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)
  df_relationship = df_relationship.append(df_group)


snapshot_name
country_of_origin
po_cut
"style"
color
style_description
goods_description
po_issue_date
original_crd_at_origin
revised_crd_at_origin
actual_crd_at_origin
local_currency
unit_price
season
sourcing_office
source_system
dc_code
po_type
purchasing_group
sbu
sub_sbu
product_line
purchasing_company
vendor_ffc
vendor_group_name
factory_ffc
delay_reason
shipment_terms
po_location
destination_country
vendor_name
factory_name
report_order_qty_lum
order_amount_local_currency
shipped_qty_lum
shipped_amt_local_currency
booked_qty_lum
booked_amt_local_currency
effective_crd
effective_qty
order_type_calculated
qty_per_pack
issample
iscustomorder
managing_office
brand
major_product_category_name
business_unit
brand2
market
order_amount_usd
shipped_amt_usd
booked_amt_usd
effective_amt_usd
exchange_rate
product_supply_group
CURRENT_DATE
src_sys
table_names
effective_crd_year
effective_crd_month
misc3
misc26
misc33
purchasing_company_code
days_late
exch_rate_date
od_misc_flag
tbr_defect_fl