In [1]:
from pathlib import Path
import pandas as pd
import json

from mo_sql_parsing import parse

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("seeklhy/codes-3b")

print(tokenizer("subtree"))
print(tokenizer("st1"))
print(tokenizer("st20"))
print(tokenizer("st200"))
print(tokenizer("[LEFT_CHILD]"))

{'input_ids': [1113, 3242], 'attention_mask': [1, 1]}
{'input_ids': [270, 35], 'attention_mask': [1, 1]}
{'input_ids': [270, 36, 34], 'attention_mask': [1, 1, 1]}
{'input_ids': [270, 36, 34, 34], 'attention_mask': [1, 1, 1, 1]}
{'input_ids': [77, 10765, 81, 31590, 79], 'attention_mask': [1, 1, 1, 1, 1]}


In [6]:
import re

foo = "[NODE_BEGIN]name[NODE_END][NODE_BEGIN]country[NODE_END][NODE_BEGIN]age[NODE_END][NODE_BEGIN]singer[NODE_END]"

foo2 = re.split(r"\[NODE_END\]", "[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4[NODE_END][NODE_BEGIN]equal to[NB]st1[NB]st3[NODE_END]")
print(foo2)
bar = re.match(r"^\[NODE_BEGIN\]([a-z ]+)\[NB\](st\d+)\[NB\](st\d+)\[NODE_END\]$", )
for j in bar.groups():
    print(j)

['[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4', '[NODE_BEGIN]equal to[NB]st1[NB]st3', '']


TypeError: match() missing 1 required positional argument: 'string'

In [18]:
foo = "[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4[NODE_END][NODE_BEGIN]equal to[NB]st1[NB]st3[NODE_END]"

[(x + "[NODE_END]") for x in foo.split("[NODE_END]") if x != ""]

foo.count("[NB]")


4

In [1]:
operations = ["union", "intersection", "difference", "selection", "cartesian product", "projection", "and", "or", "greater than", "greater than or equal to", "less than", "less than or equal to", "order by ascending", "order by descending", "group by", "limit", "in", "not in ", "like", "not like", "sum", "max", "min", "count", "average", "distinct", "keep"]

operation_types = ["predicate", "relation", "schema_constant", "any"]

operations_type_map = { "union": "relation",
                       "intersection": "relation",
                        "difference": "relation",
                        "selection": "relation",
                        "cartesian product": "relation",
                        "projection": "relation",
                        "and": "predicate",
                        "or": "predicate",
                        "greater than": "predicate",
                        "greater than or equal to": "predicate",
                        "less than": "predicate",
                        "less than or equal to": "predicate",
                        "order by ascending": "relation",
                        "order by descending": "relation",
                        "group by": "relation",
                        "limit": "relation",
                        "in": "predicate",
                        "not in ": "predicate",
                        "like": "predicate",
                        "not like": "predicate",
                        "sum": "schema_constant",
                        "max": "schema_constant",
                        "min": "schema_constant",
                        "count": "schema_constant",
                        "average": "schema_constant",
                        "distinct": "schema_constant",
                        "keep" : "any"
}

print(len(operations))
print(len(operations_type_map.keys()))

27
26


In [None]:
level 0 node samples = "[NODE_BEGIN]name[NODE_END][NODE_BEGIN]country[NODE_END][NODE_BEGIN]age[NODE_END][NODE_BEGIN]singer[NODE_END][NODE_BEGIN]60[NODE_END]"
level 1 node samnple = "[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4[NODE_END]" #binary
level 1 node sample = "[NODE_BEGIN]keep[NB]st2[NODE_END]" #unary



"level_wise_subtrees": [
            \\name, country, age, singer, 60
            "[NODE_BEGIN]name[NODE_END][NODE_BEGIN]country[NODE_END][NODE_BEGIN]age[NODE_END][NODE_BEGIN]singer[NODE_END][NODE_BEGIN]60[NODE_END]",
            "[NODE_BEGIN]greater than or equal to[NB]st2[NB]st4[NODE_END][NODE_BEGIN]"
        ]

In [2]:
PATH_DS = Path('data')
PATH_BIRD_SQL = PATH_DS / 'bird'
PATH_SPIDER_SQL = PATH_DS / 'spider'

In [3]:
generate_out_path = lambda path, suffix: path.parent / ''.join(path.parts[-1].split('.')[:-1] + [suffix])

In [4]:
paths = [
    PATH_SPIDER_SQL / "processed" / "dev_gold.json",
    PATH_SPIDER_SQL / "processed" / "train_gold.json",
    PATH_BIRD_SQL / "dev" / "dev.json",
    PATH_BIRD_SQL / "train" / "train.json",
]

In [5]:
def parse_sql(df, key="SQL"):
    errors = []
    df['SQL_parse'] = None
    for i in range(df.shape[0]):
        try:
            df.at[i, 'SQL_parse'] = parse(df.at[i, key])
        except Exception as e:
            errors.append(i)
    return df, errors

In [6]:
def add_sql_parse(path: Path):
    with open(path) as f:
        df = pd.DataFrame(json.load(f))
        df.set_index("question_id", inplace=True)
    df, errors = parse_sql(df)
    out_path = generate_out_path(path, "_parse.json")
    df.to_json(out_path, orient="records", indent=4)
    if len(errors) > 0:
        with open(generate_out_path(path, "_parse_errors.json"), "w") as f:
            json.dump(errors, f)

In [7]:
def filter_sql(df):
    df_sql = df["SQL"].copy()

    # remove queries with inner joins and nested queries
    df_sql = df_sql[~df_sql.str.contains("inner join", case=False)]
    df_sql = df_sql[~df_sql.str.contains(r"\(.*\bSELECT\b.*\)", regex=True, case=False)]

    # TODO: add filter number of tokens < 55

    return df.loc[df_sql.index]

In [9]:
for path in paths:
    with open(path) as f:
        df = pd.DataFrame(json.load(f))
        if "question_id" not in df.columns:
            df["question_id"] = df.index
        df.set_index("question_id", inplace=True)
        df = filter_sql(df)
        out_path = generate_out_path(path, "_filtered.json")
        df.to_json(out_path, orient="records", indent=4)
        print(df.shape)

(949, 3)
(6464, 3)
(336, 5)
(1984, 4)
