In [1]:
import pandas as pd
from sqlalchemy import create_engine
from joblib import dump, load
from keras.preprocessing.sequence import pad_sequences
import numpy as np

from select_col import get_col_model
from agg_col import get_agg_model
from k_where import get_k_where_model
from where_col import get_where_col_model
from where_ops import get_where_ops_model
from where_value import get_where_value_model

Using TensorFlow backend.


In [2]:
#create a dummy database
table_name='users'
engine = create_engine('sqlite://', echo=False)

df = pd.DataFrame({'name' : ['alan', 'joe', 'rick'],
                  'year': ['1st year','2nd year','3rd year'],
                  'major':['CS','EE','CE']})

df.to_sql(table_name, con=engine)
engine.execute("SELECT * FROM users").fetchall()

[(0, 'alan', '1st year', 'CS'),
 (1, 'joe', '2nd year', 'EE'),
 (2, 'rick', '3rd year', 'CE')]

In [3]:
NLQ='what is the major of the user with name joe '

In [4]:
#load the tokenizer
tokenizer=load('tokenizer.joblib')
max_len=419
max_token_index=1246

In [5]:
def sketch2Query(table_name,agg_col,select_col,where_cols,where_ops,where_values):
    temp='SELECT '
    if agg_col!='':
        temp+=agg_col+'('
        temp+=select_col+')'
    else:
        temp+=select_col
    temp+=' FROM '+table_name
    if len(where_cols)==0:
        return temp+';'
    elif len(where_cols)>1:
        temp+=' WHERE '+where_cols[0]+' '+where_ops[0]+' '+where_values[0]
        for i in range(1,len(where_cols)):
            temp+='AND WHERE '+where_cols[i]+' '+where_ops[i]+' '+where_values[i]
    else:
        temp+=' WHERE '+where_cols[0]+' '+where_ops[0]+' "'+where_values[0]+'"'
    return temp+';'

In [6]:
def text2seq(text):
    return np.ravel(pad_sequences(tokenizer.texts_to_sequences([text]),maxlen=max_len,padding='post'))


In [7]:
select_col_model=get_col_model()

Instructions for updating:
Colocations handled automatically by placer.


In [8]:
#distance based
input_q=text2seq(NLQ)
preds_col=[]
for column in df.columns:
    input_c=text2seq(column)
    preds_col.append(select_col_model.predict([[input_q],[input_c]]))
preds_col

[array([[1.0633502]], dtype=float32),
 array([[1.0431721]], dtype=float32),
 array([[0.85307306]], dtype=float32)]

In [9]:
selected_col=df.columns[np.argmin(preds_col)]
selected_col

'major'

In [10]:
#get if there is a agg column
#softmax
lb_agg=load('lb_agg.joblib')
agg_model=get_agg_model()

In [11]:
input_c=text2seq(selected_col)
pred_agg=agg_model.predict([[input_q],[input_c]])
pred_agg

array([[9.8474979e-01, 7.3686393e-04, 7.7564674e-03, 8.4247455e-05,
        4.8483445e-05, 6.6239852e-03]], dtype=float32)

In [12]:
pred_agg=lb_agg.inverse_transform(pred_agg)
pred_agg

array([''], dtype='<U5')

In [13]:
#How many where clauses do we need
lb_k=load('lb_k.joblib')
k_where_model=get_k_where_model()

In [14]:
#softmax based
pred_k_where=k_where_model.predict([[input_q]])
pred_k_where

array([[1.5023233e-04, 9.3717235e-01, 6.1912246e-02, 7.4570067e-04,
        1.9479246e-05]], dtype=float32)

In [15]:
pred_k_where=lb_k.inverse_transform(pred_k_where)
pred_k_where

array([1])

In [16]:
#Predict the where clause columns
where_col_model=get_where_col_model()

In [17]:
#distance based
preds_where_col=[]
for column in df.columns:
    input_c=text2seq(column)
    preds_where_col.append(where_col_model.predict([[input_q],[input_c]]))
preds_where_col=np.ravel(preds_where_col)
preds_where_col

array([0.4337846 , 0.42755792, 0.76692   ], dtype=float32)

In [18]:
#select k columns
k=pred_k_where[0]
min_dist_idx = np.argsort(preds_where_col)[:k]
where_columns=[]
for idx in min_dist_idx:
    where_columns.append(df.columns[idx])
where_columns

['year']

In [19]:
#select the cond ops
#softmax
lb_ops=load('lb_ops.joblib')
where_ops_model=get_where_ops_model()

In [20]:
where_ops=[]
for column in where_columns:
    input_c=text2seq(column)
    print(where_ops_model.predict([[input_q],[input_c]]))
    where_ops.append(lb_ops.inverse_transform(where_ops_model.predict([[input_q],[input_c]]))[0])
where_ops

[[0.00618434 0.9869921  0.00682344]]


['=']

In [21]:
#Predict the where clause columns
#start and end sequence
where_value_model=get_where_value_model()

In [22]:
where_values=[]
for i in range(len(where_columns)):
    input_c=text2seq(where_columns[i])
    input_ops=lb_ops.transform(where_ops)[i]
    temp=where_value_model.predict([[input_q],[input_c],[input_ops]])
    where_values.append(NLQ[np.argmax(temp[0]):np.argmax(temp[1])])
where_values

['with name joe']

In [23]:
print(sketch2Query(table_name,pred_agg,selected_col,where_columns,where_ops,where_values))

SELECT major FROM users WHERE year = "with name joe";
