In [1]:
# coding=utf-8
from __future__ import print_function

import json
import os
import pickle
import sys
from itertools import chain

import numpy as np

from asdl.asdl import ASDLGrammar
from asdl.hypothesis import Hypothesis
from datasets.wikisql.lib.common import detokenize
from datasets.wikisql.lib.dbengine import DBEngine
from datasets.wikisql.lib.query import Query
from asdl.lang.sql.sql_transition_system import SqlTransitionSystem, sql_query_to_asdl_ast, asdl_ast_to_sql_query
from datasets.wikisql.utils import my_detokenize, find_sub_sequence
from asdl.transition_system import GenTokenAction
from components.action_info import ActionInfo
from components.vocab import VocabEntry, Vocab
from model.wikisql.dataset import WikiSqlExample, WikiSqlTable, TableColumn

In [2]:
from datasets.wikisql.dataset import load_dataset
from datasets.wikisql.dataset import get_action_infos

In [3]:
engine = DBEngine('tranx.0.2.0/data/wikisql/train.db')
data_file = 'dsl_parser/my_wikisql/wikisql_train.json'
grammar = ASDLGrammar.from_text(open('asdl/lang/sql/sql_asdl.txt').read())

transition_system = SqlTransitionSystem(grammar)

In [4]:
parser = SqlTransitionSystem(grammar)

In [5]:
query = Query('1', 2)
query

SELECT MIN col1 FROM table

In [6]:
query.to_dict()

{'sel': '1', 'agg': 2, 'conds': []}

In [7]:
asdl_ast = sql_query_to_asdl_ast(query, grammar)
asdl_ast.sanity_check()
actions = transition_system.get_actions(asdl_ast)

In [8]:
actions

[ApplyRule[stmt -> Select(agg_op? agg, column_idx col_idx, cond_expr* conditions)],
 ApplyRule[agg_op -> Min()],
 SelectColumnAction[id=1],
 Reduce]

In [9]:
actions.count

<function list.count(value, /)>

In [10]:
query_reconstr = asdl_ast_to_sql_query(asdl_ast)

In [11]:
query_reconstr

SELECT MIN col1 FROM table

In [40]:
q = {"sel": 1, "conds": [[1, 0, "New series began in June 2011"], [2, 2, "Vse horosho"]], "agg": 0}

qu = Query.from_dict(q)

In [41]:
qu

SELECT  col1 FROM table WHERE col1 = New series began in June 2011 AND col2 < Vse horosho

In [42]:
qu.to_dict()

{'sel': 1,
 'agg': 0,
 'conds': [[1, 0, 'New series began in June 2011'], [2, 2, 'Vse horosho']]}

In [43]:
qu.to_query('text')

'SELECT col1 FROM table WHERE col1 = New series began in June 2011 AND col2 < Vse horosho'

In [54]:
a = qu.to_query('text')

In [27]:
cmp_op_idx2op_symbol = {0: '=', 1: '>', 2: '<'}

cmp_op_idx2op_name = {0: 'Equal', 1: 'GreaterThan', 2: 'LessThan'}
ctr_name2cmp_op_idx = {v: k for k, v in cmp_op_idx2op_name.items()}
agg_idx2op_name = {1: 'Max', 2: 'Min', 3: 'Count', 4: 'Sum', 5: 'Avg'}
ctr_name2agg_idx = {v: k for k, v in agg_idx2op_name.items()}

In [55]:
asdl_ast = sql_query_to_asdl_ast(qu, grammar)
asdl_ast.sanity_check()
actions = transition_system.get_actions(asdl_ast)

In [56]:
query_reconstr = asdl_ast_to_sql_query(asdl_ast)

In [57]:
query_reconstr.to_query('text')

'SELECT col1 FROM table WHERE col3 > New series began in June 2011 AND col1 = Smth in 4'

In [17]:
actions

[ApplyRule[stmt -> Select(agg_op? agg, column_idx col_idx, cond_expr* conditions)],
 Reduce,
 SelectColumnAction[id=1],
 ApplyRule[cond_expr -> Condition(cmp_op op, column_idx col_idx, string value)],
 ApplyRule[cmp_op -> Equal()],
 SelectColumnAction[id=3],
 GenToken[New],
 GenToken[series],
 GenToken[began],
 GenToken[in],
 GenToken[June],
 GenToken[2011],
 GenToken[</primitive>],
 Reduce]

In [18]:
import pandas as pd
data = pd.read_csv('my_wikisql/train.csv')

In [19]:
data.head()

Unnamed: 0,question,sql
0,Tell me what the notes are for South Australia,SELECT Notes FROM table WHERE Current slogan =...
1,What is the current series where the new serie...,SELECT Current series FROM table WHERE Notes =...
2,What is the format for South Australia?,SELECT Format FROM table WHERE State/territory...
3,Name the background colour for the Australian ...,SELECT Text/background colour FROM table WHERE...
4,how many times is the fuel propulsion is cng?,SELECT COUNT Fleet Series (Quantity) FROM tabl...


In [20]:
cmp_op_idx2op_name = {0: 'Equal', 1: 'GreaterThan', 2: 'LessThan'}
ctr_name2cmp_op_idx = {v: k for k, v in cmp_op_idx2op_name.items()}
agg_idx2op_name = {1: 'Max', 2: 'Min', 3: 'Count', 4: 'Sum', 5: 'Avg'}
ctr_name2agg_idx = {v: k for k, v in agg_idx2op_name.items()}

In [21]:
ctr_name2cmp_op_idx

{'Equal': 0, 'GreaterThan': 1, 'LessThan': 2}

In [22]:
data.loc[4]['sql']

'SELECT COUNT Fleet Series (Quantity) FROM table WHERE Fuel Propulsion = CNG'

In [23]:
a = []
for sql in data['sql']:
    if sql.count('WHERE') != 0:
        a.append(sql)

len(a)

55932

In [None]:
qq = data.loc[0]['sql']
qq