From 77738743403e8e16da385018211f0309cb380404 Mon Sep 17 00:00:00 2001 From: Tao Yu Date: Tue, 25 Jun 2019 18:00:09 -0700 Subject: [PATCH] Add files via upload --- evaluation.py | 938 +++++++++++++++++++++++++++++++++++++++++++++++++ process_sql.py | 566 +++++++++++++++++++++++++++++ 2 files changed, 1504 insertions(+) create mode 100644 evaluation.py create mode 100644 process_sql.py diff --git a/evaluation.py b/evaluation.py new file mode 100644 index 0000000..066120e --- /dev/null +++ b/evaluation.py @@ -0,0 +1,938 @@ +################################ +# val: number(float)/string(str)/sql(dict) +# col_unit: (agg_id, col_id, isDistinct(bool)) +# val_unit: (unit_op, col_unit1, col_unit2) +# table_unit: (table_type, col_unit/sql) +# cond_unit: (not_op, op_id, val_unit, val1, val2) +# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] +# sql { +# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) +# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} +# 'where': condition +# 'groupBy': [col_unit1, col_unit2, ...] +# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) +# 'having': condition +# 'limit': None/limit value +# 'intersect': None/sql +# 'except': None/sql +# 'union': None/sql +# } +################################ + +import os, sys +import json +import sqlite3 +import traceback +import argparse + +from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql + +# Flag to disable value evaluation +DISABLE_VALUE = True +# Flag to disable distinct in select evaluation +DISABLE_DISTINCT = True + + +CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +JOIN_KEYWORDS = ('join', 'on', 'as') + +WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') +UNIT_OPS = ('none', '-', '+', "*", '/') +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') +TABLE_TYPE = { + 'sql': "sql", + 'table_unit': "table_unit", +} + +COND_OPS = ('and', 'or') +SQL_OPS = ('intersect', 'union', 'except') +ORDER_OPS = ('desc', 'asc') + + +HARDNESS = { + "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), + "component2": ('except', 'union', 'intersect') +} + + +def condition_has_or(conds): + return 'or' in conds[1::2] + + +def condition_has_like(conds): + return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] + + +def condition_has_sql(conds): + for cond_unit in conds[::2]: + val1, val2 = cond_unit[3], cond_unit[4] + if val1 is not None and type(val1) is dict: + return True + if val2 is not None and type(val2) is dict: + return True + return False + + +def val_has_op(val_unit): + return val_unit[0] != UNIT_OPS.index('none') + + +def has_agg(unit): + return unit[0] != AGG_OPS.index('none') + + +def accuracy(count, total): + if count == total: + return 1 + return 0 + + +def recall(count, total): + if count == total: + return 1 + return 0 + + +def F1(acc, rec): + if (acc + rec) == 0: + return 0 + return (2. * acc * rec) / (acc + rec) + + +def get_scores(count, pred_total, label_total): + if pred_total != label_total: + return 0,0,0 + elif count == pred_total: + return 1,1,1 + return 0,0,0 + + +def eval_sel(pred, label): + pred_sel = pred['select'][1] + label_sel = label['select'][1] + label_wo_agg = [unit[1] for unit in label_sel] + pred_total = len(pred_sel) + label_total = len(label_sel) + cnt = 0 + cnt_wo_agg = 0 + + for unit in pred_sel: + if unit in label_sel: + cnt += 1 + label_sel.remove(unit) + if unit[1] in label_wo_agg: + cnt_wo_agg += 1 + label_wo_agg.remove(unit[1]) + + return label_total, pred_total, cnt, cnt_wo_agg + + +def eval_where(pred, label): + pred_conds = [unit for unit in pred['where'][::2]] + label_conds = [unit for unit in label['where'][::2]] + label_wo_agg = [unit[2] for unit in label_conds] + pred_total = len(pred_conds) + label_total = len(label_conds) + cnt = 0 + cnt_wo_agg = 0 + + for unit in pred_conds: + if unit in label_conds: + cnt += 1 + label_conds.remove(unit) + if unit[2] in label_wo_agg: + cnt_wo_agg += 1 + label_wo_agg.remove(unit[2]) + + return label_total, pred_total, cnt, cnt_wo_agg + + +def eval_group(pred, label): + pred_cols = [unit[1] for unit in pred['groupBy']] + label_cols = [unit[1] for unit in label['groupBy']] + pred_total = len(pred_cols) + label_total = len(label_cols) + cnt = 0 + pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] + label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] + for col in pred_cols: + if col in label_cols: + cnt += 1 + label_cols.remove(col) + return label_total, pred_total, cnt + + +def eval_having(pred, label): + pred_total = label_total = cnt = 0 + if len(pred['groupBy']) > 0: + pred_total = 1 + if len(label['groupBy']) > 0: + label_total = 1 + + pred_cols = [unit[1] for unit in pred['groupBy']] + label_cols = [unit[1] for unit in label['groupBy']] + if pred_total == label_total == 1 \ + and pred_cols == label_cols \ + and pred['having'] == label['having']: + cnt = 1 + + return label_total, pred_total, cnt + + +def eval_order(pred, label): + pred_total = label_total = cnt = 0 + if len(pred['orderBy']) > 0: + pred_total = 1 + if len(label['orderBy']) > 0: + label_total = 1 + if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ + ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): + cnt = 1 + return label_total, pred_total, cnt + + +def eval_and_or(pred, label): + pred_ao = pred['where'][1::2] + label_ao = label['where'][1::2] + pred_ao = set(pred_ao) + label_ao = set(label_ao) + + if pred_ao == label_ao: + return 1,1,1 + return len(pred_ao),len(label_ao),0 + + +def get_nestedSQL(sql): + nested = [] + for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: + if type(cond_unit[3]) is dict: + nested.append(cond_unit[3]) + if type(cond_unit[4]) is dict: + nested.append(cond_unit[4]) + if sql['intersect'] is not None: + nested.append(sql['intersect']) + if sql['except'] is not None: + nested.append(sql['except']) + if sql['union'] is not None: + nested.append(sql['union']) + return nested + + +def eval_nested(pred, label): + label_total = 0 + pred_total = 0 + cnt = 0 + if pred is not None: + pred_total += 1 + if label is not None: + label_total += 1 + if pred is not None and label is not None: + cnt += Evaluator().eval_exact_match(pred, label) + return label_total, pred_total, cnt + + +def eval_IUEN(pred, label): + lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) + lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) + lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) + label_total = lt1 + lt2 + lt3 + pred_total = pt1 + pt2 + pt3 + cnt = cnt1 + cnt2 + cnt3 + return label_total, pred_total, cnt + + +def get_keywords(sql): + res = set() + if len(sql['where']) > 0: + res.add('where') + if len(sql['groupBy']) > 0: + res.add('group') + if len(sql['having']) > 0: + res.add('having') + if len(sql['orderBy']) > 0: + res.add(sql['orderBy'][0]) + res.add('order') + if sql['limit'] is not None: + res.add('limit') + if sql['except'] is not None: + res.add('except') + if sql['union'] is not None: + res.add('union') + if sql['intersect'] is not None: + res.add('intersect') + + # or keyword + ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] + if len([token for token in ao if token == 'or']) > 0: + res.add('or') + + cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] + # not keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: + res.add('not') + + # in keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: + res.add('in') + + # like keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: + res.add('like') + + return res + + +def eval_keywords(pred, label): + pred_keywords = get_keywords(pred) + label_keywords = get_keywords(label) + pred_total = len(pred_keywords) + label_total = len(label_keywords) + cnt = 0 + + for k in pred_keywords: + if k in label_keywords: + cnt += 1 + return label_total, pred_total, cnt + + +def count_agg(units): + return len([unit for unit in units if has_agg(unit)]) + + +def count_component1(sql): + count = 0 + if len(sql['where']) > 0: + count += 1 + if len(sql['groupBy']) > 0: + count += 1 + if len(sql['orderBy']) > 0: + count += 1 + if sql['limit'] is not None: + count += 1 + if len(sql['from']['table_units']) > 0: # JOIN + count += len(sql['from']['table_units']) - 1 + + ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] + count += len([token for token in ao if token == 'or']) + cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] + count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) + + return count + + +def count_component2(sql): + nested = get_nestedSQL(sql) + return len(nested) + + +def count_others(sql): + count = 0 + # number of aggregation + agg_count = count_agg(sql['select'][1]) + agg_count += count_agg(sql['where'][::2]) + agg_count += count_agg(sql['groupBy']) + if len(sql['orderBy']) > 0: + agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) + agg_count += count_agg(sql['having']) + if agg_count > 1: + count += 1 + + # number of select columns + if len(sql['select'][1]) > 1: + count += 1 + + # number of where conditions + if len(sql['where']) > 1: + count += 1 + + # number of group by clauses + if len(sql['groupBy']) > 1: + count += 1 + + return count + + +class Evaluator: + """A simple evaluator""" + def __init__(self): + self.partial_scores = None + + def eval_hardness(self, sql): + count_comp1_ = count_component1(sql) + count_comp2_ = count_component2(sql) + count_others_ = count_others(sql) + + if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: + return "easy" + elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ + (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): + return "medium" + elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ + (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ + (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): + return "hard" + else: + return "extra" + + def eval_exact_match(self, pred, label): + partial_scores = self.eval_partial_match(pred, label) + self.partial_scores = partial_scores + + for key, score in partial_scores.items(): + if score['f1'] != 1: + return 0 + + if len(label['from']['table_units']) > 1: + label_tables = sorted(label['from']['table_units']) + pred_tables = sorted(pred['from']['table_units']) + return label_tables == pred_tables + return 1 + + def eval_partial_match(self, pred, label): + res = {} + + label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) + res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) + res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_group(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_having(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_order(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_and_or(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_IUEN(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_keywords(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + return res + + +def isValidSQL(sql, db): + conn = sqlite3.connect(db) + cursor = conn.cursor() + try: + cursor.execute(sql) + except: + return False + return True + + +def print_scores(scores, etype): + turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4'] + levels = ['easy', 'medium', 'hard', 'extra', 'all', "joint_all"] + partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', + 'group', 'order', 'and/or', 'IUEN', 'keywords'] + + print "{:20} {:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels) + counts = [scores[level]['count'] for level in levels] + print "{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts) + + if etype in ["all", "exec"]: + print '===================== EXECUTION ACCURACY =====================' + this_scores = [scores[level]['exec'] for level in levels] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores) + + if etype in ["all", "match"]: + print '\n====================== EXACT MATCHING ACCURACY =====================' + exact_scores = [scores[level]['exact'] for level in levels] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores) + print '\n---------------------PARTIAL MATCHING ACCURACY----------------------' + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores) + + print '---------------------- PARTIAL MATCHING RECALL ----------------------' + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores) + + print '---------------------- PARTIAL MATCHING F1 --------------------------' + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores) + + + print "\n\n{:20} {:20} {:20} {:20} {:20} {:20}".format("", *turns) + counts = [scores[turn]['count'] for turn in turns] + print "{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts) + + if etype in ["all", "exec"]: + print '===================== TRUN XECUTION ACCURACY =====================' + this_scores = [scores[turn]['exec'] for turn in turns] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores) + + if etype in ["all", "match"]: + print '\n====================== TRUN EXACT MATCHING ACCURACY =====================' + exact_scores = [scores[turn]['exact'] for turn in turns] + print "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores) + +def evaluate(gold, predict, db_dir, etype, kmaps): + with open(gold) as f: + glist = [] + gseq_one = [] + for l in f.readlines(): + if len(l.strip()) == 0: + glist.append(gseq_one) + gseq_one = [] + else: + lstrip = l.strip().split('\t') + gseq_one.append(lstrip) + #glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] + + with open(predict) as f: + plist = [] + pseq_one = [] + for l in f.readlines(): + if len(l.strip()) == 0: + plist.append(pseq_one) + pseq_one = [] + else: + pseq_one.append(l.strip().split('\t')) + #plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] + # plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]] + # glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]] + evaluator = Evaluator() + + turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4'] + levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] + partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', + 'group', 'order', 'and/or', 'IUEN', 'keywords'] + entries = [] + scores = {} + + for turn in turns: + scores[turn] = {'count': 0, 'exact': 0.} + scores[turn]['exec'] = 0 + + for level in levels: + scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} + scores[level]['exec'] = 0 + for type_ in partial_types: + scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} + + eval_err_num = 0 + for p, g in zip(plist, glist): + scores['joint_all']['count'] += 1 + turn_scores = {"exec": [], "exact": []} + for idx, pg in enumerate(zip(p, g)): + p, g = pg + p_str = p[0] + p_str = p_str.replace("value", "1") + g_str, db = g + db_name = db + db = os.path.join(db_dir, db, db + ".sqlite") + schema = Schema(get_schema(db)) + g_sql = get_sql(schema, g_str) + hardness = evaluator.eval_hardness(g_sql) + if idx > 3: + idx = ">4" + else: + idx += 1 + turn_id = "turn " + str(idx) + scores[turn_id]['count'] += 1 + scores[hardness]['count'] += 1 + scores['all']['count'] += 1 + + try: + p_sql = get_sql(schema, p_str) + except: + # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql + p_sql = { + "except": None, + "from": { + "conds": [], + "table_units": [] + }, + "groupBy": [], + "having": [], + "intersect": None, + "limit": None, + "orderBy": [], + "select": [ + False, + [] + ], + "union": None, + "where": [] + } + eval_err_num += 1 + print("eval_err_num:{}".format(eval_err_num)) + + # rebuild sql for value evaluation + kmap = kmaps[db_name] + g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) + g_sql = rebuild_sql_val(g_sql) + g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) + p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) + p_sql = rebuild_sql_val(p_sql) + p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) + + if etype in ["all", "exec"]: + exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) + if exec_score: + scores[hardness]['exec'] += 1 + scores[turn_id]['exec'] += 1 + turn_scores['exec'].append(1) + else: + turn_scores['exec'].append(0) + + if etype in ["all", "match"]: + exact_score = evaluator.eval_exact_match(p_sql, g_sql) + partial_scores = evaluator.partial_scores + if exact_score == 0: + turn_scores['exact'].append(0) + print("{} pred: {}".format(hardness,p_str)) + print("{} gold: {}".format(hardness,g_str)) + print("") + else: + turn_scores['exact'].append(1) + scores[turn_id]['exact'] += exact_score + scores[hardness]['exact'] += exact_score + scores['all']['exact'] += exact_score + for type_ in partial_types: + if partial_scores[type_]['pred_total'] > 0: + scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores[hardness]['partial'][type_]['acc_count'] += 1 + if partial_scores[type_]['label_total'] > 0: + scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores[hardness]['partial'][type_]['rec_count'] += 1 + scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] + if partial_scores[type_]['pred_total'] > 0: + scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores['all']['partial'][type_]['acc_count'] += 1 + if partial_scores[type_]['label_total'] > 0: + scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores['all']['partial'][type_]['rec_count'] += 1 + scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] + + entries.append({ + 'predictSQL': p_str, + 'goldSQL': g_str, + 'hardness': hardness, + 'exact': exact_score, + 'partial': partial_scores + }) + + if all(v == 1 for v in turn_scores["exec"]): + scores['joint_all']['exec'] += 1 + + if all(v == 1 for v in turn_scores["exact"]): + scores['joint_all']['exact'] += 1 + + for turn in turns: + if scores[turn]['count'] == 0: + continue + if etype in ["all", "exec"]: + scores[turn]['exec'] /= scores[turn]['count'] + + if etype in ["all", "match"]: + scores[turn]['exact'] /= scores[turn]['count'] + + for level in levels: + if scores[level]['count'] == 0: + continue + if etype in ["all", "exec"]: + scores[level]['exec'] /= scores[level]['count'] + + if etype in ["all", "match"]: + scores[level]['exact'] /= scores[level]['count'] + for type_ in partial_types: + if scores[level]['partial'][type_]['acc_count'] == 0: + scores[level]['partial'][type_]['acc'] = 0 + else: + scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ + scores[level]['partial'][type_]['acc_count'] * 1.0 + if scores[level]['partial'][type_]['rec_count'] == 0: + scores[level]['partial'][type_]['rec'] = 0 + else: + scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ + scores[level]['partial'][type_]['rec_count'] * 1.0 + if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: + scores[level]['partial'][type_]['f1'] = 1 + else: + scores[level]['partial'][type_]['f1'] = \ + 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( + scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) + + print_scores(scores, etype) + + +def eval_exec_match(db, p_str, g_str, pred, gold): + """ + return 1 if the values between prediction and gold are matching + in the corresponding index. Currently not support multiple col_unit(pairs). + """ + conn = sqlite3.connect(db) + cursor = conn.cursor() + try: + cursor.execute(p_str) + p_res = cursor.fetchall() + except: + return False + + cursor.execute(g_str) + q_res = cursor.fetchall() + + def res_map(res, val_units): + rmap = {} + for idx, val_unit in enumerate(val_units): + key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) + rmap[key] = [r[idx] for r in res] + return rmap + + p_val_units = [unit[1] for unit in pred['select'][1]] + q_val_units = [unit[1] for unit in gold['select'][1]] + return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) + + +# Rebuild SQL functions for value evaluation +def rebuild_cond_unit_val(cond_unit): + if cond_unit is None or not DISABLE_VALUE: + return cond_unit + + not_op, op_id, val_unit, val1, val2 = cond_unit + if type(val1) is not dict: + val1 = None + else: + val1 = rebuild_sql_val(val1) + if type(val2) is not dict: + val2 = None + else: + val2 = rebuild_sql_val(val2) + return not_op, op_id, val_unit, val1, val2 + + +def rebuild_condition_val(condition): + if condition is None or not DISABLE_VALUE: + return condition + + res = [] + for idx, it in enumerate(condition): + if idx % 2 == 0: + res.append(rebuild_cond_unit_val(it)) + else: + res.append(it) + return res + + +def rebuild_sql_val(sql): + if sql is None or not DISABLE_VALUE: + return sql + + sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) + sql['having'] = rebuild_condition_val(sql['having']) + sql['where'] = rebuild_condition_val(sql['where']) + sql['intersect'] = rebuild_sql_val(sql['intersect']) + sql['except'] = rebuild_sql_val(sql['except']) + sql['union'] = rebuild_sql_val(sql['union']) + + return sql + + +# Rebuild SQL functions for foreign key evaluation +def build_valid_col_units(table_units, schema): + col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] + prefixs = [col_id[:-2] for col_id in col_ids] + valid_col_units= [] + for value in schema.idMap.values(): + if '.' in value and value[:value.index('.')] in prefixs: + valid_col_units.append(value) + return valid_col_units + + +def rebuild_col_unit_col(valid_col_units, col_unit, kmap): + if col_unit is None: + return col_unit + + agg_id, col_id, distinct = col_unit + if col_id in kmap and col_id in valid_col_units: + col_id = kmap[col_id] + if DISABLE_DISTINCT: + distinct = None + return agg_id, col_id, distinct + + +def rebuild_val_unit_col(valid_col_units, val_unit, kmap): + if val_unit is None: + return val_unit + + unit_op, col_unit1, col_unit2 = val_unit + col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) + col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) + return unit_op, col_unit1, col_unit2 + + +def rebuild_table_unit_col(valid_col_units, table_unit, kmap): + if table_unit is None: + return table_unit + + table_type, col_unit_or_sql = table_unit + if isinstance(col_unit_or_sql, tuple): + col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) + return table_type, col_unit_or_sql + + +def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): + if cond_unit is None: + return cond_unit + + not_op, op_id, val_unit, val1, val2 = cond_unit + val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) + return not_op, op_id, val_unit, val1, val2 + + +def rebuild_condition_col(valid_col_units, condition, kmap): + for idx in range(len(condition)): + if idx % 2 == 0: + condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) + return condition + + +def rebuild_select_col(valid_col_units, sel, kmap): + if sel is None: + return sel + distinct, _list = sel + new_list = [] + for it in _list: + agg_id, val_unit = it + new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) + if DISABLE_DISTINCT: + distinct = None + return distinct, new_list + + +def rebuild_from_col(valid_col_units, from_, kmap): + if from_ is None: + return from_ + + from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] + from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) + return from_ + + +def rebuild_group_by_col(valid_col_units, group_by, kmap): + if group_by is None: + return group_by + + return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] + + +def rebuild_order_by_col(valid_col_units, order_by, kmap): + if order_by is None or len(order_by) == 0: + return order_by + + direction, val_units = order_by + new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] + return direction, new_val_units + + +def rebuild_sql_col(valid_col_units, sql, kmap): + if sql is None: + return sql + + sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) + sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) + sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) + sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) + sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) + sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) + sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) + sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) + sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) + + return sql + + +def build_foreign_key_map(entry): + cols_orig = entry["column_names_original"] + tables_orig = entry["table_names_original"] + + # rebuild cols corresponding to idmap in Schema + cols = [] + for col_orig in cols_orig: + if col_orig[0] >= 0: + t = tables_orig[col_orig[0]] + c = col_orig[1] + cols.append("__" + t.lower() + "." + c.lower() + "__") + else: + cols.append("__all__") + + def keyset_in_list(k1, k2, k_list): + for k_set in k_list: + if k1 in k_set or k2 in k_set: + return k_set + new_k_set = set() + k_list.append(new_k_set) + return new_k_set + + foreign_key_list = [] + foreign_keys = entry["foreign_keys"] + for fkey in foreign_keys: + key1, key2 = fkey + key_set = keyset_in_list(key1, key2, foreign_key_list) + key_set.add(key1) + key_set.add(key2) + + foreign_key_map = {} + for key_set in foreign_key_list: + sorted_list = sorted(list(key_set)) + midx = sorted_list[0] + for idx in sorted_list: + foreign_key_map[cols[idx]] = cols[midx] + + return foreign_key_map + + +def build_foreign_key_map_from_json(table): + with open(table) as f: + data = json.load(f) + tables = {} + for entry in data: + tables[entry['db_id']] = build_foreign_key_map(entry) + return tables + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--gold', dest='gold', type=str) + parser.add_argument('--pred', dest='pred', type=str) + parser.add_argument('--db', dest='db', type=str) + parser.add_argument('--table', dest='table', type=str) + parser.add_argument('--etype', dest='etype', type=str) + args = parser.parse_args() + + gold = args.gold + pred = args.pred + db_dir = args.db + table = args.table + etype = args.etype + + assert etype in ["all", "exec", "match"], "Unknown evaluation method" + + kmaps = build_foreign_key_map_from_json(table) + + evaluate(gold, pred, db_dir, etype, kmaps) diff --git a/process_sql.py b/process_sql.py new file mode 100644 index 0000000..ee25b9c --- /dev/null +++ b/process_sql.py @@ -0,0 +1,566 @@ +################################ +# Assumptions: +# 1. sql is correct +# 2. only table name has alias +# 3. only one intersect/union/except +# +# val: number(float)/string(str)/sql(dict) +# col_unit: (agg_id, col_id, isDistinct(bool)) +# val_unit: (unit_op, col_unit1, col_unit2) +# table_unit: (table_type, col_unit/sql) +# cond_unit: (not_op, op_id, val_unit, val1, val2) +# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] +# sql { +# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) +# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} +# 'where': condition +# 'groupBy': [col_unit1, col_unit2, ...] +# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) +# 'having': condition +# 'limit': None/limit value +# 'intersect': None/sql +# 'except': None/sql +# 'union': None/sql +# } +################################ + +import json +import sqlite3 +from nltk import word_tokenize + +CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +JOIN_KEYWORDS = ('join', 'on', 'as') + +WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') +UNIT_OPS = ('none', '-', '+', "*", '/') +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') +TABLE_TYPE = { + 'sql': "sql", + 'table_unit': "table_unit", +} + +COND_OPS = ('and', 'or') +SQL_OPS = ('intersect', 'union', 'except') +ORDER_OPS = ('desc', 'asc') + + + +class Schema: + """ + Simple schema which maps table&column to a unique identifier + """ + def __init__(self, schema): + self._schema = schema + self._idMap = self._map(self._schema) + + @property + def schema(self): + return self._schema + + @property + def idMap(self): + return self._idMap + + def _map(self, schema): + idMap = {'*': "__all__"} + id = 1 + for key, vals in schema.iteritems(): + for val in vals: + idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" + id += 1 + + for key in schema: + idMap[key.lower()] = "__" + key.lower() + "__" + id += 1 + + return idMap + + +def get_schema(db): + """ + Get database's schema, which is a dict with table name as key + and list of column names as value + :param db: database path + :return: schema dict + """ + + schema = {} + conn = sqlite3.connect(db) + cursor = conn.cursor() + + # fetch table names + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = [str(table[0].lower()) for table in cursor.fetchall()] + + # fetch table info + for table in tables: + cursor.execute("PRAGMA table_info({})".format(table)) + schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] + + return schema + + +def get_schema_from_json(fpath): + with open(fpath) as f: + data = json.load(f) + + schema = {} + for entry in data: + table = str(entry['table'].lower()) + cols = [str(col['column_name'].lower()) for col in entry['col_data']] + schema[table] = cols + + return schema + + +def tokenize(string): + string = str(string) + string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? + quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] + assert len(quote_idxs) % 2 == 0, "Unexpected quote" + + # keep string value as token + vals = {} + for i in range(len(quote_idxs)-1, -1, -2): + qidx1 = quote_idxs[i-1] + qidx2 = quote_idxs[i] + val = string[qidx1: qidx2+1] + key = "__val_{}_{}__".format(qidx1, qidx2) + string = string[:qidx1] + key + string[qidx2+1:] + vals[key] = val + + toks = [word.lower() for word in word_tokenize(string)] + # replace with string value token + for i in range(len(toks)): + if toks[i] in vals: + toks[i] = vals[toks[i]] + + # find if there exists !=, >=, <= + eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] + eq_idxs.reverse() + prefix = ('!', '>', '<') + for eq_idx in eq_idxs: + pre_tok = toks[eq_idx-1] + if pre_tok in prefix: + toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] + + return toks + + +def scan_alias(toks): + """Scan the index of 'as' and build the map for all alias""" + as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] + alias = {} + for idx in as_idxs: + alias[toks[idx+1]] = toks[idx-1] + return alias + + +def get_tables_with_alias(schema, toks): + tables = scan_alias(toks) + for key in schema: + assert key not in tables, "Alias {} has the same name in table".format(key) + tables[key] = key + return tables + + +def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): + """ + :returns next idx, column id + """ + tok = toks[start_idx] + if tok == "*": + return start_idx + 1, schema.idMap[tok] + + if '.' in tok: # if token is a composite + alias, col = tok.split('.') + key = tables_with_alias[alias] + "." + col + return start_idx+1, schema.idMap[key] + + assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" + + for alias in default_tables: + table = tables_with_alias[alias] + if tok in schema.schema[table]: + key = table + "." + tok + return start_idx+1, schema.idMap[key] + + assert False, "Error col: {}".format(tok) + + +def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): + """ + :returns next idx, (agg_op id, col_id) + """ + idx = start_idx + len_ = len(toks) + isBlock = False + isDistinct = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] in AGG_OPS: + agg_id = AGG_OPS.index(toks[idx]) + idx += 1 + assert idx < len_ and toks[idx] == '(' + idx += 1 + if toks[idx] == "distinct": + idx += 1 + isDistinct = True + idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) + assert idx < len_ and toks[idx] == ')' + idx += 1 + return idx, (agg_id, col_id, isDistinct) + + if toks[idx] == "distinct": + idx += 1 + isDistinct = True + agg_id = AGG_OPS.index("none") + idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) + + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + + return idx, (agg_id, col_id, isDistinct) + + +def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + col_unit1 = None + col_unit2 = None + unit_op = UNIT_OPS.index('none') + + idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + if idx < len_ and toks[idx] in UNIT_OPS: + unit_op = UNIT_OPS.index(toks[idx]) + idx += 1 + idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + + return idx, (unit_op, col_unit1, col_unit2) + + +def parse_table_unit(toks, start_idx, tables_with_alias, schema): + """ + :returns next idx, table id, table name + """ + idx = start_idx + len_ = len(toks) + key = tables_with_alias[toks[idx]] + + if idx + 1 < len_ and toks[idx+1] == "as": + idx += 3 + else: + idx += 1 + + return idx, schema.idMap[key], key + + +def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] == 'select': + idx, val = parse_sql(toks, idx, tables_with_alias, schema) + elif "\"" in toks[idx]: # token is a string value + val = toks[idx] + idx += 1 + else: + try: + val = float(toks[idx]) + idx += 1 + except: + end_idx = idx + while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ + and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: + end_idx += 1 + + idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) + idx = end_idx + + if isBlock: + assert toks[idx] == ')' + idx += 1 + + return idx, val + + +def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + conds = [] + + while idx < len_: + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + not_op = False + if toks[idx] == 'not': + not_op = True + idx += 1 + + assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) + op_id = WHERE_OPS.index(toks[idx]) + idx += 1 + val1 = val2 = None + if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values + idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + assert toks[idx] == 'and' + idx += 1 + idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + else: # normal case: single value + idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + val2 = None + + conds.append((not_op, op_id, val_unit, val1, val2)) + + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): + break + + if idx < len_ and toks[idx] in COND_OPS: + conds.append(toks[idx]) + idx += 1 # skip and/or + + return idx, conds + + +def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + + assert toks[idx] == 'select', "'select' not found" + idx += 1 + isDistinct = False + if idx < len_ and toks[idx] == 'distinct': + idx += 1 + isDistinct = True + val_units = [] + + while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: + agg_id = AGG_OPS.index("none") + if toks[idx] in AGG_OPS: + agg_id = AGG_OPS.index(toks[idx]) + idx += 1 + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + val_units.append((agg_id, val_unit)) + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + + return idx, (isDistinct, val_units) + + +def parse_from(toks, start_idx, tables_with_alias, schema): + """ + Assume in the from clause, all table units are combined with join + """ + assert 'from' in toks[start_idx:], "'from' not found" + + len_ = len(toks) + idx = toks.index('from', start_idx) + 1 + default_tables = [] + table_units = [] + conds = [] + + while idx < len_: + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] == 'select': + idx, sql = parse_sql(toks, idx, tables_with_alias, schema) + table_units.append((TABLE_TYPE['sql'], sql)) + else: + if idx < len_ and toks[idx] == 'join': + idx += 1 # skip join + idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) + table_units.append((TABLE_TYPE['table_unit'],table_unit)) + default_tables.append(table_name) + if idx < len_ and toks[idx] == "on": + idx += 1 # skip on + idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + if len(conds) > 0: + conds.append('and') + conds.extend(this_conds) + + if isBlock: + assert toks[idx] == ')' + idx += 1 + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + break + + return idx, table_units, conds, default_tables + + +def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + + if idx >= len_ or toks[idx] != 'where': + return idx, [] + + idx += 1 + idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + return idx, conds + + +def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + col_units = [] + + if idx >= len_ or toks[idx] != 'group': + return idx, col_units + + idx += 1 + assert toks[idx] == 'by' + idx += 1 + + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + col_units.append(col_unit) + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + else: + break + + return idx, col_units + + +def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + val_units = [] + order_type = 'asc' # default type is 'asc' + + if idx >= len_ or toks[idx] != 'order': + return idx, val_units + + idx += 1 + assert toks[idx] == 'by' + idx += 1 + + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + val_units.append(val_unit) + if idx < len_ and toks[idx] in ORDER_OPS: + order_type = toks[idx] + idx += 1 + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + else: + break + + return idx, (order_type, val_units) + + +def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + + if idx >= len_ or toks[idx] != 'having': + return idx, [] + + idx += 1 + idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + return idx, conds + + +def parse_limit(toks, start_idx): + idx = start_idx + len_ = len(toks) + + if idx < len_ and toks[idx] == 'limit': + idx += 2 + # make limit value can work, cannot assume put 1 as a fake limit number + if type(toks[idx-1]) != int: + return idx, 1 + + return idx, int(toks[idx-1]) + + return idx, None + + +def parse_sql(toks, start_idx, tables_with_alias, schema): + isBlock = False # indicate whether this is a block of sql/sub-sql + len_ = len(toks) + idx = start_idx + + sql = {} + if toks[idx] == '(': + isBlock = True + idx += 1 + + # parse from clause in order to get default tables + from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) + sql['from'] = {'table_units': table_units, 'conds': conds} + # select clause + _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) + idx = from_end_idx + sql['select'] = select_col_units + # where clause + idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) + sql['where'] = where_conds + # group by clause + idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) + sql['groupBy'] = group_col_units + # having clause + idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) + sql['having'] = having_conds + # order by clause + idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) + sql['orderBy'] = order_col_units + # limit clause + idx, limit_val = parse_limit(toks, idx) + sql['limit'] = limit_val + + idx = skip_semicolon(toks, idx) + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + idx = skip_semicolon(toks, idx) + + # intersect/union/except clause + for op in SQL_OPS: # initialize IUE + sql[op] = None + if idx < len_ and toks[idx] in SQL_OPS: + sql_op = toks[idx] + idx += 1 + idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) + sql[sql_op] = IUE_sql + return idx, sql + + +def load_data(fpath): + with open(fpath) as f: + data = json.load(f) + return data + + +def get_sql(schema, query): + toks = tokenize(query) + tables_with_alias = get_tables_with_alias(schema.schema, toks) + _, sql = parse_sql(toks, 0, tables_with_alias, schema) + + return sql + + +def skip_semicolon(toks, start_idx): + idx = start_idx + while idx < len(toks) and toks[idx] == ";": + idx += 1 + return idx