# TEXT2SQL

In [1]:
import os
import json
import transformers

pretrained_weights = {
    ("bert", "base"): "bert-base-uncased",
    ("bert", "large"): "bert-large-uncased-whole-word-masking",
    ("roberta", "base"): "roberta-base",
    ("roberta", "large"): "roberta-large",
    ("albert", "xlarge"): "albert-xlarge-v2"
}


def read_jsonl(jsonl):
    for line in open(jsonl, encoding="utf8"):
        sample = json.loads(line.rstrip())
        yield sample

def read_conf(conf_path):
    config = {}
    for line in open(conf_path, encoding="utf8"):
        if line.strip() == "" or line[0] == "#":
             continue
        fields = line.strip().split("\t")
        config[fields[0]] = fields[1]
    config["train_data_path"] =  os.path.abspath(config["train_data_path"])
    config["dev_data_path"] =  os.path.abspath(config["dev_data_path"])

    return config

def create_base_model(config):
    weights_name = pretrained_weights[(config["base_class"], config["base_name"])]
    if config["base_class"] == "bert":
        return transformers.BertModel.from_pretrained(weights_name)
    elif config["base_class"] == "roberta":
        return transformers.RobertaModel.from_pretrained(weights_name)
    elif config["base_class"] == "albert":
        return transformers.AlbertModel.from_pretrained(weights_name)
    else:
        raise Exception("base_class {0} not supported".format(config["base_class"]))

def create_tokenizer(config):
    weights_name = pretrained_weights[(config["base_class"], config["base_name"])]
    if config["base_class"] == "bert":
        return transformers.BertTokenizer.from_pretrained(weights_name)
    elif config["base_class"] == "roberta":
        return transformers.RobertaTokenizer.from_pretrained(weights_name)
    elif config["base_class"] == "albert":
        return transformers.AlbertTokenizer.from_pretrained(weights_name)
    else:
        raise Exception("base_class {0} not supported".format(config["base_class"]))

In [2]:
import json
import os
import string
import unicodedata

def is_whitespace(c):
    if c == " " or c == "\t" or c == "\n" or c == "\r":
        return True
    cat = unicodedata.category(c)
    if cat == "Zs":
        return True
    return False

def is_punctuation(c):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(c)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
        (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(c)
    if cat.startswith("P") or cat.startswith("S"):
        return True
    return False

def basic_tokenize(doc):
    doc_tokens = []
    char_to_word = []
    word_to_char_start = []
    prev_is_whitespace = True
    prev_is_punc = False
    prev_is_num = False
    for pos, c in enumerate(doc):
        if is_whitespace(c):
            prev_is_whitespace = True
            prev_is_punc = False
        else:
            if prev_is_whitespace or is_punctuation(c) or prev_is_punc or (prev_is_num and not str(c).isnumeric()):
                doc_tokens.append(c)
                word_to_char_start.append(pos)
            else:
                doc_tokens[-1] += c
            prev_is_whitespace = False
            prev_is_punc = is_punctuation(c)
            prev_is_num = str(c).isnumeric()
        char_to_word.append(len(doc_tokens) - 1)

    return doc_tokens, char_to_word, word_to_char_start

In [3]:

def get_schema(tables):
    schema, headers, colTypes, naturalMap = {}, {}, {}, {}
    for table in tables:
        values = [set() for _ in range(len(table["header"]))]
        for row in table["rows"]:
            for i, value in enumerate(row):
                values[i].add(str(value).lower())
        columns = {column: values[i] for i, column in enumerate(table["header"])}

        trans = {"text": "string", "real": "real", "integer": "integer"}
        colTypes[table["id"]] = {col:trans[ty.lower()] for ty, col in zip(table["types"], table["header"])}
        schema[table["id"]] = columns
        naturalMap[table["id"]] = {col: col for col in columns}
        headers[table["id"]] = table["header"]

    return schema, headers, colTypes, naturalMap


# if __name__ == "__main__":
#     data_path = os.path.join("WikiSQL", "data")
#     for phase in ["train", "dev", "test"]:
#         src_file = os.path.join(data_path, phase + ".jsonl")
#         schema_file = os.path.join(data_path, phase + ".tables.jsonl")
#         output_file = os.path.join("data", "wiki" + phase + ".jsonl")
#         schema, headers, colTypes, naturalMap = get_schema(utils.read_jsonl(schema_file))

In [4]:
from pathlib import Path

# data_path = Path().absolute().parent.parent / "data/sqldata/wikisql"
data_path = Path().absolute().parent / "src/data/train"

In [76]:
for table in tables:
    break

In [77]:
values = [set() for _ in range(len(table["header"]))]
for row in table["rows"]:
    for i, value in enumerate(row):
        values[i].add(str(value).lower())
columns = {column: values[i] for i, column in enumerate(table["header"])}

In [80]:
trans = {"text": "string", "real": "real", "integer": "integer"}
colTypes[table["id"]] = {col:trans[ty.lower()] for ty, col in zip(table["types"], table["header"])}

In [81]:
colTypes

{'1-1000181-1': {'State/territory': 'string',
  'Text/background colour': 'string',
  'Format': 'string',
  'Current slogan': 'string',
  'Current series': 'string',
  'Notes': 'string'},
 '1-10006830-1': {'Aircraft': 'string',
  'Description': 'string',
  'Max Gross Weight': 'string',
  'Total disk area': 'string',
  'Max disk Loading': 'string'},
 '1-10007452-3': {'Order Year': 'string',
  'Manufacturer': 'string',
  'Model': 'string',
  'Fleet Series (Quantity)': 'string',
  'Powertrain (Engine/Transmission)': 'string',
  'Fuel Propulsion': 'string'},
 '1-10015132-1': {'Player': 'string',
  'No.': 'string',
  'Nationality': 'string',
  'Position': 'string',
  'Years in Toronto': 'string',
  'School/Club Team': 'string'},
 '1-10015132-14': {'Player': 'string',
  'No.': 'real',
  'Nationality': 'string',
  'Position': 'string',
  'Years in Toronto': 'string',
  'School/Club Team': 'string'},
 '1-10015132-2': {'Player': 'string',
  'No.': 'string',
  'Nationality': 'string',
  'Positio

In [79]:
table["types"]

['INTEGER',
 'TEXT',
 'TEXT',
 'INTEGER',
 'TEXT',
 'TEXT',
 'TEXT',
 'TEXT',
 'TEXT',
 'TEXT',
 'TEXT',
 'TEXT',
 'TEXT',
 'INTEGER']

In [75]:
tables = read_jsonl(schema_file)

In [73]:
src_file = data_path / "train.jsonl"
schema_file = data_path / "train_table.jsonl"
output_file = data_path / "hydra_train.jsonl"
schema, headers, colTypes, naturalMap = get_schema(read_jsonl(schema_file))

KeyError: 'timestamp'

In [37]:
for raw_sample in read_jsonl(src_file):
    break

In [38]:
table_id = raw_sample["table_id"]
sql = raw_sample["sql"]

cur_schema = schema[table_id]
header = headers[table_id]
cond_col_values = {header[cond[0]]: str(cond[2]) for cond in sql["conds"]}

In [46]:
column_meta = []
for col in header:
    if col in cond_col_values:
        column_meta.append((col, colTypes[table_id][col], cond_col_values[col]))
    else:
        detected_val = None
        # for cond_col_val in cond_col_values.values():
        #     if cond_col_val.lower() in cur_schema[col]:
        #         detected_val = cond_col_val
        #         break
        column_meta.append((col, colTypes[table_id][col], detected_val))

In [53]:
cnt = 0

In [54]:
example = SQLExample(
            cnt,
            raw_sample["question"],
            table_id,
            column_meta,
            sql["agg"],
            int(sql["sel"]),
            [(int(cond[0]), cond[1], str(cond[2])) for cond in sql["conds"]])

In [59]:
d = example.dump_to_json()

In [66]:
d = json.loads(d)

In [13]:
cnt = 0
print("processing {0}...".format(src_file))
with open(output_file, "w", encoding="utf8") as f:
    for raw_sample in utils.read_jsonl(src_file):
        table_id = raw_sample["table_id"]
        sql = raw_sample["sql"]

        cur_schema = schema[table_id]
        header = headers[table_id]
        cond_col_values = {header[cond[0]]: str(cond[2]) for cond in sql["conds"]}
        column_meta = []
        for col in header:
            if col in cond_col_values:
                column_meta.append((col, colTypes[table_id][col], cond_col_values[col]))
            else:
                detected_val = None
                # for cond_col_val in cond_col_values.values():
                #     if cond_col_val.lower() in cur_schema[col]:
                #         detected_val = cond_col_val
                #         break
                column_meta.append((col, colTypes[table_id][col], detected_val))

        example = SQLExample(
            cnt,
            raw_sample["question"],
            table_id,
            column_meta,
            sql["agg"],
            int(sql["sel"]),
            [(int(cond[0]), cond[1], str(cond[2])) for cond in sql["conds"]])

        f.write(example.dump_to_json() + "\n")
        cnt += 1

In [14]:
s = data[0].strip()

In [15]:
d = json.loads(s)
keys = ["qid", "question", "table_id", "column_meta", "agg", "select", "conditions", "tokens", "char_to_word", "word_to_char_start", "value_start_end", "valid"]

In [17]:
d = example.dump_to_json()

{'phase': 1,
 'table_id': '1-1000181-1',
 'question': 'Tell me what the notes are for South Australia ',
 'sql': {'sel': 5, 'conds': [[3, 0, 'SOUTH AUSTRALIA']], 'agg': 0}}

In [16]:
[d[k] for k in keys]

KeyError: 'qid'

In [72]:
class SQLExample(object):
    def __init__(self,
                 qid,
                 question,
                 table_id,
                 column_meta,
                 agg=None,
                 select=None,
                 conditions=None,
                 tokens=None,
                 char_to_word=None,
                 word_to_char_start=None,
                 value_start_end=None,
                 valid=True):
        self.qid = qid
        self.question = question
        self.table_id = table_id
        self.column_meta = column_meta
        self.agg = agg
        self.select = select
        self.conditions = conditions
        self.valid = valid
        if tokens is None:
            self.tokens, self.char_to_word, self.word_to_char_start = basic_tokenize(question)
            self.value_start_end = {}
            if conditions is not None and len(conditions) > 0:
                cur_start = None
                for cond in conditions:
                    value = cond[-1]
                    value_tokens, _, _ = basic_tokenize(value)
                    val_len = len(value_tokens)
                    for i in range(len(self.tokens)):
                        if " ".join(self.tokens[i:i+val_len]).lower() != " ".join(value_tokens).lower():
                            continue
                        s = self.word_to_char_start[i]
                        e = len(question) if i + val_len >= len(self.word_to_char_start) else self.word_to_char_start[i + val_len]
                        recovered_answer_text = question[s:e].strip()
                        if value.lower() == recovered_answer_text.lower():
                            cur_start = i
                            break

                    if cur_start is None:
                        self.valid = False
                        print([value, value_tokens, question, self.tokens])
                        # for c in question:
                        #     print((c, ord(c), unicodedata.category(c)))
                        # raise Exception()
                    else:
                        self.value_start_end[value] = (cur_start, cur_start + val_len)
        else:
            self.tokens, self.char_to_word, self.word_to_char_start, self.value_start_end = tokens, char_to_word, word_to_char_start, value_start_end

    @staticmethod
    def load_from_json(s):
        d = json.loads(s)
        keys = ["qid", "question", "table_id", "column_meta", "agg", "select", "conditions", "tokens", "char_to_word", "word_to_char_start", "value_start_end", "valid"]

        return SQLExample(*[d[k] for k in keys])

    def dump_to_json(self):
        d = {}
        d["qid"] = self.qid
        d["question"] = self.question
        d["table_id"] = self.table_id
        d["column_meta"] = self.column_meta
        d["agg"] = self.agg
        d["select"] = self.select
        d["conditions"] = self.conditions
        d["tokens"] = self.tokens
        d["char_to_word"] = self.char_to_word
        d["word_to_char_start"] = self.word_to_char_start
        d["value_start_end"] = self.value_start_end
        d["valid"] = self.valid

        return json.dumps(d)

    def output_SQ(self, return_str=True):
        agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        cond_ops = ['=', '>', '<', 'OP']

        agg_text = agg_ops[self.agg]
        select_text = self.column_meta[self.select][0]
        cond_texts = []
        for wc, op, value_text in self.conditions:
            column_text = self.column_meta[wc][0]
            op_text = cond_ops[op]
            cond_texts.append(column_text + op_text + value_text)

        if return_str:
            sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts)
        else:
            sq = (agg_text, select_text, set(cond_texts))
        return sq

