In [249]:
import sys
from antlr4 import *
from grammar.SQLiteLexer import SQLiteLexer
from grammar.SQLiteParser import SQLiteParser
from grammar.SQLiteListener import SQLiteListener
import re
from readlisp import readlisp, LispSymbol
import copy
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
def pre_process_tree(filename):
    with open(filename) as f:
        lines = []
        for line in f.readlines():
            if "like" in line.lower():
                new_line = re.sub("""(?!\()([^\s]+?\s+?(?:not\s)?like\s+?["'][^\s]+?["'])\s+?(?!\))""", r' (\1) ', line, flags = re.IGNORECASE)
                # print(re.findall("""(?!\()([^\s]+?\s+?(?:not\s)?like\s+?["'][^\s]+?["'])\s+?(?!\))""", line, flags = re.IGNORECASE))
                lines.append(new_line)
            else:
                lines.append(line)
    with open(filename, "w+") as f:
        f.writelines(lines)
        
def remove_lisp_symbol(l):
    if type(l) in [LispSymbol, int, float]:
        if type(l) == LispSymbol:
            return l.name.lower()
        else:
            return str(l)
    elif type(l) == list:
        return [remove_lisp_symbol(ls) for ls in l]
    else:
        print("Uhoh", type(l), l)
        
def get_parsed_tree(filename):
    pre_process_tree(filename)
    input_stream = FileStream(filename)
    lexer = SQLiteLexer(input_stream)
    stream = CommonTokenStream(lexer)
    parser = SQLiteParser(stream)
    tree = parser.parse()
    tree = remove_lisp_symbol(readlisp(tree.toStringTree(recog=parser)))
    return tree

In [29]:
test1 = get_parsed_tree("test1.sql")
test2 = get_parsed_tree("test2.sql")
test3 = get_parsed_tree("test3.sql")
test4 = get_parsed_tree("test4.sql")

In [372]:
select_s_exp_3 = test3[1][1][1]
select_s_exp_4 = test4[1][1][1]
select_s_exp_4

['factored_select_stmt',
 ['select_core',
  'select',
  ['result_column',
   ['expr',
    ['table_name', ['any_name', 'c1']],
    '.',
    ['column_name', ['any_name', 'cand_name']]]],
  ',',
  ['result_column',
   ['expr',
    ['table_name', ['any_name', 'c2']],
    '.',
    ['column_name', ['any_name', 'cmte_nm']]]],
  'from',
  ['join_clause',
   ['table_or_subquery',
    ['table_name', ['any_name', 'cand']],
    ['table_alias', 'c1']],
   ['join_operator', 'inner', 'join'],
   ['table_or_subquery',
    ['table_name', ['any_name', 'comm']],
    ['table_alias', 'c2']],
   ['join_constraint',
    'on',
    ['expr',
     ['expr',
      ['table_name', ['any_name', 'c1']],
      '.',
      ['column_name', ['any_name', 'cand_id']]],
     '=',
     ['expr',
      ['table_name', ['any_name', 'c2']],
      '.',
      ['column_name', ['any_name', 'cand_id']]]]]]],
 'order',
 'by',
 ['ordering_term',
  ['expr',
   ['table_name', ['any_name', 'c1']],
   '.',
   ['column_name', ['any_name', 'can

In [436]:
import random
from pprint import pprint

def gen_fresh_name(base_name, used_names):
    if base_name not in used_names:
        new_name = base_name
    else:
        for i in range(2, 10000):
            if f"{base_name}_{i}" in used_names:
                continue
            new_name = f"{base_name}_{i}"
            break
    used_names.append(new_name)
    return new_name

def gen_name_mapping(old_name_sets, new_names):
    # old names is a list of set
    # new names is a list of strings
    mapping = {}
    for old_names, new_name in zip(old_name_sets, new_names):
      for old_name in old_names:
        mapping[old_name] = new_name
    return mapping
    
class TableOp():
    def tbl_op_dispatcher(s_exp):
        print(s_exp)
        if s_exp[0] == "table_or_subquery":
            if s_exp[1][0] == 'table_name':
                print("dispatching to tbl_ref")
                return TableReference(s_exp)
            else:
                print("dispatching to select statement")
                return SelectStatement(s_exp[1][0])
        elif s_exp[0] == "join_clause":
            print("dispatching to join")
            return JoinTable.from_s_exp(s_exp)
        
        
class RenamedTable(TableOp):
    def __init__(self, name, table, cols):
        self.name = name
        self.table = table
        self.cols = cols # these are invented alias names 

    def infer_out_schema(self, schema):
        return self.table.infer_out_schema(schema)

    def rename(self, schema, used_table_names):
        return self.table.rename(schema, used_table_names)
    
    def rename_ops(self, mapping):
        pass
    
    def to_rkt(self, schema):
        return "(AS " + self.table.to_rkt(schema) + "\n[\"" + self.name + "\" (list " + " ".join(["\"" + col + "\"" for col in self.cols]) + ")])"
            
class TableReference(TableOp):
    def __init__(self, s_exp):
        self.name = s_exp[1][1][1]
        if len(s_exp) > 2:
            self.alias = s_exp[2][1]
        else:
            self.alias = None
            
    def infer_out_schema(self, schema):
        tbl_schema = schema[self.name]
        to_return = []
        for col in tbl_schema:
            full_name = self.alias + "." + col if self.alias else self.name + "." + col
            to_return.append(set([full_name, col]))
        
        print("returning", to_return)
        return to_return

    def rename(self, schema, used_table_names):
        tbl_schema = schema[self.name]
        tbl_name = self.alias if self.alias else self.name
        
        new_table_name = gen_fresh_name(tbl_name, used_table_names)
        
        # collecting old names
        old_names = [set([c, f"{tbl_name}.{c}"]) for c in tbl_schema]
        new_names = [f"{new_table_name}.{c}" for c in tbl_schema]
      
        mappings = gen_name_mapping(old_names, new_names)
        
        return RenamedTable(new_table_name, self, list(tbl_schema.keys())), mappings, used_table_names

    def rename_ops(self, mapping):
        pass
    
    def to_rkt(self, schema):
        return "(NAMED " + self.name + ")"

        
class JoinTable(TableOp):
    def __init__(self, left_tbl, right_tbl, join_op, constraint = None):
        self.left_tbl = left_tbl
        self.right_tbl = right_tbl
        self.join_op = join_op
        self.constraint = constraint
        
    @classmethod
    def from_s_exp(cls, s_exp):
        left_tbl = TableOp.tbl_op_dispatcher(s_exp[1])
        join_op = s_exp[2][1:]
        right_tbl = TableOp.tbl_op_dispatcher(s_exp[3])
        if "inner" in join_op:
            print("Wrapping inner table as select")
            if type(s_exp[4]) == list:
                constraint = PredicateOp.pred_op_dispatcher(s_exp[4])
            else:
                constraint = None
            cross_table = cls(left_tbl, right_tbl, 'cross')
            to_return = SelectStatement([], wrap = True, wrap_tbl = cross_table, wrap_constraints = constraint)
            return to_return
        else:
            if type(s_exp[4]) == list:
                constraint = PredicateOp.pred_op_dispatcher(s_exp[4])
            else:
                constraint = None
            return cls(left_tbl, right_tbl, join_op, constraint)

    def infer_out_schema(self, schema):        
        left_cols = self.left_tbl.infer_out_schema(schema)
        right_cols = self.right_tbl.infer_out_schema(schema)
        left_cols_set = set()
        right_cols_set = set()
        for col_set in left_cols:
             left_cols_set = left_cols_set.union(col_set)
        for col_set in right_cols:
             right_cols_set = right_cols_set.union(col_set)
        duplicate_col_names = left_cols_set & right_cols_set
        
        if hasattr(self, "alias") and self.alias:
            name = tbl.alias
        else:
            name = None
        
        used_names = set()
        non_referrables = [set() for i in range(len(left_cols + right_cols))]
        for index, colset in enumerate(left_cols + right_cols):
            for col_name in colset:
                if col_name in used_names:
                    non_referrables[index].add(col_name)
                else:
                    used_names.add(col_name)
            
        to_return = left_cols + right_cols
        if not name:
            to_return = [{c for c in nameset if c not in non_refs} for nameset, non_refs in zip(to_return, non_referrables)]
        else:
            to_return = [set([f'{name}.{c}' for c in nameset if c not in non_refs and "." not in c] 
                             + [c for c in nameset if c not in non_refs and '.' in c]) for nameset, non_refs in zip(to_return, non_referrables)]
#         for index, nameset in enumerate(left_cols):
#             for col in nameset:
#                 if col in duplicate_col_names:
#                     if name:
#                         to_return[index].union(set([col, name + "." + col]))
#                     else:
#                         to_return[index].union(set([col]))
        
        print("Inferring join, returning", to_return)
        return to_return
    
    
    def rename(self, schema, used_table_names):
        left_cols_old = self.left_tbl.infer_out_schema(schema)
        right_cols_old = self.right_tbl.infer_out_schema(schema)
        my_cols_old = self.infer_out_schema(schema)
        
        self.left_tbl, mappings_left, used_table_names = self.left_tbl.rename(schema, used_table_names)
        self.right_tbl, mappings_right, used_table_names = self.right_tbl.rename(schema, used_table_names)
        
        left_cols_new = self.left_tbl.infer_out_schema(schema)
        right_cols_new = self.right_tbl.infer_out_schema(schema)
        
        self.rename_ops({**mappings_left, **mappings_right})
        
        new_table_name = gen_fresh_name(self.alias if (hasattr(self, "alias") and self.alias) else "t", used_table_names)
        
        mappings = {}        
        # the list of new names
        column_list = []
        used_col_names = []
        for index, col_set in enumerate(my_cols_old):
            old_full_name = [col for col in col_set if "." in col]
            old_short_name = [col for col in col_set if "." not in col]
            base_name = old_full_name[0].split(".")[-1] if len(old_full_name) > 0 else "c"
            new_name = gen_fresh_name(base_name, used_col_names)
            
            for old_name in col_set:
                mappings[old_name] = new_table_name + "." + new_name
            column_list.append(new_name)
            
        print("My mapping from join:")
        pprint(mappings)
      
        return RenamedTable(new_table_name, self, column_list), mappings, used_table_names 

    def rename_ops(self, mapping):
        if self.constraint:
            self.constraint.rename_ops(mapping)
            
    def to_rkt(self, schema):
        if "left" in self.join_op:
            op = "LEFT-OUTER-JOIN"
        else:
            op = "JOIN"
             
        return f"({op} {self.left_tbl.to_rkt(schema)} {self.right_tbl.to_rkt(schema)} {self.constraint.to_rkt(schema) if self.constraint else ''})"
    

class SelectStatement(TableOp):
    def __init__(self, entire_s_exp, wrap = False, wrap_tbl = None, wrap_constraints = None):
        if wrap:
            self.columns = [AllColumn()]
            self.col_names = [None]
            self.tables = [wrap_tbl]
            self.tbl_names = [None]
            self.subquery_tree = wrap_tbl
            self.where_tree = wrap_constraints
            self.group_col = None
            self.having_tree = None
            self.order_cols = []
            self.ordering_dir = []
        else:
            self.from_s_exp(entire_s_exp)

    
    def from_s_exp(self, entire_s_exp):
        s_exp = entire_s_exp[1]

        select_index = s_exp.index('select')
        from_index = s_exp.index('from')
        self.columns = []
        self.col_names = []
        for term in s_exp[select_index:from_index]:
            if type(term) == list and term[0] == 'result_column':
                self.columns.append(ColOp.col_op_dispatcher(term[1]))
                if "as" in term:
                    alias = term[term.index("as") + 1][1]
                    self.col_names.append(alias)
                else:
                    self.col_names.append(None)

        self.tables = []
        self.tbl_names = [] 
        if "where" in s_exp:
            where_index = s_exp.index('where')
            for term in s_exp[from_index+1:where_index]:
                if type(term) == list:
                    self.tables.append(TableOp.tbl_op_dispatcher(term))
                if "as" in term:
                    alias = term[term.index("as") + 1][1]
                    self.tbl_names.append(alias)
                else:
                    self.tbl_names.append(None)
        else:
            for term in s_exp[from_index+1:]:
                if type(term) == list:
                    self.tables.append(TableOp.tbl_op_dispatcher(term))
                if "as" in term:
                    alias = term[term.index("as") + 1][1]
                    self.tbl_names.append(alias)
                else:
                    self.tbl_names.append(None)

        if len(self.tables) > 1:
            self.tables[0].alias = self.tbl_names[0]
            self.tables[1].alias = self.tbl_names[1]
            self.subquery_tree = JoinTable(self.tables[0], self.tables[1], 'cross')
            if len(self.tables) > 2:
                for i, tbl in enumerate(self.tables[2:]):
                    tbl.alias = self.tbl_names[i+2]
                    self.subquery_tree = JoinTable(left_tbl = self.subquery_tree, right_tbl = tbl, join_op = 'cross')
        else:
            self.subquery_tree = self.tables[0]
            self.subquery_tree.alias = self.tbl_names[0]

        if "where" in s_exp:
            self.where_tree = PredicateOp.pred_op_dispatcher(s_exp[where_index+1])
        else:
            self.where_tree = None

        if "group" in s_exp:
            groupby_index = list(filter(lambda i: s_exp[i] == 'group' and s_exp[i+1] == 'by', range(len(s_exp)-2)))[0]
            if "having" in s_exp:  
                having_index = s_exp.index('having')
                self.group_col = [ColOp.col_op_dispatcher(col) for col in s_exp[groupby_index + 2:having_index]]
            else:
                self.group_col = [ColOp.col_op_dispatcher(col) for col in s_exp[groupby_index + 2:]]
        else:
            self.group_col = None
        
        if "having" in s_exp:
            having_index = s_exp.index('having')
            self.having_tree = PredicateOp.pred_op_dispatcher(s_exp[having_index+1])
        else:
            self.having_tree = None

        self.order_cols = []
        self.ordering_dir = []
        if 'order' in entire_s_exp:
            orderby_index = list(filter(lambda i: entire_s_exp[i] == 'order' and entire_s_exp[i+1] == 'by', range(len(entire_s_exp)-2)))[0]
            if "limit" in entire_s_exp:
                end = entire_s_exp.index('limit')
            else:
                end = len(entire_s_exp)
            for term in entire_s_exp[orderby_index+2:end]:
                if type(term) == list and term[0] == "ordering_term":
                    self.order_cols.append(ColOp.col_op_dispatcher(term[1]))
                    if len(term) > 2:
                        self.ordering_dir.append(term[2])
                    else:
                        self.ordering_dir.append("asc")

        if 'limit' in entire_s_exp:
            limit_index = entire_s_exp.index('limit')
            self.limit = entire_s_exp[limit_index+1][1][1]

    
    def infer_out_schema(self, schema):
        to_return = []
        if type(self.columns[0]) == AllColumn:
            schema = self.subquery_tree.infer_out_schema(schema)
            to_return = [{col for col in nameset if "." not in col} for nameset in schema]
        else:
            for col, col_name in zip(self.columns, self.col_names):
                if col_name:
                    to_return.append(set([col_name]))
                else:
                    if type(col) in [UnaryColumnOp, BinaryColumnOp, ConstantColumn]:
                        to_return.append(set())
                    else:
                        to_return.append(set([col.name]))
                    
        if hasattr(self, "alias") and self.alias:
            for col in to_return:
                if col:
                    col.add(self.alias + "." + next(iter(col)))
        print("Inferring select, returning", to_return)
        return to_return

    def rename(self, schema, used_table_names):
        child_old_names = self.subquery_tree.infer_out_schema(schema)
        my_old_names = self.infer_out_schema(schema)
        
        self.subquery_tree, child_mappings, used_table_names = self.subquery_tree.rename(schema, used_table_names)
        child_new_names = self.subquery_tree.infer_out_schema(schema)
        
        # only rename ops in select and where, ops in having should be renamed using my_mappings instead of mappings
        self.rename_ops(child_mappings)
        print("My child old names are", child_old_names)
        new_table_name = gen_fresh_name(self.alias if (hasattr(self, "alias") and self.alias) else "t", used_table_names)
        print("my old names are", my_old_names)
        mappings = {}
        column_list = []
        used_col_names = []
        #print("My old names are", old_names)
        for index, col_set in enumerate(my_old_names):
            old_full_name = [col for col in col_set if "." in col]
            old_short_name = [col for col in col_set if "." not in col]
            if len(old_full_name) > 0:
                base_name = old_full_name[0]
            elif len(old_short_name) > 0:
                base_name = old_short_name[0]
            else:
                base_name = 'c'
            new_name = gen_fresh_name(base_name, used_col_names)

            for old_name in col_set:
                mappings[old_name] = new_table_name + "." + new_name

            column_list.append(new_name)
          
        print(f"My mapping is [Select {new_table_name}]:")
        pprint(mappings)
        
        # old names overwrites new names
        mappings_copy = copy.deepcopy(mappings)
        for name in child_mappings:
            mappings_copy[name] = child_mappings[name]
        
        # self.rename_having(mappings_copy)
        
        return RenamedTable(new_table_name, self, column_list), mappings, used_table_names

    def rename_ops(self, mapping):
        replacement_col_names = []
        if type(self.columns[0]) == AllColumn:
            print("Renaming inner table")
            pass
        else:
            for col in self.columns:
                col.rename_ops(mapping)

        if self.where_tree:
            self.where_tree.rename_ops(mapping)

        if self.group_col:
            for col in self.group_col:
                col.rename_ops(mapping) 
                
    def rename_having(self, mapping):
        if self.having_tree:
            self.having_tree.rename_ops(mapping)

        for col in self.order_cols:
            col.rename_ops(mapping)
            
    def to_rkt(self, schema):
        where_part = "\nWHERE " + self.where_tree.to_rkt(schema) if self.where_tree else ""
        group_part = "\nGROUP-BY (list " + " ".join([col.to_rkt(schema) + " " for col in self.group_col])+ ")" if self.group_col else ""
        having_part = "\nHAVING " + self.having_tree.to_rkt(schema) if self.having_tree else ""
        
        return "(SELECT " + "(VALS " + " ".join([col.to_rkt(schema) for col in self.columns]) + ")" + "\nFROM " + self.subquery_tree.to_rkt(schema) + where_part + " " + group_part + " " + having_part + ")"
            

class ColOp():
    def col_op_dispatcher(s_exp):
        print(s_exp)
        if s_exp == "*" or s_exp == "[*]":
            print("dispatching to *")
            return AllColumn()
        assert s_exp[0] == "expr"
        if s_exp[1][0] == "column_name":
            print("dispatching to col name")
            return Column(s_exp[1])
        elif s_exp[1][0] == "literal_value":
            print("dispatching to col value")
            return ConstantColumn(s_exp[1])
        elif s_exp[1][0] == "function_name" and len(s_exp) == 3:
            print("dispatching to unary")
            return UnaryColumnOp(s_exp)
        elif '.' in s_exp:
            print("dispatching to col name")
            return Column(s_exp, table = True)
        elif len(s_exp) == 4:
            print("dispatching to binary")
            return BinaryColumnOp(s_exp)       

class Column(ColOp):
    def __init__(self, s_exp, table = False):
        if not table:
            assert s_exp[0] == "column_name"
            self.name = s_exp[1][1]
            self.table = None
        else:
            assert s_exp[1][0] == "table_name"
            self.name = s_exp[3][1][1]
            self.table = s_exp[1][1][1]

    def rename_ops(self, mapping):
        if self.table:
            new_name = mapping[self.table + "." + self.name] 
            self.table = new_name.split(".")[0]
            self.name = new_name.split(".")[1]
        else:
            new_name = mapping[self.name] 
            self.name = new_name.split(".")[1]
       
      
    def to_rkt(self, schema):
        return "\"" + (self.table + "." + self.name if self.table else self.name) + "\""
        
        
class ConstantColumn(ColOp):
    def __init__(self, s_exp):
        assert s_exp[0] == "literal_value"
        self.value = s_exp[1]
        
    def to_rkt(self, schema):
        return str(self.value)
    
    def rename_ops(self, mapping, tbl_name):
        pass
    
class UnaryColumnOp(ColOp):
    def __init__(self, s_exp):
        self.op = s_exp[1][1][1]
        self.col_ops = []
        for child in s_exp[2]:
            self.col_ops.append(ColOp.col_op_dispatcher(child))

    def rename_ops(self, mapping):
        for child in self.col_ops:
            child.rename_ops(mapping)
            
    def to_rkt(self, schema):
        return "(" + self.op + " " + " ".join(["\"" + (col.table + "." + col.name if col.table else col.name) + "\""  for col in self.col_ops]) + ")"
            
class BinaryColumnOp(ColOp):
    def __init__(self, s_exp):
        self.op = s_exp[2]
        self.left_col_op = ColOp.col_op_dispatcher(s_exp[1])
        self.right_col_op = ColOp.col_op_dispatcher(s_exp[3])

    def rename_ops(self, mapping):
        self.left_col_op.rename_ops(mapping)
        self.right_col_op.rename_ops(mapping)
        
    def to_rkt(self, schema):
        return "(BINOP " + self.left_col_op.to_rkt(schema) + " " + self.op + " " + self.right_col_op.to_rkt(schema) + ")"
    

class AllColumn(ColOp):
    def __init__(self):
        pass

    def rename_ops(self, mapping):
        pass
    
    def to_rkt(self, schema):
        # TODO: how to interpret COUNT(*)
        pass
    
class PredicateOp():
    def pred_op_dispatcher(s_exp):
        print(s_exp)
        if s_exp[0] == 'join_constraint':
            print("dispatching to Join")
            return JoinPredicate(s_exp)
        assert s_exp[0] == "expr"
        s_exp = find_ultimate_pred(s_exp)
        print("cleaned s_exp", s_exp)
        if 'and' in s_exp:
            print("dispatching to And")
            return AndPred(s_exp)
        elif 'or' in s_exp:
            print("dispatching to Or")
            return OrPred(s_exp)
        elif s_exp[1][0] == 'unary_operator' and s_exp[1][1] == 'not':
            print("dispatching to Not")
            return NotPred(s_exp)
        elif "like" in s_exp:
            print("dispatching to like")
            return LikePredicate(s_exp)
        else:
            print("dispatching to Predicate")
            return Predicate(s_exp)
        

class JoinPredicate(PredicateOp):
    def __init__(self, s_exp):
        assert s_exp[0] == 'join_constraint'
        pred_stmt = s_exp[2]
        self.op = pred_stmt[2]
        self.left_pred_op = ColOp.col_op_dispatcher(pred_stmt[1])
        self.right_pred_op = ColOp.col_op_dispatcher(pred_stmt[3])
    
    def rename_ops(self, mapping):
        self.left_pred_op.rename_ops(mapping)
        self.right_pred_op.rename_ops(mapping)
        
    def to_rkt(self, schema):
        return "(BINOP " + self.left_pred_op.to_rkt(schema) + " " + self.op + " " + self.right_pred_op.to_rkt(schema) + ")"

def find_ultimate_pred(s_exp):
    assert s_exp[0] == "expr"
    if type(s_exp[1][0]) == list:
        return find_ultimate_pred(s_exp[1][0])
    else:
        return s_exp
    
class Predicate(PredicateOp):
    def __init__(self, s_exp):
        assert s_exp[0] == 'expr'
        pred_stmt = s_exp
        self.left_pred_op = ColOp.col_op_dispatcher(pred_stmt[1])
        if pred_stmt[2] == 'not':
            self.op = pred_stmt[2:4]
            self.right_pred_op = ColOp.col_op_dispatcher(pred_stmt[4])
        else:
            self.op = pred_stmt[2]
            self.right_pred_op = ColOp.col_op_dispatcher(pred_stmt[3])

    def rename_ops(self, mapping):
        self.left_pred_op.rename_ops(mapping)
        self.right_pred_op.rename_ops(mapping)

    def to_rkt(self, schema):
        return "(BINOP " + self.left_pred_op.to_rkt(schema) + " " + self.op + " " + self.right_pred_op.to_rkt(schema) + ")"
    
class LikePredicate(Predicate):
    def __init__(self, s_exp):
        assert s_exp[0] == 'expr'
        self.left_pred_op = ColOp.col_op_dispatcher(s_exp[1])
        if s_exp[2] == 'not':
            self.like = False
            self.pattern = s_exp[4][1][1]
        else:
            self.like = True
            self.pattern = s_exp[3][1][1]
        print("my pattern is", self.pattern)
    
    def rename_ops(self, mapping):
        self.left_pred_op.rename_ops(mapping)
        
    def to_rkt(self, schema):
        if self.like:
            return "(LIKEOP " + self.left_pred_op.to_rkt(schema) + " \"" + self.pattern + "\")"
        else:
            return "(NOT (LIKEOP " + self.left_pred_op.to_rkt(schema) + " \"" + self.pattern + "\"))"
            
class AndPred(PredicateOp):
    def __init__(self, s_exp):
        self.left_pred_op = PredicateOp.pred_op_dispatcher(s_exp[1])
        self.right_pred_op = PredicateOp.pred_op_dispatcher(s_exp[3])

    def rename_ops(self, mapping):
        self.left_pred_op.rename_ops(mapping)
        self.right_pred_op.rename_ops(mapping)
        
    def to_rkt(self, schema):
        return "(AND " + self.left_pred_op.to_rkt(schema) + " " + self.right_pred_op.to_rkt(schema) + ")"
        
class OrPred(PredicateOp):
    def __init__(self, s_exp):
        self.left_pred_op = PredicateOp.pred_op_dispatcher(s_exp[1])
        self.right_pred_op = PredicateOp.pred_op_dispatcher(s_exp[3])
    
    def rename_ops(self, mapping):
        self.left_pred_op.rename_ops(mapping)
        self.right_pred_op.rename_ops(mapping)
        
    def to_rkt(self, schema):
        return "(OR " + self.left_pred_op.to_rkt(schema) + " " + self.right_pred_op.to_rkt(schema) + ")"
        
def NotPred(PredicateOp):
    def __init__(self, s_exp):
        pred_stmt = s_exp[2][1][0]
        self.op = pred_stmt[2]
        self.pred_op = PredOp.pred_op_dispatcher(pred_stmt[1])
    
    def rename_ops(self, mapping):
        self.pred_op.rename_ops(mapping)
        
    def to_rkt(self, schema):
        return "(NOT " + self.pred_op.to_rkt(schema) + ")"
 

In [437]:
a = SelectStatement(select_s_exp)

['expr', ['table_name', ['any_name', 'cand']], '.', ['column_name', ['any_name', 'cand_name']]]
dispatching to col name
['expr', ['table_name', ['any_name', 'comm']], '.', ['column_name', ['any_name', 'cmte_nm']]]
dispatching to col name
['join_clause', ['table_or_subquery', ['table_name', ['any_name', 'cand']]], ['join_operator', 'join'], ['table_or_subquery', ['table_name', ['any_name', 'comm']]], 'join_constraint']
dispatching to join
['table_or_subquery', ['table_name', ['any_name', 'cand']]]
dispatching to tbl_ref
['table_or_subquery', ['table_name', ['any_name', 'comm']]]
dispatching to tbl_ref
['expr', ['expr', ['table_name', ['any_name', 'comm']], '.', ['column_name', ['any_name', 'cand_id']]], '=', ['expr', ['table_name', ['any_name', 'cand']], '.', ['column_name', ['any_name', 'cand_id']]]]
cleaned s_exp ['expr', ['expr', ['table_name', ['any_name', 'comm']], '.', ['column_name', ['any_name', 'cand_id']]], '=', ['expr', ['table_name', ['any_name', 'cand']], '.', ['column_name

In [438]:
schema2 = {"indiv_sample_nyc": {
        "cmte_id": "int", 
        "transaction_amt": "int",
        "name": "str"
}, "comm": {
        "cmte_id": "int", 
        "cmte_nm": "int",
        "cand_id": "int"
},
    "cand": {"cand_name": "str", "cand_id": "int"}
}


cand_schema = {"cand_name": "str", "cand_id": "int"}
comm_schema = {"cmte_nm": "int", "cand_id": "int"}
schema1 = {'cand': cand_schema, 'comm': comm_schema}

In [439]:
a = a.rename(schema2, [])[0]

returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}]
returning [{'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'cand_id', 'comm.cand_id'}]
Inferring join, returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}, {'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'comm.cand_id'}]
Inferring select, returning [{'cand_name'}, {'cmte_nm'}]
returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}]
returning [{'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'cand_id', 'comm.cand_id'}]
returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}]
returning [{'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'cand_id', 'comm.cand_id'}]
Inferring join, returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}, {'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'comm.cand_id'}]
returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}]
returning [{'cmte_id', 'com

In [440]:
result = a.to_rkt(schema2)
result

'(AS (SELECT (VALS "t.cand_name" "t.cmte_nm")\nFROM (AS (JOIN (AS (NAMED cand)\n["cand" (list "cand_name" "cand_id")]) (AS (NAMED comm)\n["comm" (list "cmte_id" "cmte_nm" "cand_id")]) )\n["t" (list "cand_name" "cand_id" "cmte_id" "cmte_nm" "cand_id_2")])\nWHERE (BINOP "t.cand_id_2" = "t.cand_id")  )\n["t_2" (list "cand_name" "cmte_nm")])'

In [441]:
print(result)

(AS (SELECT (VALS "t.cand_name" "t.cmte_nm")
FROM (AS (JOIN (AS (NAMED cand)
["cand" (list "cand_name" "cand_id")]) (AS (NAMED comm)
["comm" (list "cmte_id" "cmte_nm" "cand_id")]) )
["t" (list "cand_name" "cand_id" "cmte_id" "cmte_nm" "cand_id_2")])
WHERE (BINOP "t.cand_id_2" = "t.cand_id")  )
["t_2" (list "cand_name" "cmte_nm")])


In [442]:
a.table.subquery_tree.table.infer_out_schema(schema2)

returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}]
returning [{'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'cand_id', 'comm.cand_id'}]
Inferring join, returning [{'cand_name', 'cand.cand_name'}, {'cand_id', 'cand.cand_id'}, {'cmte_id', 'comm.cmte_id'}, {'comm.cmte_nm', 'cmte_nm'}, {'comm.cand_id'}]


[{'cand.cand_name', 'cand_name'},
 {'cand.cand_id', 'cand_id'},
 {'cmte_id', 'comm.cmte_id'},
 {'cmte_nm', 'comm.cmte_nm'},
 {'comm.cand_id'}]

for every rename call, should distinguish
1. get children old col names
2. get my own old names
3. rename child. This returns us a child mapping of old to new
4. rename_ops with child mapping (it always uses tehc hild mapping) (RENAME OPS ONLY TAKES MAPPING + self)
5. invent my new names
6. create my mappings from my old to new names
7. rename having + order colops using parent mapping
8. wrap and return renamed table + mappings


In [326]:
(AS (SELECT (VALS "c_0" "c_1")
    FROM (AS (JOIN (AS (NAMED cand)
                        ["cand_0" (list "cand_name" "cand_id")]) 
                    (AS (NAMED comm)
                        ["comm_0" (list "cmte_id" "cmte_nm" "cand_id")]) )
        ["t_0" (list "cand_name_0" "cand_id_0" "cmte_id_0" "cmte_nm_0" "cand_id_1")])
    WHERE (BINOP "t_0.cand_id_1" = "t_0.cand_id_0")  )
    ["t_1" (list "c_0" "c_1")])

SyntaxError: invalid syntax (<ipython-input-326-def6d4ca4b53>, line 1)

In [290]:
(AS (JOIN (AS (NAMED cand)
            ["cand_0" (list "cand_name" "cand_id")]) 
         (AS (NAMED comm)
            ["comm_0" (list "cmte_id" "cmte_nm" "cand_id")]) )
["t_0" (list "cand_name_0" "cand_name_0" "cand_id_0" "cand_id_0" "cmte_id_0" "cmte_id_0" "cmte_nm_0" "cmte_nm_0" "cand_id_1" "cand_id_1")])

SyntaxError: invalid syntax (<ipython-input-290-eae5f93fb052>, line 1)

In [None]:
(AS (SELECT (VALS "tbl_1.cand_name" "tbl_1.cmte_nm")
    FROM (AS (JOIN (AS (NAMED cand)
                    ["cand" (list "cand.cand_name" "cand.cand_id")]) 
                  (AS (NAMED comm)
                    ["comm" (list "comm.cmte_id" "comm.cmte_nm" "comm.cand_id")]) )
            ["tbl_0" (list "cand.cand_name" "cand.cand_id" "comm.cmte_id" "comm.cmte_nm" "comm.cand_id2")])
    WHERE (BINOP "comm.cand_id" = "tbl_0.cand_id")  )
["tbl_1" (list "tbl_0.cand_name" "tbl_0.cand_id" "tbl_0.cmte_id" "tbl_0.cmte_nm" "temp_col_name_4")])

In [None]:
(SELECT (VALS "tbl_1.cmte_id" "tbl_1.total_amount" "tbl_1.num_donations" "tbl_1.cmte_nm")
    FROM (AS (SELECT (VALS "t3.cmte_id" "t3.transaction_amt" "t3.name" "t3.cmte_id2" "t3.cmte_nm")
        FROM (AS (JOIN (AS (NAMED indiv_sample_nyc)
            ["indiv_sample_nyc" (list "indiv_sample_nyc.cmte_id" "indiv_sample_nyc.transaction_amt" "indiv_sample_nyc.name")]) (AS (NAMED comm)
            ["comm" (list "comm.cmte_id" "comm.cmte_nm")]) )
        ["tbl_0" (list "indiv_sample_nyc.cmte_id" "indiv_sample_nyc.transaction_amt" "indiv_sample_nyc.name" "comm.cmte_id2" "comm.cmte_nm")])
        WHERE (BINOP "tbl_0.cmte_id" == "comm.cmte_id")  )
    ["t3" (list "tbl_0.cmte_id" "tbl_0.transaction_amt" "tbl_0.name" "temp_col_name_3" "tbl_0.cmte_nm")])
WHERE (AND (AND (LIKEOP "name" "'%trump%'") (LIKEOP "name" "'%donald%'")) (NOT (LIKEOP "name" "'%inc%'"))) 
GROUP-BY (list ""cmte_id"" ) )

In [None]:
# TODO: renamed columns featured in having or order by -> wrap in select
# SELECT *
# (SELECT 
#     cmte_id,
#     SUM(transaction_amt) as total_amount,
#     COUNT(*) as num_donations,
#     cmte_nm
# FROM 
#         (SELECT *
#         FROM indiv_sample_nyc, comm
#         WHERE indiv_sample_nyc.cmte_id == comm.cmte_id) as t3
# WHERE  (name LIKE '%TRUMP%') AND  (name LIKE '%DONALD%') AND  (name NOT LIKE '%INC%')
# GROUP BY cmte_id) AS [a, b, c]
# WHERE c > 10
# ORDER BY c DESC

In [None]:
b = SelectStatement(test3[1][1][1])
b = b.rename(schema1, 0)[0]
print(b.to_rkt(schema1))

In [None]:
b.table.subquery_tree.table

1. Finish to do items within the classes, ensure it works
2. to rkt functions
3. testing
4. python testing on limit/order, LIKEOP (then remove for rkt tree)
5. error logging feedback
6. dockerization

nice to have
- fstrings for to_rkt

In [None]:
SELECT 
    cmte_id,
    SUM(transaction_amt) as total_amount,
    COUNT(*) as num_donations,
    cmte_nm
FROM 
        SELECT *
        FROM indiv_sample_nyc, comm
        WHERE indiv_sample_nyc.cmte_id == comm.cmte_id
WHERE  (name LIKE '%TRUMP%') AND  (name LIKE '%DONALD%') AND  (name NOT LIKE '%INC%') GROUP BY cmte_id
ORDER BY total_amount DESC


