In [1]:
import os
import re
import json
import pickle
import random
from template_config import *
from collections import defaultdict
from nltk.stem.porter import PorterStemmer
from nltk.stem.wordnet import WordNetLemmatizer

In [13]:
import nltk
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /home/taoyu/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


True

In [2]:
ps = PorterStemmer()
lmtzr = WordNetLemmatizer()

In [5]:
def read_in_all_data(data_path=DATA_PATH):
    training_data = json.load(open(os.path.join(data_path, "train_spider.json")))
    tables_org = json.load(open(os.path.join(data_path, "tables.json")))
    tables = {tab['db_id']: tab for tab in tables_org}

    return training_data, tables

In [6]:
def get_all_question_query_pairs(data):
    question_query_pairs = []
    for item in data:
        question_query_pairs.append((item['question_toks'], item['query'], item['db_id']))
    return question_query_pairs

In [7]:
training_data, tables = read_in_all_data("data")

train_qq_pairs = get_all_question_query_pairs(training_data)

In [9]:
print("Training question-query pair count: {}".format(len(train_qq_pairs)))

Training question-query pair count: 7000


In [10]:
def is_value(token):
    """
    as values can either be a numerical digit or a string literal, then we can
    detect if a token is a value by matching with regex
    """
    is_number = True
    try:
        float(token)
    except ValueError:
        is_number = False
    is_string = token.startswith("\"") or token.startswith("\'") or token.endswith("\"") or token.endswith("\'")

    return is_number or is_string


def remove_all_from_clauses(query_keywords):
    """
    remove all keywords from from clauses, until there is no more from clauses
    e.g. select {} from {} as {} where {} = {} --> select {} where {} = {}
    """
    # remove from clause by deleting the range from "FROM" to "WHERE" or "GROUP"
    start_location = 0
    count = 0
    while "FROM" in query_keywords:
        count += 1
        if count > 5:
            break
            print("error query_keywords: ", query_keywords)
        from_location = query_keywords.index("FROM")
        end_token_locations = [len(query_keywords)]  # defaulting to the end of the list
        for end_token in ["WHERE", "GROUP", "ORDER"]:
            try:
                end_token_locations.append(query_keywords.index(end_token, start_location))
            except ValueError:
                pass

        query_keywords = query_keywords[:from_location] + [FROM_SYMBOL] + query_keywords[min(end_token_locations):]
        start_location = min(end_token_locations)
        
    return query_keywords


def strip_query(query, table):
    """
    returns (stripped query, non keywords)
    """
    #get table column names info
    column_types = table['column_types']
    table_names_original = [cn.lower() for cn in table['table_names_original']]
    table_names = [cn.lower() for cn in table['table_names']]
    column_names = [cn.lower() for i, cn in table['column_names']]
    column_names_original = [cn.lower() for i, cn in table['column_names_original']]

    #clean query: replace values, numbers, column names with SYMBOL
    query_keywords = []
    columns = table_names_original + table_names

    query = query.replace(";","")
    query = query.replace("\t","")
    query = query.replace("(", " ( ").replace(")", " ) ")
    # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE}
    str_1 = re.findall("\"[^\"]*\"", query)
    str_2 = re.findall("\'[^\']*\'", query)
    values = str_1 + str_2
    for val in values:
        query = query.replace(val.strip(), VALUE_STR_SYMBOL)

    query_tokenized = query.split(' ')
    float_nums = re.findall("[-+]?\d*\.\d+", query)
    query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized]
    query = " ".join(query_tokenized)
    int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)]
    query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized]
    nums = float_nums + int_nums
        
    #query_tokenized = query.split(' ')
    cols_dict = {}
    for token in query_tokenized:
        if len(token.strip()) == 0:  # in case there are more than one space used
            continue
        if IGNORE_COMMAS_AND_ROUND_BRACKETS:
            keywords_dict = SQL_KEYWORDS_AND_OPERATORS_WITHOUT_COMMAS_AND_BRACES
        else:
            keywords_dict = SQL_KEYWORDS_AND_OPERATORS

        if token.upper() not in keywords_dict and token != VALUE_STR_SYMBOL and token != VALUE_NUM_SYMBOL:
            token = token.upper()
            if USE_COLUMN_AND_VALUE_REPLACEMENT_TOKEN:
                token = re.sub("[T]\d+\.", '', token)
                token = re.sub(r"\"|\'", '', token)
                token = re.sub("[T]\d+", '', token).lower()
#                 if token in table_names_original:
#                     query_keywords.append(TABLE_SYMBOL)
#                     continue
                if token != '' and token in column_names_original:
                    try:
                        tok_ind = column_names_original.index(token)
                    except:
                        print("\ntable: {}".format(table['db_id']))
                        print("\ntoken: {}".format(token))
                        print("column_names_original: {}".format(column_names_original))
                        print("query: {}".format(query))
                        print("query_tokenized: {}".format(query_tokenized))
                        exit(1)
                    col_type = column_types[tok_ind]
                    col_name = column_names[tok_ind]
                    columns.append(col_name)
                    columns.append(token)
                    if token not in cols_dict:
                        cols_dict[token] = COLUMN_SYMBOL.replace("}", str(len(cols_dict)))
                    query_keywords.append(cols_dict[token])
                elif token in table_names_original:
                    query_keywords.append(TABLE_SYMBOL)
                    continue
                    
        else:
            query_keywords.append(token.upper())

    if "FROM" in query_keywords:
        query_keywords = remove_all_from_clauses(query_keywords)

    if USE_LIMITED_KEYWORD_SET:
        query_keywords = [kw for kw in query_keywords if kw in LIMITED_KEYWORD_SET]

    columns_lemed = [lmtzr.lemmatize(w) for w in " ".join(columns).split(" ") if w not in LOW_CHAR]
    columns_lemed_stemed = [ps.stem(w) for w in columns_lemed]

    return " ".join(query_keywords), values, nums, columns_lemed_stemed


def filter_string(cs):
    return "".join([c.upper() for c in cs if c.isalpha() or c == ' '])


def process_question(question, values, nums, columns):

    question = " ".join(question).lower()
    values = [re.sub(r"\"|\'", '', val) for val in values]
    for val in values:
        val = val.lower()
        try:
            question = re.sub(r'\b'+val+r'\b', VALUE_STR_SYMBOL, question)
        except:
            continue

    for num in nums:
        num = num.strip()
        question = re.sub(r'\b'+num+r'\b', VALUE_NUM_SYMBOL, question)

    question_toks = question.split(" ")
    question_lemed = [lmtzr.lemmatize(w) for w in question_toks]
    question_lemed_stemed = [ps.stem(w) for w in question_lemed]
    replace_inds = [i for i, qt in enumerate(question_lemed_stemed) if qt in columns]
    #print("question_stemed: {}".format(question_stemed))
    #print("replace_inds: {}".format(replace_inds))
    for ind in replace_inds:
        question_toks[ind] = COLUMN_SYMBOL

    question_template = ' '.join(question_toks)

    return question_template

In [11]:
KEY_KEYWORD_SET = {"SELECT", "WHERE", "GROUP", "HAVING", "ORDER", "BY", "LIMIT", "EXCEPT", "UNION", "INTERSECT"}
ALL_KEYWORD_SET = {"SELECT", "WHERE", "GROUP", "HAVING", "DESC", "ORDER", "BY", "LIMIT", "EXCEPT", "UNION", 
                   "INTERSECT", "NOT", "IN", "OR", "LIKE", "(", ">", ")", "COUNT"}

WHERE_OPS = ['=', '>', '<', '>=', '<=', '!=', 'LIKE', 'IS', 'EXISTS']
AGG_OPS = ['MAX', 'MIN', 'SUM', 'AVG']
DASC = ['ASC', 'DESC']
def general_pattern(pattern):
    general_pattern_list = []
    for x in pattern.split(" "):
        if x in KEY_KEYWORD_SET:
            general_pattern_list.append(x)
    
    return " ".join(general_pattern_list)

def sub_pattern(pattern):
    general_pattern_list = []
    for x in pattern.split(" "):
        if x in ALL_KEYWORD_SET:
            general_pattern_list.append(x)
    
    return " ".join(general_pattern_list)

def tune_pattern(pattern):
    general_pattern_list = []
    cols_dict = {}
    for x in pattern.split(" "):
        if "{COLUMN" in x:
            if x not in cols_dict:
                cols_dict[x] = COLUMN_SYMBOL.replace("}", str(len(cols_dict))+"}")
            general_pattern_list.append(cols_dict[x])
            continue
            
        if "{VALUE" in x:
            general_pattern_list.append("{VALUE}")
            continue
            
        if x == 'DISTINCT':
            continue
        elif x in DASC:
            general_pattern_list.append("{DASC}")
        elif x in WHERE_OPS:
            general_pattern_list.append("{OP}")
        elif x in AGG_OPS:
            general_pattern_list.append("{AGG}")
        else:
            general_pattern_list.append(x)
    
    return " ".join(general_pattern_list)

In [14]:
training_question_pattern_pairs = []
training_patterns = set()

pattern_question_dict = defaultdict(list)

# train_qq_pairs
for eid, (question, query, bd_id) in enumerate(train_qq_pairs):
    table = tables[bd_id]
    if eid % 500 == 0:
        print("processing eid: ", eid)
    
    pattern, values, nums, columns = strip_query(query, table)
    question_template = process_question(question, values, nums, columns)
    
    gen_pattern = general_pattern(pattern)
    more_pattern = sub_pattern(pattern)
    tu_pattern = tune_pattern(pattern)
    
    pattern_question_dict[tu_pattern].append(' '.join(question) + " ||| " + 
                                              question_template + " ||| " + more_pattern
                                              + " ||| " + query)
#     print("\n--------------------------------------")
#     print("original question: {}".format(' '.join(question).encode('utf-8')))
#     print("question: {}".format(question_template.encode('utf-8')))
#     print("query: {}".format(query))
#     print("pattern: {}".format(pattern))
#     print("values: {}".format(values))
#     print("nums: {}".format(nums))
#     print("columns: {}".format(columns))

processing eid:  0
processing eid:  500
processing eid:  1000
processing eid:  1500
processing eid:  2000
processing eid:  2500
processing eid:  3000
processing eid:  3500
processing eid:  4000
processing eid:  4500
processing eid:  5000
processing eid:  5500
processing eid:  6000
processing eid:  6500


In [15]:
print("total pattern number: {}".format(len(pattern_question_dict)))
pattern_question_dict = sorted(pattern_question_dict.items(), key=lambda kv: len(kv[1]), reverse=True)

total pattern number: 517


In [17]:
# filter_nums = [762, 275, 241, 204, 202, 164, 98, 59, 55, 48]

In [16]:
for sql, qts in pattern_question_dict:
#     if len(qts) not in filter_nums:
#         continue
    print("\n--------------------------------------------")
    print("SQL Pattern: {}".format(sql))
    print("count: {}".format(len(qts)))
    for qt in qts:
        q, q_template, sql, sql_more = qt.split("|||")
        print("question: ", q.replace("""'""", "").replace("""``""", ''))
#         print("question: ", q_template.replace("""'""", "").replace("""``""", ''))
        print("SQL: {} \n".format(sql_more))
#     for qt in qts:
#         q, q_template, sql_temp, sql_more = qt.split("|||")
#     #     print("question: ", q_template)
#     #     print("sql_temp: ", sql_temp)
#     #     print("sql_more: ", sql_more)
#         if sql == 'SELECT {COLUMN0} {FROM} WHERE {COLUMN4} {OP} {VALUE_STR} AND {COLUMN5} {OP} {VALUE_STR}':
#             print(sql_more)



--------------------------------------------
SQL Pattern: SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE}
count: 794
question:  What are the distinct creation years of the departments managed by a secretary born in state Alabama  ? 
SQL:  SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama' 

question:  What are the distinct ages of the heads who are acting ? 
SQL:  SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id  =  T2.head_id WHERE T2.temporary_acting  =  'Yes' 

question:  Please show the themes of competitions with host cities having populations larger than 1000 . 
SQL:  SELECT T2.Theme FROM city AS T1 JOIN farm_competition AS T2 ON T1.City_ID  =  T2.Host_city_ID WHERE T1.Population  >  1000 

question:  What are the themes of competitions that have corresponding host cities with more than 1000 residents ? 
SQL: 

In [17]:
for sql, qts in pattern_question_dict:
    print("\n")
    print("SQL Pattern: {}".format(sql))
    print("count: ", len(qts))



SQL Pattern: SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE}
count:  794


SQL Pattern: SELECT {COLUMN0} {FROM}
count:  275


SQL Pattern: SELECT {COLUMN0} , {COLUMN1} {FROM} WHERE {COLUMN2} {OP} {VALUE}
count:  260


SQL Pattern: SELECT COUNT ( * ) {FROM}
count:  241


SQL Pattern: SELECT {COLUMN0} {FROM} GROUP BY {COLUMN0} ORDER BY COUNT ( * ) {DASC} LIMIT {VALUE}
count:  239


SQL Pattern: SELECT {COLUMN0} , {COLUMN1} {FROM}
count:  210


SQL Pattern: SELECT {COLUMN0} , COUNT ( * ) {FROM} GROUP BY {COLUMN0}
count:  208


SQL Pattern: SELECT COUNT ( * ) {FROM} WHERE {COLUMN0} {OP} {VALUE}
count:  206


SQL Pattern: SELECT {COLUMN0} {FROM} ORDER BY {COLUMN1} {DASC} LIMIT {VALUE}
count:  184


SQL Pattern: SELECT COUNT ( {COLUMN0} ) {FROM}
count:  166


SQL Pattern: SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE} INTERSECT SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE}
count:  151


SQL Pattern: SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE} AND {COLUMN2} {OP} {

In [18]:
for sql_template, qts in pattern_question_dict:
    print("\n--------------------------------------")
    print("SQL Pattern: {}".format(sql_template))
    print("count: ", len(qts))
    sql_dict = defaultdict(int)
    for qt in qts:
        q, q_template, sql, sql_more = qt.split("|||")
        sql_dict[sql] += 1
        
    sql_count = sorted(sql_dict.items(), key=lambda kv: kv[1])
    for sql, count in sql_count:
        print("SQL: {}, count: {}".format(sql, count))


--------------------------------------
SQL Pattern: SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE}
count:  794
SQL:  SELECT WHERE LIKE , count: 47
SQL:  SELECT WHERE > , count: 103
SQL:  SELECT WHERE , count: 644

--------------------------------------
SQL Pattern: SELECT {COLUMN0} {FROM}
count:  275
SQL:  SELECT , count: 275

--------------------------------------
SQL Pattern: SELECT {COLUMN0} , {COLUMN1} {FROM} WHERE {COLUMN2} {OP} {VALUE}
count:  260
SQL:  SELECT WHERE LIKE , count: 15
SQL:  SELECT WHERE > , count: 22
SQL:  SELECT WHERE , count: 223

--------------------------------------
SQL Pattern: SELECT COUNT ( * ) {FROM}
count:  241
SQL:  SELECT COUNT ( ) , count: 241

--------------------------------------
SQL Pattern: SELECT {COLUMN0} {FROM} GROUP BY {COLUMN0} ORDER BY COUNT ( * ) {DASC} LIMIT {VALUE}
count:  239
SQL:  SELECT GROUP BY ORDER BY COUNT ( ) LIMIT , count: 22
SQL:  SELECT GROUP BY ORDER BY COUNT ( ) DESC LIMIT , count: 217

--------------------------------

In [19]:
for sql_template, qts in pattern_question_dict:
    print("\n--------------------------------------")
    print("SQL Pattern: {}".format(sql_template))
    print("count: ", len(qts))
    for qt in qts:
        q, q_template, sql, sql_more = qt.split("|||")
        print("question: ", q)
        print("SQL: {} \n".format(sql_more))


--------------------------------------
SQL Pattern: SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} {VALUE}
count:  794
question:  What are the distinct creation years of the departments managed by a secretary born in state 'Alabama ' ? 
SQL:  SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama' 

question:  What are the distinct ages of the heads who are acting ? 
SQL:  SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id  =  T2.head_id WHERE T2.temporary_acting  =  'Yes' 

question:  Please show the themes of competitions with host cities having populations larger than 1000 . 
SQL:  SELECT T2.Theme FROM city AS T1 JOIN farm_competition AS T2 ON T1.City_ID  =  T2.Host_city_ID WHERE T1.Population  >  1000 

question:  What are the themes of competitions that have corresponding host cities with more than 1000 residents ? 
SQL:  SE