In [None]:
import os
import re
import sys
import copy

ROOT = os.path.join(os.path.dirname(os.getcwd()), 'backend')
if ROOT not in sys.path:
    sys.path.append(ROOT)
    
import json
import os
import pandas as pd
from termcolor import colored
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import AgglomerativeClustering
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

In [None]:
from app.dataService.dataService import DataService

In [None]:
# Generate SQL and the execution results from the plain text
text = "films and film prices that cost below 10 dollars"
db_id = "cinema"
dataService = DataService("spider")

In [None]:
# read spider dataset
dataset_type = "train_spider"
# dataset_type="dev"
with open(f"../backend/app/data/dataset/spider/{dataset_type}.json", "r") as f:
    spider_data = json.load(f)
df = pd.DataFrame(spider_data)

In [None]:
df.head()

## Parse SQL using spider parser
### Assumptions:
  1. sql is correct
  2. only table name has alias
  3. only one intersect/union/except

- `val`: number(float)/string(str)/sql(dict)
- `col_unit`: (agg_id, col_id, isDistinct(bool))
- `val_unit`: (unit_op, col_unit1, col_unit2)
- `table_unit`: (table_type, table_id/sql) :table_type: table_unit, sql
- `cond_unit`: (not_op, op_id, val_unit, val1, val2)
- `condition`: [cond_unit1, 'and'/'or', cond_unit2, ...]

```
sql {
  'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
  'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
  'where': condition
  'groupBy': [col_unit1, col_unit2, ...]
  'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
  'having': condition
  'limit': None/limit value
  'intersect': None/sql
  'except': None/sql
  'union': None/sql
}
```

In [None]:
style = \
"""
<style>
    span.column-id {
        background: gold;
    }
    span.entity-id {
        background: #aaa;
    }
</style>
"""

In [None]:
from app.dataService.utils.processSQL import decode_sql, sql2text
from IPython.display import display


for i in range(20): 
    row = df.loc[i]
    parsed = dataService.parsesql(row["query"], row["db_id"])

    # decode sql (whole statement)
    sql_decoded = decode_sql(parsed["sql_parse"], parsed["table"])
    display({'text/html': style + sql2text(sql_decoded)}, raw=True)