In [1]:
# Format: 
# Level 0: Q | RS -> [node_begin](<value>,<node_phi>,<node_phi>)[node_end] | ...
# Level 1: Q | RS | <node_embeddings> -> [node_begin](<op>,[node_1],[node_2])[node_end] | [node_begin](<op>,[node_3],[node_phi])[node_end] | ...
# Level 2: Q | RS | <node_embeddings> -> [node_begin](<op>,[node_1],[node_2])[node_end] | [node_begin](<op>,[node_3],[node_4])[node_end] | ...
# .
# .
# .
# Level 9: Q | RS | <node_embeddings> -> [node_begin](<op>,[node_1],[node_2])[node_end]

In [10]:
from anytree import Node, RenderTree, LevelOrderIter
import json
from mo_sql_parsing import parse

import pandas as pd
from random import sample

from ra_preproc import ast_to_ra

In [11]:
with open('/root/repo/CS726/parsed/bird_dev_parsed.json', 'r') as f:
    bird_dev = pd.DataFrame(json.load(f))

In [12]:
bird_dev.head()

Unnamed: 0,db_id,question,evidence,SQL,difficulty,SQL_parse
0,california_schools,What is the highest eligible free rate for K-1...,Eligible free rate for K-12 = `Free Meal Count...,SELECT `Free Meal Count (K-12)` / `Enrollment ...,simple,{'select': {'value': {'div': ['Free Meal Count...
1,california_schools,Please list the lowest three eligible free rat...,Eligible free rates for students aged 5-17 = `...,SELECT `Free Meal Count (Ages 5-17)` / `Enroll...,moderate,{'select': {'value': {'div': ['Free Meal Count...
2,california_schools,Please list the zip code of all the charter sc...,Charter schools refers to `Charter School (Y/N...,SELECT T2.Zip FROM frpm AS T1 INNER JOIN schoo...,simple,"{'select': {'value': 'T2.Zip'}, 'from': [{'val..."
3,california_schools,What is the unabbreviated mailing address of t...,,SELECT T2.MailStreet FROM frpm AS T1 INNER JOI...,simple,"{'select': {'value': 'T2.MailStreet'}, 'from':..."
4,california_schools,Please list the phone numbers of the direct ch...,Charter schools refers to `Charter School (Y/N...,SELECT T2.Phone FROM frpm AS T1 INNER JOIN sch...,moderate,"{'select': {'value': 'T2.Phone'}, 'from': [{'v..."


In [13]:
sql = "SELECT name, age FROM Employees WHERE department = 'HR' AND title='Manager' OR age>=40"
# sql = "select count(*) from employees"

In [14]:
def balance_tree(node, ht):
    if node.height == ht:
        if node.children:
            for child in node.children:
                balance_tree(child, ht - 1)
        return node
    else:
        keep_node = Node("keep", parent=node.parent, children=[node])
        keep_node.n_type = node.n_type
        node.parent = keep_node
        balance_tree(node, ht - 1)

In [15]:
def fix_height(node, ht):
    while node.height != ht:
        node = Node("keep", children=[node], n_type=node.n_type)
    return node

In [17]:
def extract_nodes_at_level(
    nodes, level, nodes_prev_lvl_order=None, node_phi="[node_phi]", shuffle=False
):
    nodes_repr = []
    nodes_order = {}
    if level == 0:
        for i, node in enumerate(nodes[level]):
            nodes_repr.append((node.val, node_phi, node_phi, node.n_type))
            nodes_order[hash(node)] = i
    else:
        assert nodes_prev_lvl_order is not None
        for i, node in enumerate(nodes[level]):
            nodes_repr.append(
                (
                    node.name,
                    nodes_prev_lvl_order[hash(node.children[0])],
                    (
                        nodes_prev_lvl_order[hash(node.children[1])]
                        if len(node.children) > 1
                        else node_phi
                    ),
                    node.n_type
                )
            )
            nodes_order[hash(node)] = i
    if shuffle:
        new_order = sample(range(len(nodes_repr)), len(nodes_repr))
        rev_hash = {v: k for k, v in nodes_order.items()}
        nodes_repr_ = []
        nodes_order_ = {}
        for i_, i in enumerate(new_order):
            nodes_repr_.append(nodes_repr[i])
            nodes_order_[rev_hash[i]] = i_
        return nodes_repr_, nodes_order_
    return nodes_repr, nodes_order

In [18]:
def print_tree(root):
    for pre, fill, node in RenderTree(root):
        print("%s%s" % (pre, node.name+"|"+node.n_type))
        
        
def get_node_repr(parsed):
    root = ast_to_ra(parsed)
    print_tree(root)
    root = balance_tree(root, root.height)
    root = fix_height(root, 9)
    print_tree(root)

    nodes = {}
    for h in range(0, root.height + 1):
        nodes[h] = {
            v: k
            for k, v in enumerate(
                LevelOrderIter(root, filter_=lambda n: n.height == h)
            )
        }
    nodes_repr = {}
    flag_shuffle = True
    for i in range(root.height + 1):
        if i == 0:
            nodes_repr[0] = extract_nodes_at_level(nodes, 0, shuffle=flag_shuffle)
        else:
            nodes_repr[i] = extract_nodes_at_level(
                nodes, i, nodes_repr[i - 1][1], shuffle=flag_shuffle
            )
    return nodes_repr

parsed_sql = parse(sql)


nodes_repr = get_node_repr(parsed_sql)

Project|Table
├── Val_list|Value
│   ├── Value|Value
│   └── Value|Value
└── Selection|Table
    ├── Or|Predicate
    │   ├── And|Predicate
    │   │   ├── eq|Predicate
    │   │   │   ├── Value|Value
    │   │   │   └── literal|Agg
    │   │   │       └── Value|Value
    │   │   └── eq|Predicate
    │   │       ├── Value|Value
    │   │       └── literal|Agg
    │   │           └── Value|Value
    │   └── gte|Predicate
    │       ├── Value|Value
    │       └── Value|Value
    └── Table|Table
keep|Table
└── keep|Table
    └── keep|Table
        └── Project|Table
            ├── Selection|Table
            │   ├── Or|Predicate
            │   │   ├── And|Predicate
            │   │   │   ├── eq|Predicate
            │   │   │   │   ├── literal|Agg
            │   │   │   │   │   └── Value|Value
            │   │   │   │   └── keep|Value
            │   │   │   │       └── Value|Value
            │   │   │   └── eq|Predicate
            │   │   │       ├── literal|Agg
            │   │

In [11]:
# import re

# foo = "[NODE_BEGIN]name[NODE_END][NODE_BEGIN]country[NODE_END][NODE_BEGIN]age[NODE_END][NODE_BEGIN]singer[NODE_END]"

# foo2 = re.split(r"\[NODE_END\]", "[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4[NODE_END][NODE_BEGIN]equal to[NB]st1[NB]st3[NODE_END]")
# print(foo2)
# bar = re.match(r"^\[NODE_BEGIN\]([a-z ]+)\[NB\](st\d+)\[NB\](st\d+)\[NODE_END\]$", )
# for j in bar.groups():
#     print(j)

In [12]:

# foo = "[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4[NODE_END][NODE_BEGIN]equal to[NB]st1[NB]st3[NODE_END]"

# [(x + "[NODE_END]") for x in foo.split("[NODE_END]") if x != ""]

# foo.count("[NB]")


In [10]:
def get_node_repr_training(node_tuple):
    node_repr = '[NODE_BEGIN]'
    node_repr += str(node_tuple[0])
    if node_tuple[1] != '[node_phi]':
        node_repr += '[NB]st' + str(node_tuple[1])
    if node_tuple[2] != '[node_phi]':
        node_repr += '[NB]st' + str(node_tuple[2])
    node_repr += '[NODE_END]'
    return node_repr

In [11]:
parsed_sql = parse(sql)


nodes_repr = get_node_repr(parsed_sql)

keep
└── keep
    └── keep
        └── Project
            ├── Selection
            │   ├── Or
            │   │   ├── And
            │   │   │   ├── eq
            │   │   │   │   ├── literal
            │   │   │   │   │   └── Value
            │   │   │   │   └── keep
            │   │   │   │       └── Value
            │   │   │   └── eq
            │   │   │       ├── literal
            │   │   │       │   └── Value
            │   │   │       └── keep
            │   │   │           └── Value
            │   │   └── keep
            │   │       └── keep
            │   │           └── gte
            │   │               ├── Value
            │   │               └── Value
            │   └── keep
            │       └── keep
            │           └── keep
            │               └── keep
            │                   └── Table
            └── keep
                └── keep
                    └── keep
                        └── keep
                            └── Val_

In [14]:
RenderTree(root)

NameError: name 'root' is not defined

In [15]:
def gen_training_repr(nodes_repr):
    training_repr = []
    for k, v in nodes_repr.items():
        training_repr.append(
            "".join([get_node_repr_training(node_tuple) for node_tuple in v[0]])
        )
    return training_repr

In [16]:
training_repr = gen_training_repr(nodes_repr)

In [17]:
for i, level in enumerate(training_repr):
    print(f"Level {i}:")
    print(level)

Level 0:
[NODE_BEGIN]Manager[NODE_END][NODE_BEGIN]age[NODE_END][NODE_BEGIN]name[NODE_END][NODE_BEGIN]40[NODE_END][NODE_BEGIN]department[NODE_END][NODE_BEGIN]Employees[NODE_END][NODE_BEGIN]title[NODE_END][NODE_BEGIN]age[NODE_END][NODE_BEGIN]HR[NODE_END]
Level 1:
[NODE_BEGIN]gte[NB]st1[NB]st3[NODE_END][NODE_BEGIN]literal[NB]st8[NODE_END][NODE_BEGIN]keep[NB]st4[NODE_END][NODE_BEGIN]literal[NB]st0[NODE_END][NODE_BEGIN]keep[NB]st6[NODE_END][NODE_BEGIN]keep[NB]st5[NODE_END][NODE_BEGIN]Val_list[NB]st2[NB]st7[NODE_END]
Level 2:
[NODE_BEGIN]keep[NB]st5[NODE_END][NODE_BEGIN]keep[NB]st0[NODE_END][NODE_BEGIN]eq[NB]st1[NB]st2[NODE_END][NODE_BEGIN]keep[NB]st6[NODE_END][NODE_BEGIN]eq[NB]st3[NB]st4[NODE_END]
Level 3:
[NODE_BEGIN]keep[NB]st3[NODE_END][NODE_BEGIN]keep[NB]st0[NODE_END][NODE_BEGIN]keep[NB]st1[NODE_END][NODE_BEGIN]And[NB]st2[NB]st4[NODE_END]
Level 4:
[NODE_BEGIN]keep[NB]st1[NODE_END][NODE_BEGIN]Or[NB]st3[NB]st2[NODE_END][NODE_BEGIN]keep[NB]st0[NODE_END]
Level 5:
[NODE_BEGIN]Selection[NB]st