In [1]:
from sqlanalyzer import column_parser
import pandas as pd
import sqlparse
import re

In [2]:
def delevel_query(query_list):
    line_level = []
    sub_query = []
    pos_delete, pos_where = len(query_list), len(query_list)
    
    for i, line in enumerate(query_list):
        if line.startswith('ORDER') or line.startswith('GROUP'):
            pos_delete = i
        if line.startswith('WHERE'):
            pos_where = i 
        if line.startswith('FROM'):
            pos_from = i-1
        if line.startswith('LEFT JOIN') or line.startswith('INNER JOIN') or line.startswith('FULL OUTER JOIN'):
            pos_join = i
  
    for line in query_list[:pos_from+2]:
        line_level.append((line, "level_1"))
        
    for line in query_list[pos_from+1:pos_where]:
        if line.startswith(' ') or line.startswith('FROM'):
            sub_query.append(line[3:])
        else:
            line_level.append((line, "level_1"))

    for line in query_list[pos_where:pos_delete]:    
        line_level.append((line, "level_1"))

    return line_level, sub_query[1:]

In [3]:
def parse_sub_query(sub_query_list):
    sub_query = "\n".join(sub_query_list)
    sub_query = sub_query.lstrip('\n').lstrip(' ')
    
    formatter = column_parser.Parser(sub_query)
    formatted = formatter.format_query(sub_query)
    sub_query_list = formatted.split('\n')

    query = "\n".join(sub_query_list).lstrip('\n').lstrip(' ').lstrip('\(')
    query_list = query.split('\n')
    
    return query_list

In [4]:
def has_child(sub_query_list):
    query_list = parse_sub_query(sub_query_list)
    query_list = delevel_query(query_list)[1]
    return query_list, query_list != []

In [5]:
query_1 = """SELECT * FROM sfdc.accounts sfdc_accounts
LEFT JOIN (SELECT MAX(dt) FROM sfdc.cases) AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt
WHERE dt > '2020-04-03' 
"""


In [76]:
query_2 = """SELECT * FROM sfdc.accounts sfdc_accounts
LEFT JOIN sfdc.cases AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt
WHERE dt > '2020-04-03' 
"""

In [None]:
query_3 = """SELECT *
FROM (
  SELECT u.name, b.customer_tier_c, b.name, m.account, b.x18_digit_account_id_c, s.id, m.platform, m.mobile_os, m.num_requests, Row_number() OVER(
    PARTITION BY s.id
  ) row_
  FROM wbr.map_requests_by_account m
  INNER JOIN (
    SELECT DISTINCT id
    FROM mapbox_customer_data.styles 
    WHERE cast(dt as DATE) >= CURRENT_DATE - INTERVAL '14' DAY
      AND sources LIKE '%mapbox-streets-v7%'
  ) s
  ON m.service_metadata_version = s.id
  LEFT JOIN (
    SELECT customer_tier_c, csm_c, name, mapbox_username_c, x18_digit_account_id_c
    FROM sfdc.accounts
    WHERE cast(dt as DATE) = CURRENT_DATE - INTERVAL '1' DAY
  ) b
  ON m.account = b.mapbox_username_c
  LEFT JOIN (
    SELECT name, id
    FROM sfdc.users
    WHERE cast(dt as DATE) = CURRENT_DATE - INTERVAL '1' DAY
  ) u
  ON b.csm_c = u.id
  WHERE cast(m.dt as DATE) >= CURRENT_DATE - INTERVAL '14' DAY
    AND m.service_metadata = 'custom'
    AND m.service = 'styles'
    AND b.customer_tier_c IN (
      'Tier 0',
      'Tier 1',
      'Tier 2',
      'Tier 3',
      'Tier 4'
    )
)
WHERE row_ = 1
AND m.service_metadata = 'custom'
ORDER BY 1, 4
LIMIT 5000
"""

In [144]:
formatter = column_parser.Parser(query_1)
formatted = formatter.format_query(query_1)
query_list = formatted.split('\n')
query_list

['SELECT *',
 'FROM sfdc.accounts sfdc_accounts',
 'LEFT JOIN',
 '  (SELECT MAX(dt)',
 '   FROM sfdc.cases) AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt',
 "WHERE dt > '2020-04-03'"]

In [8]:
formatter = column_parser.Parser(query_1)
formatted = formatter.format_query(query_1)
query_list = formatted.split('\n')
sub_query_list, has_subquery = has_child(query_list)

In [9]:
sub_query_list

['SELECT MAX(dt)',
 'FROM sfdc.cases) AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt']

In [10]:
has_subquery

True

In [11]:
sub_sub_query_list, has_subquery = has_child(sub_query_list)

In [12]:
sub_sub_query_list

[]

In [13]:
has_subquery

False

In [14]:
db_fields_1 = pd.DataFrame({'db_table': 'sfdc.accounts', 
            'all_columns': ['platform', 'mobile_os', 'service', 'service_metadata', 'service_metadata_version', 'account', 'num_requests', 'dt']})
db_fields_1


Unnamed: 0,db_table,all_columns
0,sfdc.accounts,platform
1,sfdc.accounts,mobile_os
2,sfdc.accounts,service
3,sfdc.accounts,service_metadata
4,sfdc.accounts,service_metadata_version
5,sfdc.accounts,account
6,sfdc.accounts,num_requests
7,sfdc.accounts,dt


In [15]:
db_fields_2 = pd.DataFrame({'db_table': 'sfdc.cases', 
            'all_columns': ['account', 'num_requests', 'dt', 'owner', 'id']})
db_fields_2


Unnamed: 0,db_table,all_columns
0,sfdc.cases,account
1,sfdc.cases,num_requests
2,sfdc.cases,dt
3,sfdc.cases,owner
4,sfdc.cases,id


In [16]:
db_fields = db_fields_1.append(db_fields_2, ignore_index=True)

In [77]:
formatter = column_parser.Parser(query_2)
columns_queried = formatter.match_queried_fields(query_2, db_fields)

In [78]:
columns_queried

[{'database_name': 'sfdc', 'table_name': 'cases', 'column_name': 'owner'},
 {'database_name': 'sfdc', 'table_name': 'cases', 'column_name': 'id'},
 {'database_name': 'sfdc', 'table_name': 'cases', 'column_name': 'account'},
 {'database_name': 'sfdc', 'table_name': 'accounts', 'column_name': 'service'},
 {'database_name': 'sfdc', 'table_name': 'accounts', 'column_name': 'account'},
 {'database_name': 'sfdc',
  'table_name': 'accounts',
  'column_name': 'service_metadata_version'},
 {'database_name': 'sfdc', 'table_name': 'accounts', 'column_name': 'dt'},
 {'database_name': 'sfdc',
  'table_name': 'accounts',
  'column_name': 'service_metadata'},
 {'database_name': 'sfdc',
  'table_name': 'accounts',
  'column_name': 'platform'},
 {'database_name': 'sfdc',
  'table_name': 'accounts',
  'column_name': 'mobile_os'},
 {'database_name': 'sfdc',
  'table_name': 'accounts',
  'column_name': 'num_requests'},
 {'database_name': 'sfdc',
  'table_name': 'cases',
  'column_name': 'num_requests'},
 

In [79]:
formatted_query = formatter.format_query(query_2)
cte_queries = formatter.parse_cte(formatted_query)

In [80]:
cte_queries

{'main': "SELECT *\nFROM sfdc.accounts sfdc_accounts\nLEFT JOIN sfdc.cases AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt\nWHERE dt > '2020-04-03'"}

In [81]:
all_columns_scanned = formatter._get_all_scanned_cols(cte_queries, db_fields)

In [82]:
all_columns_scanned

['sfdc.cases.owner',
 'sfdc.cases.id',
 'sfdc.cases.account',
 'sfdc.accounts.service',
 'sfdc.accounts.account',
 'sfdc.accounts.service_metadata_version',
 'sfdc.accounts.dt',
 'sfdc.accounts.service_metadata',
 'sfdc.accounts.platform',
 'sfdc.accounts.mobile_os',
 'sfdc.accounts.num_requests',
 'sfdc.cases.num_requests',
 'sfdc.cases.dt']

In [30]:
def _get_all_scanned_cols(self, cte_queries, meta_cols):
    """
    Get all scanned original columns.
    Args:
        param1 (dict): A dictionary of CTE's name:query pair.
        param2 (dict): A dictionary of metadata columns from Glue.
    Returns:
        list: A list of all scanned columns with db and table names.
    """
    all_columns_scanned = []

    for _,cte_query in cte_queries.items():

        table_alias_mapping = self.get_table_names(cte_query.split('\n'))
        variables = self._get_all_variables(cte_query)
        queried_columns = self._get_queried_columns(table_alias_mapping, meta_cols)
        if variables == []:
            original_columns_list = []
            for table in queried_columns:
                for k,v in table.items():
                    for t in v:
                        original_columns_list.append("{}.{}".format(k,t))
        else:
            original_columns_list = self._map_db_columns(variables, queried_columns, table_alias_mapping)
        all_columns_scanned.extend(list(set(original_columns_list)))
    return all_columns_scanned


In [106]:
all_columns_scanned = []

for _,cte_query in cte_queries.items():

    table_alias_mapping = formatter.get_table_names(cte_query.split('\n'))
    print(table_alias_mapping)
#     variables = formatter._get_all_variables(cte_query)
    variables = get_all_variables(cte_query)
    print("variables:", variables)
    queried_columns = formatter._get_queried_columns(table_alias_mapping, db_fields)
    print(queried_columns)
    
    if variables == []:
        original_columns_list = []
        for table in queried_columns:
            for k,v in table.items():
                for t in v:
                    original_columns_list.append("{}.{}".format(k,t))
    else:
        original_columns_list = formatter._map_db_columns(variables, queried_columns, table_alias_mapping)
        
    all_columns_scanned.extend(list(set(original_columns_list)))
    

{'sfdc_accounts': 'sfdc.accounts', 'sfdc_accounts.dt': 'sfdc.cases'}
variables: ['sfdc.accounts sfdc_accounts', '', 'sfdc.cases', 'sfdc_cases', 'sfdc_cases.dt', 'sfdc_accounts.dt', 'dt', '']
[{'sfdc.accounts': {'dt', 'num_requests', 'service', 'platform', 'service_metadata', 'account', 'service_metadata_version', 'mobile_os'}}, {'sfdc.cases': {'owner', 'id', 'dt', 'num_requests', 'account'}}]


In [110]:
all_columns_scanned

['sfdc.accounts.dt', 'sfdc.cases.dt']

In [34]:
formatter._get_all_variables(cte_queries['main'])

[]

In [36]:
cte_queries['main'].split('\n')

['SELECT *',
 'FROM sfdc.accounts sfdc_accounts',
 'LEFT JOIN dt',
 'FROM sfdc.cases AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt',
 "WHERE dt > '2020-04-03'"]

In [62]:
list(map(lambda x: '*' in x, re.findall(r"\s[*]?", 'SELECT *')))

[True]

In [103]:
def get_all_variables(query):
    """
    Get all variables including: table names, aliases, column names and aliases, and all other non-sql reserved words.
    Args:
        param (string): A string of any type of complete query; allows only complete query but can nest with CTE's and/or subqueries.
    Returns:
        list: A list of all variables within the query.
    """
    all_variables = []

    for e in query.split('\n'):
        if sum(list(map(lambda x: '*' in x, re.findall(r"\s[*]?", e)))):
            variable = []
        else:
            variable = [x.strip(' ') for x in re.findall(r"[a-z_\s.]+", e)]

        all_variables.extend(variable)

    return all_variables


In [101]:
all_variables = []

for e in cte_queries['main'].split('\n'):
    if sum(list(map(lambda x: '*' in x, re.findall(r"\s[*]?", e)))):
        variable = []
    else:
        variable = [x.strip(' ') for x in re.findall(r"[a-z_\s.]+", e)]

    all_variables.extend(variable)

In [111]:
all_columns_scanned

['sfdc.accounts.dt', 'sfdc.cases.dt']

In [125]:
column_payload = []
for column in all_columns_scanned:
    print(column.split('.'))
    try:
        col_split = column.split('.')
        db, table, col = col_split[0], col_split[1], col_split[2]
        row_payload = dict()
        row_payload["database_name"] = db
        row_payload["table_name"] = table 
        row_payload["column_name"] = col
#         for arg, value in kwargs.items():
#             row_payload[arg] = value    
        column_payload.append(row_payload)  
    except:
        pass

['sfdc', 'accounts', 'dt']
['sfdc', 'cases', 'dt']


In [126]:
column_payload

[{'database_name': 'sfdc', 'table_name': 'accounts', 'column_name': 'dt'},
 {'database_name': 'sfdc', 'table_name': 'cases', 'column_name': 'dt'}]

In [138]:
def get_all_scanned_cols(cte_queries, meta_cols):
    """
    Get all scanned original columns.
    Args:
        param1 (dict): A dictionary of CTE's name:query pair.
        param2 (dict): A dictionary of metadata columns from Glue.
    Returns:
        list: A list of all scanned columns with db and table names.
    """
    all_columns_scanned = []

    for _,cte_query in cte_queries.items():

        table_alias_mapping = formatter.get_table_names(cte_query.split('\n'))
        variables = get_all_variables(cte_query)
        queried_columns = formatter._get_queried_columns(table_alias_mapping, meta_cols)
        if variables == []:
            original_columns_list = []
            for table in queried_columns:
                for k,v in table.items():
                    for t in v:
                        original_columns_list.append("{}.{}".format(k,t))
        else:
            original_columns_list = formatter._map_db_columns(variables, queried_columns, table_alias_mapping)
        all_columns_scanned.extend(list(set(original_columns_list)))

    return all_columns_scanned

In [139]:
def match_queried_fields(query, db_fields):
    """
    Match the query column with those registered on metastore.
    Args:
        query (string): the raw query.
        db_fields (spark dataframe): dataframe containing column names.
        **kargs: other metadata around query execution that needs to be populated to payload.

    Return:
        column_payload (json): the queried columns, table and db.
    """
#     logging.info("Reading and formatting query...")

    formatted_query = formatter.format_query(query)
    cte_queries = formatter.parse_cte(formatted_query)

#     logging.info("Mapping and retrieving columns from query...")
#     all_columns_scanned = formatter._get_all_scanned_cols(cte_queries, db_fields)
    all_columns_scanned = get_all_scanned_cols(cte_queries, db_fields)
    print(all_columns_scanned)
#     logging.info("All columns scanned in the query: {}.".format(all_columns_scanned))

    column_payload = []
    for column in all_columns_scanned:
        try:
            col_split = column.split('.')
            db, table, col = col_split[0], col_split[1], col_split[2]
            row_payload = dict()
            row_payload["database_name"] = db
            row_payload["table_name"] = table 
            row_payload["column_name"] = col
#             for arg, value in kwargs.items():
#                 row_payload[arg] = value    

            column_payload.append(row_payload)  
        except:
            pass

    return column_payload

In [146]:
match_queried_fields(formatted, db_fields)

['sfdc.accounts.dt', 'sfdc.cases.dt']


[{'database_name': 'sfdc', 'table_name': 'accounts', 'column_name': 'dt'},
 {'database_name': 'sfdc', 'table_name': 'cases', 'column_name': 'dt'}]

In [145]:
print(formatted)

SELECT *
FROM sfdc.accounts sfdc_accounts
LEFT JOIN
  (SELECT MAX(dt)
   FROM sfdc.cases) AS sfdc_cases ON sfdc_cases.dt = sfdc_accounts.dt
WHERE dt > '2020-04-03'
