In [1]:
import os
import re
import sys
import copy

ROOT = os.path.join(os.path.dirname(os.getcwd()), 'backend')
if ROOT not in sys.path:
    sys.path.append(ROOT)
    
import json
import os
import pandas as pd
from termcolor import colored
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import AgglomerativeClustering
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

In [2]:
from app.dataService.dataService import DataService

In [3]:
# Generate SQL and the execution results from the plain text
text = "films and film prices that cost below 10 dollars"
db_id = "cinema"
dataService = DataService("spider")

In [4]:
# read spider dataset
dataset_type = "train_spider"
# dataset_type="dev"
with open(f"../backend/app/data/dataset/spider/{dataset_type}.json", "r") as f:
    spider_data = json.load(f)
df = pd.DataFrame(spider_data)

In [5]:
df.head()

Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql
0,department_management,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...",How many heads of the departments are older th...,"[How, many, heads, of, the, departments, are, ...","{'from': {'table_units': [['table_unit', 1]], ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","List the name, born state and age of the heads...","[List, the, name, ,, born, state, and, age, of...","{'from': {'table_units': [['table_unit', 1]], ..."
2,department_management,"SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","List the creation year, name and budget of eac...","[List, the, creation, year, ,, name, and, budg...","{'from': {'table_units': [['table_unit', 0]], ..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...",What are the maximum and minimum budget of the...,"[What, are, the, maximum, and, minimum, budget...","{'from': {'table_units': [['table_unit', 0]], ..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...",What is the average number of employees of the...,"[What, is, the, average, number, of, employees...","{'from': {'table_units': [['table_unit', 0]], ..."


## Parse SQL using spider parser
### Assumptions:
  1. sql is correct
  2. only table name has alias
  3. only one intersect/union/except

- `val`: number(float)/string(str)/sql(dict)
- `col_unit`: (agg_id, col_id, isDistinct(bool))
- `val_unit`: (unit_op, col_unit1, col_unit2)
- `table_unit`: (table_type, table_id/sql) :table_type: table_unit, sql
- `cond_unit`: (not_op, op_id, val_unit, val1, val2)
- `condition`: [cond_unit1, 'and'/'or', cond_unit2, ...]

```
sql {
  'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
  'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
  'where': condition
  'groupBy': [col_unit1, col_unit2, ...]
  'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
  'having': condition
  'limit': None/limit value
  'intersect': None/sql
  'except': None/sql
  'union': None/sql
}
```

In [6]:
agg_dict = {
    'none': '',
    'max': 'the maximum of ',
    'min': 'the minimum of ',
    'count': 'the number of ',
    'sum': 'the sum of ',
    'avg': 'the average number of '
}

In [7]:
where_dict = {
    '=': 'equals to', 
    '>': 'is larger than', 
    '<': 'is smaller than', 
    '>=': 'is larger than or equals to', 
    '<=': 'is smaller than or equals to', 
    '!=': 'not equals to',
    'between': 'is between',
    'where': 'where',
    'in': 'in',
    'like': 'like', 
    'is': 'is', 
    'exists': 'exists',
}

In [8]:
def is_none(token):
    return token is None or token == 'none' or token == ''

In [9]:
def col_id2text(col_id, agg_id=None):
    col_text = col_id.replace(':', '\'s')
    if agg_id == 'count' and '*' in col_text:
        col_text = col_text.replace('*', 'entries')
    else:
        col_text = col_text.replace('*', 'all information')
    return col_text

def col_unit2text(col_unit):
    if is_none(col_unit):
        return ''
    agg_id, col_id, isDistinct = col_unit
    # TODO: isDistinct is ignored
    return agg_dict[agg_id]+col_id2text(col_id, agg_id)

In [10]:
def val_unit2text(val_unit):
    if is_none(val_unit):
        return ''
    unit_op, col_unit1, col_unit2 = val_unit
    if is_none(unit_op):
        return col_unit2text(col_unit1)
    else:
        return col_unit2text(col_unit1)+" {} ".format(unit_op)+col_unit2text(col_unit2)

In [11]:
val_unit2text(('none', ('none', 'department: ranking', ''), None))

"department's ranking"

In [12]:
def cond_unit2text(cond_unit):
    if is_none(cond_unit):
        return ''
    not_op, op_id, val_unit, val1, val2 = cond_unit
    not_text = " not " if is_none(not_op) else ''
    val2_text = "" if is_none(val2) else ' and {}'.format(val2)
    return "{}{} {} {}{}".format(val_unit2text(val_unit), not_text, 
                                 where_dict[op_id], val1, val2_text)

In [13]:
cond_unit2text((False,
  'between',
  ('none', ('none', 'department: ranking', ''), None),
  10.0,
  15.0))

"department's ranking is between 10.0 and 15.0"

In [14]:
def condition2text(condition):
    where_sentence = ""
    for i, cond_unit in enumerate(where_decoded):
        if i % 2 == 1:
            where_sentence += " {} ".format(cond_unit)
        else:
            where_sentence += cond_unit2text(cond_unit)
    return where_sentence

In [16]:
def select_unit2text(select_unit):
    if is_none(select_unit):
        return ''
    agg_id, val_unit = select_unit
    return "{}{}".format(agg_dict[agg_id], val_unit2text(val_unit))

In [17]:
def select2text(select):
    return ', '.join([select_unit2text(select_unit) for select_unit in select[1]])

In [18]:
example_select = ('', [('count', ('none', ('none', '*', ''), None))])

select2text(example_select)

'the number of all information'

In [19]:
for i in range(20): 
    row = df.loc[i]
    db_id = row["db_id"]
    sql = row["query"]
    parsed = dataService.parsesql(sql, db_id)
    table = parsed["table"]
    sql_parse = parsed["sql_parse"]

    # decode sql (whole statement)
    sql_decoded = decode_sql.decode_sql(sql_parse, table)
    where_decoded = decode_sql.decode_where(sql_parse["where"], table)
    
    select_sentence = select2text(sql_decoded["select"])
    where_sentence = condition2text(where_decoded)
    print(row['query'])
    if where_sentence == "":
        print("Find {}.".format(select_sentence))
    else:
        print("Find {} where {}.".format(select_sentence, where_sentence))
    print()

=== begin loading sql parser ===
=== finish loading sql parser ===


NameError: name 'decode_sql' is not defined