In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import sqlparse
import re
from SQLParser import *
import SQLFunctions as funcs
from SQLParserError import SQLParserError
import warnings

In [2]:
sql = SQLParser()

In [49]:
raw = '''
select
            t.year_of_calendar
            , datediff(usr.end_dttm, usr.start_dttm) as days_per_trip
            , case
                when non_flick.connections_per_day is not null then 1
                else 0
            end as flick_flag
            , case
                when lt.rank_in_month is not null then 1
                else 0
            end as longest_trip_flag
        from month as t
        join non_flickers as non_flick on usr.subs_id = non_flick.subs_id
                and t.calendar_date = non_flick.calendar_date
                and connections_per_day is not null
                and connections_per_day1 is not null
                and connections_per_day2 is not null
                and end_
'''

In [63]:
raw = '''
select
            t.year_of_calendar
            , t.calendar_date
            , t.end_of_month
            , usr.subs_id
            , usr.msisdn
            , usr.start_dttm
            , usr.end_dttm
            , usr.home_province_code
            , usr.province_code_correct
            , usr.country_code
            , usr.home_flag
            , usr.end_date
            , usr.home_region
            , usr.destination
            , case
                when (unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) > 86400 then 86400 / 60 / 60
                when (unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) <= 86400 and (unix_timestamp(t.calendar_date) - unix_timestamp(usr.start_dttm)) <= 86400 then
                    round((unix_timestamp(usr.end_dttm) - unix_timestamp(usr.start_dttm)) / 60 / 60, 2)
                when (unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) <= 86400 and (unix_timestamp(t.calendar_date) - unix_timestamp(usr.start_dttm)) > 86400 then
                    round((unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) / 60 / 60, 2)
            end as time_per_day
            , datediff(usr.end_dttm, usr.start_dttm) as days_per_trip
            , case
                when non_flick.connections_per_day is not null then 1
                else 0
            end as flick_flag
            , case
                when lt.rank_in_month is not null then 1
                else 0
            end as longest_trip_flag
        from month as t
        join non_flickers as non_flick on usr.subs_id = non_flick.subs_id
                and t.calendar_date = non_flick.calendar_date
                and connections_per_day is not null
        join prd2_dic_v.geo_country as cntr 
        on t.country_code = substr(cntr.country_code, 1, 2) and t1.calendar_date = non_flick.calendar_date1
        left join cls_moscow as msc on t.subs_id = msc.subs_id
        where t.start_dttm <= '2022-10-31 23:59:59'
            and t.end_date >= '20221001'
'''

In [64]:
d = sql.default_string(raw)

In [38]:
print(join_rule_sub(d))

['JOIN non_flickers AS non_flick ON usr.subs_id = non_flick.subs_id', 'AND t.calendar_date = non_flick.calendar_date', 'AND connections_per_day IS NOT NULL', 'AND connections_per_day IS NOT NULL', 'AND connections_per_day IS NOT NULL', 'AND end_']
[]


In [52]:
def join_rule_sub( raw):
        '''
        Функция проверяет корректность написания "JOIN ... ON" выражения.
        Если условие только одно, то допускается написание наравне с оператором JOIN, если условий более одного, 
        то их необходимо перенести на новую строку (совместно с ON) и использовать скобки.
        '''
        assert isinstance(raw, str), 'Аргумент raw должен быть типа string'
        
        script_by_rows = [tup for tup in raw.split('\n') if len(tup)>0]
        
        join_keywords = ['JOIN', 'LEFT', 'RIGHT', 'INNER', 'FULL', 'OUTER', 'CROSS']
        join_end = ['WHERE', 'ORDER BY', 'GROUP BY']
        
        join_groups = []
        counter = 0
        for ind in range(0, len(script_by_rows)-1):
            
            # если строка начинается с join, то запоминаем ее индекс
            if script_by_rows[ind].startswith(tuple(join_keywords)) and counter == 0:
                counter = ind
                # если следующая строка так же join, то записываем текущую строку и обнуляем счетчик
                if script_by_rows[ind+1].startswith(tuple(join_keywords)):
                    join_groups.append(script_by_rows[counter:ind+1])
                    counter = 0
                    continue
            
            # если следующая строка join и счетчик не 0, то записываем диапазон строк, счетчик = 0
            if script_by_rows[ind+1].startswith(tuple(join_keywords)) and counter != 0:
                join_groups.append(script_by_rows[counter:ind+1])
                counter = 0
                
            # если впереди другие операторы (не join) и счетчик не равен 0 (значит join заканчивается), то записываем диапазон
            if script_by_rows[ind+1].startswith(tuple(join_end)) and counter != 0:
                join_groups.append(script_by_rows[counter:ind+1])
                counter = 0
            
            # если счетчик не равен 0, в запросе несколько условий и запрос заканчивается
            if len(script_by_rows) == ind + 2 and counter != 0:
                join_groups.append(script_by_rows[counter:ind+2])
     
        for group in join_groups:
            if len(group) > 1:
                for ind in range(0, len(group)):
                    try:
                        # если пользователь проставил скобки
                        if 'ON (' in group[ind]:
                            count_spaces = 4 * ' '
                            re_val = group[ind].replace('(', '\(')
                            pattern = re.compile(f'{re_val}')
                            right_val = group[ind].replace('ON (', '\n' + count_spaces + 'ON (\n' + count_spaces * 2)
                            raw = pattern.sub(right_val, raw, 1)
                            
                            # поиск закрывающей скобки
                            for i in range(1, len(group)):
                                count_spaces = 4 * ' '
                                if ')' in group[i]:
                                    re_val = group[i].replace(')', '\)')
                                    pattern = re.compile(f'{re_val}')
                                    right_val = count_spaces * 2 + group[i].strip().replace(')', '\n' + count_spaces + ')')
                                else:
                                    re_val = group[i]
                                    pattern = re.compile(f'{re_val}')
                                    right_val = count_spaces * 2 + group[i].strip()
                                raw = pattern.sub(right_val, raw, 1)
                        
                        
                        if '(' in group[ind] or ')' in group[ind]:
                            re_val = group[ind].replace('(', '\(').replace(')', '\)')
                            pattern = re.compile(f'{re_val}')
                        else:
                            pattern = re.compile(f'{group[ind]}')
                        
                        # если скобки не проставлены
                        if group[ind].startswith(tuple(join_keywords)) and 'ON (' not in group[ind]:
                            count_spaces = 4 * ' '
                            right_values1 = group[ind].replace('ON ', '\n' + count_spaces + 'ON (\n' + count_spaces)
                            raw = pattern.sub(right_values1, raw, 1)
                        
                        if 'AND ' in group[ind] or 'OR ' in group[ind]:
                            if ind == len(group) - 1:
                                
                                if 'AND ' in group[ind]:
                                    count_space = 4 * ' '
                                    right_val = group[ind].replace('AND', count_space + 'AND')
                                    right_val = right_val + '\n' + count_space + ')'
                                    raw = pattern.sub(right_val, raw, 1)
                                if 'OR ' in group[ind]:
                                    count_space = 4 * ' '
                                    right_val = group[ind].replace('OR', count_space + 'OR')
                                    right_val = right_val + '\n' + count_space + ')'
                                    raw = pattern.sub(right_val, raw, 1)
                            else:
                                if 'AND ' in group[ind]:
                                    count_space = 4 * ' '
                                    right_val = group[ind].replace('AND', count_space + 'AND')
                                    raw = pattern.sub(right_val, raw, 1)
                                if 'OR ' in group[ind]:
                                    count_space = 4 * ' '
                                    right_val = group[ind].replace('OR', count_space + 'OR')
                                    raw = pattern.sub(right_val, raw, 1)
                    except:
                        raise SQLParserError("Something went wrong in " '{!r}'.format(group[ind]))        
        return raw

In [65]:
print(join_rule_sub(d))

AND t.calendar_date = non_flick.calendar_date 3 1
AND connections_per_day IS NOT NULL 3 2
AND t1.calendar_date = non_flick.calendar_date1 2 1

SELECT t.year_of_calendar ,
       t.calendar_date ,
       t.end_of_month ,
       usr.subs_id ,
       usr.msisdn ,
       usr.start_dttm ,
       usr.end_dttm ,
       usr.home_province_code ,
       usr.province_code_correct ,
       usr.country_code ,
       usr.home_flag ,
       usr.end_date ,
       usr.home_region ,
       usr.destination ,
       CASE
           WHEN (unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) > 86400 THEN 86400 / 60 / 60
           WHEN (unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) <= 86400
                AND (unix_timestamp(t.calendar_date) - unix_timestamp(usr.start_dttm)) <= 86400 THEN round((unix_timestamp(usr.end_dttm) - unix_timestamp(usr.start_dttm)) / 60 / 60, 2)
           WHEN (unix_timestamp(usr.end_dttm) - unix_timestamp(t.calendar_date)) <= 86400
                AN