In [None]:
import os
import re
import sys
import copy

ROOT = os.path.join(os.path.dirname(os.getcwd()), 'backend')
if ROOT not in sys.path:
    sys.path.append(ROOT)
    
from moz_sql_parser import parse as ps
sys.setrecursionlimit(1000000)
import json
import os
import pandas as pd
from termcolor import colored
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import AgglomerativeClustering
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

In [None]:
from app.dataService.dataService import DataService

In [None]:
# Generate SQL and the execution results from the plain text
text = "films and film prices that cost below 10 dollars"
db_id = "cinema"
dataService = DataService("spider")

In [None]:
# read spider dataset
dataset_type = "train_spider"
# dataset_type="dev"
with open(f"../backend/app/data/dataset/spider/{dataset_type}.json", "r") as f:
    spider_data = json.load(f)
df = pd.DataFrame(spider_data)

In [None]:
df.head()

## Parse SQL using spider parser
### Assumptions:
  1. sql is correct
  2. only table name has alias
  3. only one intersect/union/except

- `val`: number(float)/string(str)/sql(dict)
- `col_unit`: (agg_id, col_id, isDistinct(bool))
- `val_unit`: (unit_op, col_unit1, col_unit2)
- `table_unit`: (table_type, table_id/sql) :table_type: table_unit, sql
- `cond_unit`: (not_op, op_id, val_unit, val1, val2)
- `condition`: [cond_unit1, 'and'/'or', cond_unit2, ...]

```
sql {
  'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
  'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
  'where': condition
  'groupBy': [col_unit1, col_unit2, ...]
  'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
  'having': condition
  'limit': None/limit value
  'intersect': None/sql
  'except': None/sql
  'union': None/sql
}
```

In [None]:
from app.dataService.utils.processSQL import decode_sql

for rowid, row in df.iterrows():
    if (dataset_type=="train_spider" and rowid != 3153) or dataset_type == "dev":
        db_id = row["db_id"]
        sql = row["query"]
        parsed = dataService.parsesql(sql, db_id)
        table = parsed["table"]
        sql_parse = parsed["sql_parse"]
        
        # decode sql (whole statement)
        sql_decoded = decode_sql.decode_sql(sql_parse, table)
        
        # decode select
        ### IMPORTANT: here we use whole sql_data as input since * may be used in the select clause, 
        ### we need from clause to identify the corresponding tables
        select_decoded = decode_sql.decode_select(sql_parse, table)
        
        # decode from
        from_decoded = decode_sql.decode_from(sql_parse["from"], table)
        
        # decode where
        where_decoded = decode_sql.decode_where(sql_parse["where"], table)
        
        # decode groupBy
        groupby_decoded = decode_sql.decode_groupby(sql_parse["groupBy"], table)
        
        # decode orderBy
        orderby_decoded = decode_sql.decode_orderby(sql_parse["orderBy"], table)
        
        # decode having
        having_decoded = decode_sql.decode_having(sql_parse["having"], table)
        
        # decode limit
        limit_decoded = decode_sql.decode_limit(sql_parse["limit"], table)
        
        # decode intersect
        intersect_decoded = decode_sql.decode_intersect(sql_parse["intersect"], table)
        
        # decode except
        except_decoded = decode_sql.decode_except(sql_parse["except"], table)
        
        # decode union
        union_decoded = decode_sql.decode_union(sql_parse["union"], table)
        
        if rowid == 0:
            print(sql)
            print(f"sql_decoded: {sql_decoded}")
            print("*"*10)
            print(f"select_decoded: {select_decoded}")
            print("*"*10)
            print(f"from_decoded: {from_decoded}")
            print("*"*10)
            print(f"where_decoded: {where_decoded}")
            print("*"*10)
            print(f"groupby_decoded: {groupby_decoded}")
            print("*"*10)
            print(f"orderby_decoded: {orderby_decoded}")
            print("*"*10)
            print(f"having_decoded: {having_decoded}")
            print("*"*10)
            print(f"limit_decoded: {limit_decoded}")
            print("*"*10)
            print(f"intersect_decoded: {intersect_decoded}")
            print("*"*10)
            print(f"except_decoded: {except_decoded}")
            print("*"*10)
            print(f"union_decoded: {union_decoded}")
            print("*"*10)