In [1]:
import os
import re
import sys
import copy

ROOT = os.path.join(os.path.dirname(os.getcwd()), 'backend')
if ROOT not in sys.path:
    sys.path.append(ROOT)
    
from moz_sql_parser import parse as ps
sys.setrecursionlimit(1000000)
import json
import os
import pandas as pd
from termcolor import colored
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import AgglomerativeClustering
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

In [2]:
from app.dataService.dataService import DataService

In [3]:
# Generate SQL and the execution results from the plain text
text = "films and film prices that cost below 10 dollars"
db_id = "cinema"
dataService = DataService("spider")

=== begin loading model ===
=== finish loading model ===


In [4]:
# read spider dataset
dataset_type = "train_spider"
# dataset_type="dev"
with open(f"../backend/app/data/dataset/spider/{dataset_type}.json", "r") as f:
    spider_data = json.load(f)
df = pd.DataFrame(spider_data)

In [14]:
df.head()

Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql
0,department_management,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...",How many heads of the departments are older th...,"[How, many, heads, of, the, departments, are, ...","{'from': {'table_units': [['table_unit', 1]], ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","List the name, born state and age of the heads...","[List, the, name, ,, born, state, and, age, of...","{'from': {'table_units': [['table_unit', 1]], ..."
2,department_management,"SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","List the creation year, name and budget of eac...","[List, the, creation, year, ,, name, and, budg...","{'from': {'table_units': [['table_unit', 0]], ..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...",What are the maximum and minimum budget of the...,"[What, are, the, maximum, and, minimum, budget...","{'from': {'table_units': [['table_unit', 0]], ..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...",What is the average number of employees of the...,"[What, is, the, average, number, of, employees...","{'from': {'table_units': [['table_unit', 0]], ..."


## Parse SQL using spider parser
### Assumptions:
  1. sql is correct
  2. only table name has alias
  3. only one intersect/union/except

- `val`: number(float)/string(str)/sql(dict)
- `col_unit`: (agg_id, col_id, isDistinct(bool))
- `val_unit`: (unit_op, col_unit1, col_unit2)
- `table_unit`: (table_type, table_id/sql) :table_type: table_unit, sql
- `cond_unit`: (not_op, op_id, val_unit, val1, val2)
- `condition`: [cond_unit1, 'and'/'or', cond_unit2, ...]

```
sql {
  'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
  'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
  'where': condition
  'groupBy': [col_unit1, col_unit2, ...]
  'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
  'having': condition
  'limit': None/limit value
  'intersect': None/sql
  'except': None/sql
  'union': None/sql
}
```

In [15]:
from app.dataService.utils.processSQL import decode_sql

for rowid, row in df.iterrows():
    if (dataset_type=="train_spider" and rowid != 3153) or dataset_type == "dev":
        db_id = row["db_id"]
        sql = row["query"]
        parsed = dataService.parsesql(sql, db_id)
        table = parsed["table"]
        sql_parse = parsed["sql_parse"]
        
        # decode sql (whole statement)
        sql_decoded = decode_sql.decode_sql(sql_parse, table)
        
        # decode select
        select_decoded = decode_sql.decode_select(sql_parse["select"], table)
        
        # decode from
        from_decoded = decode_sql.decode_from(sql_parse["from"], table)
        
        # decode where
        where_decoded = decode_sql.decode_where(sql_parse["where"], table)
        
        # decode groupBy
        groupby_decoded = decode_sql.decode_groupby(sql_parse["groupBy"], table)
        
        # decode orderBy
        orderby_decoded = decode_sql.decode_orderby(sql_parse["orderBy"], table)
        
        # decode having
        having_decoded = decode_sql.decode_having(sql_parse["having"], table)
        
        # decode limit
        limit_decoded = decode_sql.decode_limit(sql_parse["limit"], table)
        
        # decode intersect
        intersect_decoded = decode_sql.decode_intersect(sql_parse["intersect"], table)
        
        # decode except
        except_decoded = decode_sql.decode_except(sql_parse["except"], table)
        
        # decode union
        union_decoded = decode_sql.decode_union(sql_parse["union"], table)
        
        if rowid == 0:
            print(f"sql_decoded: {sql_decoded}")
            print("*"*10)
            print(f"select_decoded: {select_decoded}")
            print("*"*10)
            print(f"from_decoded: {from_decoded}")
            print("*"*10)
            print(f"where_decoded: {where_decoded}")
            print("*"*10)
            print(f"groupby_decoded: {groupby_decoded}")
            print("*"*10)
            print(f"orderby_decoded: {orderby_decoded}")
            print("*"*10)
            print(f"having_decoded: {having_decoded}")
            print("*"*10)
            print(f"limit_decoded: {limit_decoded}")
            print("*"*10)
            print(f"intersect_decoded: {intersect_decoded}")
            print("*"*10)
            print(f"except_decoded: {except_decoded}")
            print("*"*10)
            print(f"union_decoded: {union_decoded}")
            print("*"*10)

sql_decoded: {'select': ('', [('none', ('none', 'management: *', ''), None)]), 'from': {'table_units': [('table_unit', 'head')], 'conds': []}, 'where': [(False, '>', ('none', ('none', 'head: age', ''), None), 56.0, None)], 'groupBy': [], 'orderBy': [], 'having': [], 'limit': None, 'intersect': None, 'except': None, 'union': None}
**********
select_decoded: ('', [('none', ('none', 'management: *', ''), None)])
**********
from_decoded: {'table_units': [('table_unit', 'head')], 'conds': []}
**********
where_decoded: [(False, '>', ('none', ('none', 'head: age', ''), None), 56.0, None)]
**********
groupby_decoded: []
**********
orderby_decoded: []
**********
having_decoded: []
**********
limit_decoded: None
**********
intersect_decoded: None
**********
except_decoded: None
**********
union_decoded: None
**********
