In [1]:
import sqlparse
from sqlparse.sql import Token, TokenList
import re
import json
import random, string
import copy

In [2]:
base_keywords = ['select','from','where','group','having']

In [3]:
def clean_query(tokens):
    tokens_exp = []
    tokens_cleaned = []
    for t in tokens:
        if isinstance(t,sqlparse.sql.Where):
            for tt in t.tokens:
                tokens_exp.append(tt)
        else:
            tokens_exp.append(t)
            
#     print(tokens_exp)
    
    for t in tokens_exp:
        append_flag = True
        for pattern in ['Token.Text.Whitespace*','Token.Punctuation*']:
            if re.match(pattern,str(t.ttype)):
                append_flag = False
        for t_type in [sqlparse.sql.Comment]:
            if type(t)==t_type:
                append_flag = False
                
        if append_flag:
            tokens_cleaned.append(t)
                
#     tokens_cleaned = [t for t in tokens_exp if t not in tokens_flagged]
#     tokens_cleaned = [t for t in tokens_cleaned if type(t) not in [sqlparse.sql.Comment]]
    return tokens_cleaned

In [4]:
def get_select_from_with(subq_token):
    for t in subq_token.tokens:
        if isinstance(t,sqlparse.sql.Parenthesis):
            return clean_query(t.tokens)

In [5]:
def check_token_is_base_keyword(t,kw):
    if kw == 'select':
        return t.match(sqlparse.tokens.Keyword.DML,['select'])
    else:
        return t.match(sqlparse.tokens.Keyword,[kw])

In [6]:

def parse_simple_select(q_tokens):
    assert check_token_is_base_keyword(q_tokens[0],'select')
    
    # initial high level parsing
    q_parse_dict_1 = {kw:[] for kw in base_keywords}
    base_keyword_idx = 0
    for t in q_tokens:
        is_kw_flag = False
        for kw in base_keywords:
            if check_token_is_base_keyword(t,kw):
                key = kw
                is_kw_flag = True
#         if check_token_is_base_keyword(t,base_keywords[base_keyword_idx]):
#             key = base_keywords[base_keyword_idx]
#             q_parse_dict_1[key] = []
#             if base_keyword_idx<len(base_keywords)-1:
#                 base_keyword_idx+=1
#             continue
        if not is_kw_flag:
            q_parse_dict_1[key].append(t)
    
    return q_parse_dict_1

In [7]:
def clean_col_def(c,n):
    return re.sub('\ '+n+'$','',c.split('--')[0].strip()).strip()
    
def parse_select_statement(s_tokens):
    op_dict = {}
    for t in s_tokens:
        if isinstance(t,sqlparse.sql.Identifier):
            if isinstance(clean_query(t.tokens)[0],sqlparse.sql.Case):
                op_dict[t.get_name()] = clean_col_def(t.value,t.get_name())
#                 print(t.value,t.get_name(),clean_col_def(t.value,t.get_name()))
            else:
#                 op_dict[t.get_name()] = t.get_real_name()
                op_dict[t.get_name()] = clean_col_def(t.value,t.get_name())
        elif isinstance(t,sqlparse.sql.IdentifierList):
            for tt in clean_query(t.tokens):
#                 op_dict[tt.get_name()] = tt.get_real_name()
                op_dict[tt.get_name()] = clean_col_def(tt.value,tt.get_name())
    return op_dict


def parse_from_statement(f_tokens):
    op_dict = {'input_tables':{}}
        
    # only identifier or identifier list present in f_tokens
    only_identifiers_flag=True
    for t in f_tokens:
        if not (isinstance(t,sqlparse.sql.Identifier) or isinstance(t,sqlparse.sql.IdentifierList)):
            only_identifiers_flag=False
    
    if only_identifiers_flag:
        for t in f_tokens:
            if isinstance(t,sqlparse.sql.Identifier):
                op_dict['input_tables'][t.get_name()] = t.get_real_name()
            elif isinstance(t,sqlparse.sql.IdentifierList):
                for tt in clean_query(t.tokens):
                    op_dict['input_tables'][tt.get_name()] = tt.get_real_name()
        return op_dict
    
    # join statement
    op_dict['joins'] = {}
    for i,t in enumerate(f_tokens):
        if isinstance(t,sqlparse.sql.Identifier):
            op_dict['input_tables'][t.get_name()] = t.get_real_name()
            
        if check_token_is_base_keyword(t,'left join') \
        or check_token_is_base_keyword(t,'right join') \
        or check_token_is_base_keyword(t,'inner join') \
        or check_token_is_base_keyword(t,'join') \
        or check_token_is_base_keyword(t,'full outer join'):
            key = f_tokens[i+1].get_name()
            join_dict = {}
            join_dict['type'] = t.value.replace(' join','')
            if join_dict['type']=='join':
                join_dict['type'] = 'inner'
            join_dict['on'] = f_tokens[i+3].value
            op_dict['joins'][key] = join_dict
    
    return op_dict        
            
    
     

In [8]:
test_sel_from_q = '''
        select a.a_id, --aaa
        max(a.a_ind) a_max,
        sum(b.b_ind) b_sum
        from a left join b 
        on a.a_id = b.b_id
        right join c
        on a.a_id = c.c_id
        where a.a_id != 0
        group by 1
        '''

test_sel_from_q_2 = '''
        select a.a_id, --aaa
        max(a.a_ind) a_max,
        sum(b.b_ind) b_sum
        from asd a, basd b 
        where a.a_id = b.b_id
        and a.a_id != 0
        group by 1
        '''


test_sel_from_q_3 = '''
        select a.a_id, --aaa
        max(a.a_ind) a_max,
        sum(b.b_ind) b_sum
        from a,
        (select * from basd) b 
        where a.a_id = b.b_id
        and a.a_id != 0
        group by 1
        '''

test_sel_from_q_4 = '''
        select a.a_id, --aaa
        max(a.a_ind) a_max,
        sum(b.b_ind) b_sum
        from asd a,
        (select * from basd) b 
        where a.a_id = b.b_id
        and a.a_id != 0
        group by 1
        '''

In [9]:
q_c = clean_query(sqlparse.parse(test_sel_from_q)[0].tokens)

In [10]:
q_c

[<DML 'select' at 0x2B983CF96A8>,
 <Identifier 'a.a_id' at 0x2B983CF1C78>,
 <IdentifierList 'max(a....' at 0x2B983D199A8>,
 <Keyword 'from' at 0x2B983D18468>,
 <Identifier 'a' at 0x2B983D195E8>,
 <Keyword 'left j...' at 0x2B983D185E8>,
 <Identifier 'b' at 0x2B983D19660>,
 <Keyword 'on' at 0x2B983D18768>,
 <Comparison 'a.a_id...' at 0x2B983D19750>,
 <Keyword 'right ...' at 0x2B983D18C48>,
 <Identifier 'c' at 0x2B983D196D8>,
 <Keyword 'on' at 0x2B983D18E28>,
 <Comparison 'a.a_id...' at 0x2B983D197C8>,
 <Keyword 'where' at 0x2B983D3E348>,
 <Comparison 'a.a_id...' at 0x2B983D19840>,
 <Keyword 'group' at 0x2B983D3E768>,
 <Keyword 'by' at 0x2B983D3E828>,
 <Integer '1' at 0x2B983D3E8E8>]

In [11]:
q_c_p = parse_simple_select(q_c)

In [12]:
q_c_p

{'select': [<Identifier 'a.a_id' at 0x2B983CF1C78>,
  <IdentifierList 'max(a....' at 0x2B983D199A8>],
 'from': [<Identifier 'a' at 0x2B983D195E8>,
  <Keyword 'left j...' at 0x2B983D185E8>,
  <Identifier 'b' at 0x2B983D19660>,
  <Keyword 'on' at 0x2B983D18768>,
  <Comparison 'a.a_id...' at 0x2B983D19750>,
  <Keyword 'right ...' at 0x2B983D18C48>,
  <Identifier 'c' at 0x2B983D196D8>,
  <Keyword 'on' at 0x2B983D18E28>,
  <Comparison 'a.a_id...' at 0x2B983D197C8>],
 'where': [<Comparison 'a.a_id...' at 0x2B983D19840>],
 'group': [<Keyword 'by' at 0x2B983D3E828>, <Integer '1' at 0x2B983D3E8E8>],
 'having': []}

In [13]:
parse_select_statement(q_c_p['select'])

{'a_id': 'a.a_id', 'a_max': 'max(a.a_ind)', 'b_sum': 'sum(b.b_ind)'}

In [14]:
# clean_query(clean_query(q_c_p['from'][0].tokens)[0].tokens)

In [15]:
from_w_par = q_c_p['from']
from_w_par

[<Identifier 'a' at 0x2B983D195E8>,
 <Keyword 'left j...' at 0x2B983D185E8>,
 <Identifier 'b' at 0x2B983D19660>,
 <Keyword 'on' at 0x2B983D18768>,
 <Comparison 'a.a_id...' at 0x2B983D19750>,
 <Keyword 'right ...' at 0x2B983D18C48>,
 <Identifier 'c' at 0x2B983D196D8>,
 <Keyword 'on' at 0x2B983D18E28>,
 <Comparison 'a.a_id...' at 0x2B983D197C8>]

In [16]:
check_token_is_base_keyword(from_w_par[1],'left join')

True

In [17]:
# simplq_dict = {}
# for t in from_w_par:
#     if isinstance(t,sqlparse.sql.Identifier):
#         subsel = get_select_from_with(t)
#         if subsel:
#             simplq_dict[t.get_name()] = subsel
#             replace

In [18]:
test_l = []
for t in q_c_p['from']:
    test_l.extend(list(t.flatten()))
test_l

[<Name 'a' at 0x2B983D18528>,
 <Keyword 'left j...' at 0x2B983D185E8>,
 <Name 'b' at 0x2B983D186A8>,
 <Keyword 'on' at 0x2B983D18768>,
 <Name 'a' at 0x2B983D18828>,
 <Punctuation '.' at 0x2B983D18888>,
 <Name 'a_id' at 0x2B983D188E8>,
 <Whitespace ' ' at 0x2B983D18948>,
 <Comparison '=' at 0x2B983D189A8>,
 <Whitespace ' ' at 0x2B983D18A08>,
 <Name 'b' at 0x2B983D18A68>,
 <Punctuation '.' at 0x2B983D18AC8>,
 <Name 'b_id' at 0x2B983D18B28>,
 <Keyword 'right ...' at 0x2B983D18C48>,
 <Name 'c' at 0x2B983D18D08>,
 <Keyword 'on' at 0x2B983D18E28>,
 <Name 'a' at 0x2B983D18EE8>,
 <Punctuation '.' at 0x2B983D18F48>,
 <Name 'a_id' at 0x2B983D18FA8>,
 <Whitespace ' ' at 0x2B983D3E048>,
 <Comparison '=' at 0x2B983D3E0A8>,
 <Whitespace ' ' at 0x2B983D3E108>,
 <Name 'c' at 0x2B983D3E168>,
 <Punctuation '.' at 0x2B983D3E1C8>,
 <Name 'c_id' at 0x2B983D3E228>]

In [19]:
not max([check_token_is_base_keyword(t,'select') for t in test_l])

True

In [20]:
parse_from_statement(q_c_p['from'])

{'input_tables': {'a': 'a', 'b': 'b', 'c': 'c'},
 'joins': {'b': {'type': 'left', 'on': 'a.a_id = b.b_id'},
  'c': {'type': 'right', 'on': 'a.a_id = c.c_id'}}}

In [21]:
def generate_name(l=8):
    x = ''.join(random.choices(string.ascii_letters + string.digits, k=l))
    return x

def get_all_queries(q_tokens_c):
    all_queries={}
    
    if q_tokens_c[0].match(sqlparse.tokens.Keyword.CTE,['with']):
        # get all subqueries from with

        # list,dict of all subqueries
        subq_list_raw = clean_query(q_tokens_c[1].tokens)
        all_queries.update({subq.get_name(): parse_simple_select(get_select_from_with(subq)) for subq in subq_list_raw})

        if q_tokens_c[2].match(sqlparse.tokens.Keyword.DML,['select']):
            all_queries['out'] = parse_simple_select(q_tokens_c[2:])
 
    elif q_tokens_c[0].match(sqlparse.tokens.Keyword.DML,['select']):
        # simpler query, parse it
        all_queries['out'] = parse_simple_select(q_tokens_c)
        
    return all_queries

def is_simple_select(parsed_t_s):
    from_t = []
    for t in parsed_t_s['from']:
        from_t.extend(list(t.flatten()))
    return not max([check_token_is_base_keyword(t,'select') for t in from_t])

def make_simple(parsed_t_s, init_parsed):
    if is_simple_select(parsed_t_s):
        parsed_t_s['from'] = parse_from_statement(parsed_t_s['from'])
        return parsed_t_s
    else:
        new_from = parse_from_statement(parsed_t_s['from'])
        for i in range(len(parsed_t_s['from'])):
            t = parsed_t_s['from'][i]
            if isinstance(t,sqlparse.sql.Identifier):
                # asdasd
                sel = parse_simple_select(get_select_from_with(t))
                if sel:
                    subq_name = generate_name()
                    new_from['input_tables'][t.get_name()] = subq_name
                    init_parsed[subq_name] = make_simple(sel,init_parsed)
            elif isinstance(t,sqlparse.sql.IdentifierList):
                for tt in clean_query(t.tokens):
                    sel = parse_simple_select(get_select_from_with(tt))
                    if sel:
                        subq_name = generate_name()
                        new_from['input_tables'][tt.get_name()] = subq_name
                        init_parsed[subq_name] = make_simple(sel,init_parsed)
        parsed_t_s['from'] = new_from
        return parsed_t_s
    
    

        
            
        
    
        
        
        

In [22]:
subq_test_q_1 = '''--asd
            with a as (select a_id1 as a_id, --abc
            case when aa = \'a1\' and a2=3 then 1 else 0 end aa_ind --def
            from (select q.* from `a123.aaa` q) abc) --pqr
        , b as (select b_id, 
            sum(bb) bb_sum 
            from `b123.bbb` 
            group by b_id
            having sum(bb)>2) 
        select a.a_id,a.a_ind,b.b_sum 
        from a left join b 
        on a.a_id = b.b_id'''


In [23]:
parsed_q = sqlparse.parse(subq_test_q_1)[0] # assumed that q has a single query
q_tokens = parsed_q.tokens
q_tokens_c = clean_query(q_tokens)

In [24]:
global all_queries
all_queries = get_all_queries(q_tokens_c)

In [25]:
all_queries

{'a': {'select': [<Identifier 'a_id1 ...' at 0x2B983D606D8>,
   <Identifier 'case w...' at 0x2B983D46570>],
  'from': [<Identifier '(selec...' at 0x2B983D465E8>],
  'where': [],
  'group': [],
  'having': []},
 'b': {'select': [<IdentifierList 'b_id, ...' at 0x2B983D46660>],
  'from': [<Identifier '`b123....' at 0x2B983D460C0>],
  'where': [],
  'group': [<Keyword 'by' at 0x2B983D501C8>,
   <Identifier 'b_id' at 0x2B983D46138>],
  'having': [<Comparison 'sum(bb...' at 0x2B983D46408>]},
 'out': {'select': [<IdentifierList 'a.a_id...' at 0x2B983D46750>],
  'from': [<Identifier 'a' at 0x2B983D46228>,
   <Keyword 'left j...' at 0x2B983D50E28>,
   <Identifier 'b' at 0x2B983D462A0>,
   <Keyword 'on' at 0x2B983D50FA8>,
   <Comparison 'a.a_id...' at 0x2B983D46480>],
  'where': [],
  'group': [],
  'having': []}}

In [26]:
ks = copy.deepcopy(list(all_queries.keys()))
for k in ks:
    all_queries[k] = make_simple(all_queries[k],all_queries)

In [27]:
all_queries

{'a': {'select': [<Identifier 'a_id1 ...' at 0x2B983D606D8>,
   <Identifier 'case w...' at 0x2B983D46570>],
  'from': {'input_tables': {'abc': 'kQ6PfV7Q'}},
  'where': [],
  'group': [],
  'having': []},
 'b': {'select': [<IdentifierList 'b_id, ...' at 0x2B983D46660>],
  'from': {'input_tables': {'`b123.bbb`': '`b123.bbb`'}},
  'where': [],
  'group': [<Keyword 'by' at 0x2B983D501C8>,
   <Identifier 'b_id' at 0x2B983D46138>],
  'having': [<Comparison 'sum(bb...' at 0x2B983D46408>]},
 'out': {'select': [<IdentifierList 'a.a_id...' at 0x2B983D46750>],
  'from': {'input_tables': {'a': 'a', 'b': 'b'},
   'joins': {'b': {'type': 'left', 'on': 'a.a_id = b.b_id'}}},
  'where': [],
  'group': [],
  'having': []},
 'kQ6PfV7Q': {'select': [<Identifier 'q.*' at 0x2B983D60570>],
  'from': {'input_tables': {'q': '`a123.aaa`'}},
  'where': [],
  'group': [],
  'having': []}}

In [28]:
all_queries['si2CjupY']['select'][0].is_wildcard()

KeyError: 'si2CjupY'

In [None]:
all_queries['si2CjupY']['select'][0].match(sqlparse.tokens.Wildcard,'q.*')

In [None]:
"""
Simple select query: Select query with no subquery
"""

class SelectQuery():
    def __init__(self,query_text):
        self.query_text = query_text
        self.simple_select_queries = {}
        self.keywords = ['select','from','where','group','having']
    
    def parse_select_query(self):
        self.init_parse_query()
        
        # get all initial select queries
        self.get_all_queries()
#         print(self.all_queries)
        # convert all queries to simple queries
        ks = copy.deepcopy(list(self.all_queries.keys()))
        for k in ks:
            self.all_queries[k] = self.make_simple(self.all_queries[k])
#         print(self.all_queries)
        # parse statements
        ks = copy.deepcopy(list(self.all_queries.keys()))
        for k in ks:
            # select
            self.all_queries[k]['select'] = self.parse_select_statement(self.all_queries[k]['select'])
            # where
            # group
            # having
        
    
    def init_parse_query(self):
        parsed_q = sqlparse.parse(self.query_text)[0] # assumed that q has a single query
        q_tokens = parsed_q.tokens
        q_tokens_c = self._clean_query(q_tokens)
        self.init_parse = q_tokens_c

    def get_all_queries(self):
        all_queries={}
        
        if self.init_parse[0].match(sqlparse.tokens.Keyword.CTE,['with']):
            # get all subqueries from with

            # list,dict of all subqueries
            subq_list_raw = self._clean_query(self.init_parse[1].tokens)
            all_queries.update({subq.get_name(): self.parse_simple_select(self._get_select_from_par(subq)) for subq in subq_list_raw})

            if self.init_parse[2].match(sqlparse.tokens.Keyword.DML,['select']):
                all_queries['out'] = self.parse_simple_select(self.init_parse[2:])

        elif self.init_parse[0].match(sqlparse.tokens.Keyword.DML,['select']):
            # simpler query, parse it
            all_queries['out'] = self.parse_simple_select(self.init_parse)

        self.all_queries = all_queries
    
    def parse_simple_select(self,q_tokens):
        assert self._check_token_is_base_keyword(q_tokens[0],'select')

        # initial high level parsing
        q_parse_dict_1 = {kw:[] for kw in self.keywords}
        base_keyword_idx = 0
        for t in q_tokens:
            is_kw_flag = False
            for kw in base_keywords:
                if self._check_token_is_base_keyword(t,kw):
                    key = kw
                    is_kw_flag = True
            if not is_kw_flag:
                q_parse_dict_1[key].append(t)

        return q_parse_dict_1
    
    def make_simple(self,parsed_t_s):
        if self._is_simple_select(parsed_t_s):
            parsed_t_s['from'] = self.parse_from_statement(parsed_t_s['from'])
            return parsed_t_s
        else:
            new_from = self.parse_from_statement(parsed_t_s['from'])
            for i in range(len(parsed_t_s['from'])):
                t = parsed_t_s['from'][i]
                if isinstance(t,sqlparse.sql.Identifier):
                    sel = self.parse_simple_select(self._get_select_from_par(t))
                    if sel:
                        subq_name = self._generate_name()
                        new_from['input_tables'][t.get_name()] = subq_name
                        self.all_queries[subq_name] = self.make_simple(sel)
                elif isinstance(t,sqlparse.sql.IdentifierList):
                    for tt in self._clean_query(t.tokens):
                        sel = self.parse_simple_select(self._get_select_from_par(tt))
                        if sel:
                            subq_name = self._generate_name()
                            new_from['input_tables'][tt.get_name()] = subq_name
                            self.all_queries[subq_name] = self.make_simple(sel)
            parsed_t_s['from'] = new_from
            return parsed_t_s              
    
    
    def parse_select_statement(self,s_tokens):
        op_dict = {}
        for t in s_tokens:
            if isinstance(t,sqlparse.sql.Identifier):
                if isinstance(self._clean_query(t.tokens)[0],sqlparse.sql.Case):
                    op_dict[t.get_name()] = self._clean_col_def(t.value,t.get_name())
                else:
                    op_dict[t.get_name()] = self._clean_col_def(t.value,t.get_name())
            elif isinstance(t,sqlparse.sql.IdentifierList):
                for tt in self._clean_query(t.tokens):
                    op_dict[tt.get_name()] = self._clean_col_def(tt.value,tt.get_name())
        return op_dict


    def parse_from_statement(self,f_tokens):
        op_dict = {'input_tables':{}}

        # only identifier or identifier list present in f_tokens
        only_identifiers_flag=True
        for t in f_tokens:
            if not (isinstance(t,sqlparse.sql.Identifier) or isinstance(t,sqlparse.sql.IdentifierList)):
                only_identifiers_flag=False

        if only_identifiers_flag:
            for t in f_tokens:
                if isinstance(t,sqlparse.sql.Identifier):
                    op_dict['input_tables'][t.get_name()] = t.get_real_name()
                elif isinstance(t,sqlparse.sql.IdentifierList):
                    for tt in self._clean_query(t.tokens):
                        op_dict['input_tables'][tt.get_name()] = tt.get_real_name()
            return op_dict

        # join statement
        op_dict['joins'] = {}
        for i,t in enumerate(f_tokens):
            if isinstance(t,sqlparse.sql.Identifier):
                op_dict['input_tables'][t.get_name()] = t.get_real_name()

            if self._check_token_is_base_keyword(t,['left join','right join','inner join','join','full outer join']):
                key = f_tokens[i+1].get_name()
                join_dict = {}
                join_dict['type'] = t.value.replace(' join','')
                if join_dict['type']=='join':
                    join_dict['type'] = 'inner'
                join_dict['on'] = f_tokens[i+3].value
                op_dict['joins'][key] = join_dict

        return op_dict        



    # helper functions
    def _clean_col_def(self,c,n):
        return re.sub('\ '+n+'$','',c.split('--')[0].strip()).strip()

    def _generate_name(self,l=8):
        x = ''.join(random.choices(string.ascii_letters + string.digits, k=l))
        return x
 
    def _is_simple_select(self,q_t):
        from_t = []
        for t in q_t['from']:
            from_t.extend(list(t.flatten()))
        return not max([self._check_token_is_base_keyword(t,'select') for t in from_t])
        
    def _check_token_is_base_keyword(self,t,kw):
        if 'select' in kw:
            return t.match(sqlparse.tokens.Keyword.DML,kw)
        else:
            return t.match(sqlparse.tokens.Keyword,kw)

    def _get_select_from_par(self,subq_token):
        for t in subq_token.tokens:
            if isinstance(t,sqlparse.sql.Parenthesis):
                return self._clean_query(t.tokens)
        
    def _clean_query(self,tokens):
        tokens_exp = []
        tokens_cleaned = []
        for t in tokens:
            if isinstance(t,sqlparse.sql.Where):
                for tt in t.tokens:
                    tokens_exp.append(tt)
            else:
                tokens_exp.append(t)

        for t in tokens_exp:
            append_flag = True
            for pattern in ['Token.Text.Whitespace*','Token.Punctuation*']:
                if re.match(pattern,str(t.ttype)):
                    append_flag = False
            for t_type in [sqlparse.sql.Comment]:
                if type(t)==t_type:
                    append_flag = False

            if append_flag:
                tokens_cleaned.append(t)

        return tokens_cleaned

In [None]:
test_query_obj = SelectQuery(subq_test_q_1)

In [None]:
test_query_obj.parse_select_query()

In [None]:
test_query_obj.all_queries

In [None]:
'select' in 'select'