# TEXT2SQL

In [1]:
import os
import json
import transformers
import numpy as np
from tqdm.notebook import tqdm
pretrained_weights = {
    ("bert", "base"): "bert-base-uncased",
    ("bert", "large"): "bert-large-uncased-whole-word-masking",
    ("roberta", "base"): "roberta-base",
    ("roberta", "large"): "roberta-large",
    ("albert", "xlarge"): "albert-xlarge-v2"
}

def read_jsonl(jsonl):
    for line in open(jsonl, encoding="utf8"):
        sample = json.loads(line.rstrip())
        yield sample

def read_conf(conf_path):
    config = {}
    for line in open(conf_path, encoding="utf8"):
        if line.strip() == "" or line[0] == "#":
             continue
        fields = line.strip().split("\t")
        config[fields[0]] = fields[1]
    config["train_data_path"] =  os.path.abspath(config["train_data_path"])
    config["dev_data_path"] =  os.path.abspath(config["dev_data_path"])

    return config

def create_base_model(config):
    weights_name = pretrained_weights[(config["base_class"], config["base_name"])]
    if config["base_class"] == "bert":
        return transformers.BertModel.from_pretrained(weights_name)
    elif config["base_class"] == "roberta":
        return transformers.RobertaModel.from_pretrained(weights_name)
    elif config["base_class"] == "albert":
        return transformers.AlbertModel.from_pretrained(weights_name)
    else:
        raise Exception("base_class {0} not supported".format(config["base_class"]))

def create_tokenizer(config):
    weights_name = pretrained_weights[(config["base_class"], config["base_name"])]
    if config["base_class"] == "bert":
        return transformers.BertTokenizer.from_pretrained(weights_name)
    elif config["base_class"] == "roberta":
        return transformers.RobertaTokenizer.from_pretrained(weights_name)
    elif config["base_class"] == "albert":
        return transformers.AlbertTokenizer.from_pretrained(weights_name)
    else:
        raise Exception("base_class {0} not supported".format(config["base_class"]))

In [2]:
import json
import os
import string
import unicodedata

def is_whitespace(c):
    if c == " " or c == "\t" or c == "\n" or c == "\r":
        return True
    cat = unicodedata.category(c)
    if cat == "Zs":
        return True
    return False

def is_punctuation(c):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(c)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
        (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(c)
    if cat.startswith("P") or cat.startswith("S"):
        return True
    return False

def basic_tokenize(doc):
    doc_tokens = []
    char_to_word = []
    word_to_char_start = []
    prev_is_whitespace = True
    prev_is_punc = False
    prev_is_num = False
    for pos, c in enumerate(doc):
        if is_whitespace(c):
            prev_is_whitespace = True
            prev_is_punc = False
        else:
            if prev_is_whitespace or is_punctuation(c) or prev_is_punc or (prev_is_num and not str(c).isnumeric()):
                doc_tokens.append(c)
                word_to_char_start.append(pos)
            else:
                doc_tokens[-1] += c
            prev_is_whitespace = False
            prev_is_punc = is_punctuation(c)
            prev_is_num = str(c).isnumeric()
        char_to_word.append(len(doc_tokens) - 1)

    return doc_tokens, char_to_word, word_to_char_start

In [3]:
def get_schema(tables):
    schema, headers, colTypes, naturalMap = {}, {}, {}, {}
    for table in tables:
        values = [set() for _ in range(len(table["header"]))]
        for row in table["rows"]:
            for i, value in enumerate(row):
                values[i].add(str(value).lower())
        columns = {column: values[i] for i, column in enumerate(table["header"])}

        trans = {"text": "string", "real": "real", "integer": "integer"}
        colTypes[table["id"]] = {col:trans[ty.lower()] for ty, col in zip(table["types"], table["header"])}
        schema[table["id"]] = columns
        naturalMap[table["id"]] = {col: col for col in columns}
        headers[table["id"]] = table["header"]

    return schema, headers, colTypes, naturalMap

# if __name__ == "__main__":
#     data_path = os.path.join("WikiSQL", "data")
#     for phase in ["train", "dev", "test"]:
#         src_file = os.path.join(data_path, phase + ".jsonl")
#         schema_file = os.path.join(data_path, phase + ".tables.jsonl")
#         output_file = os.path.join("data", "wiki" + phase + ".jsonl")
#         schema, headers, colTypes, naturalMap = get_schema(utils.read_jsonl(schema_file))

In [111]:
class SQLExample(object):
    def __init__(self,
                 qid,
                 question,
                 table_id,
                 column_meta,
                 agg=None,
                 select=None,
                 conditions=None,
                 tokens=None,
                 char_to_word=None,
                 word_to_char_start=None,
                 value_start_end=None,
                 valid=True):
        self.keys = ["qid", "question", "table_id", "column_meta", "agg", "select", "conditions", "tokens", "char_to_word", "word_to_char_start", "value_start_end", "valid"]
        self.qid = qid
        self.question = question
        self.table_id = table_id
        self.column_meta = column_meta
        self.agg = agg
        self.select = select
        self.conditions = conditions
        self.valid = valid
        if tokens is None:
            self.tokens, self.char_to_word, self.word_to_char_start = basic_tokenize(question)
            self.value_start_end = {}
            if conditions is not None and len(conditions) > 0:
                cur_start = None
                for cond in conditions:
                    value = cond[-1]
                    value_tokens, _, _ = basic_tokenize(value)
                    val_len = len(value_tokens)
                    for i in range(len(self.tokens)):
                        if " ".join(self.tokens[i:i+val_len]).lower() != " ".join(value_tokens).lower():
                            continue
                        s = self.word_to_char_start[i]
                        e = len(question) if i + val_len >= len(self.word_to_char_start) else self.word_to_char_start[i + val_len]
                        recovered_answer_text = question[s:e].strip()
                        if value.lower() == recovered_answer_text.lower():
                            cur_start = i
                            break

                    if cur_start is None:
                        self.valid = False
                        # print([value, value_tokens, question, self.tokens])
                        # for c in question:
                        #     print((c, ord(c), unicodedata.category(c)))
                        # raise Exception()
                    else:
                        self.value_start_end[value] = (cur_start, cur_start + val_len)
        else:
            self.tokens, self.char_to_word, self.word_to_char_start, self.value_start_end = tokens, char_to_word, word_to_char_start, value_start_end
    
    def __str__(self):
        s = ""
        for k in self.keys:
            s += f"{k}: "
            s += f"{self.__dict__[k]}\n"
        return s
    
    @staticmethod
    def load_from_json(s):
        d = json.loads(s)
        keys = ["qid", "question", "table_id", "column_meta", "agg", "select", "conditions", "tokens", "char_to_word", "word_to_char_start", "value_start_end", "valid"]

        return SQLExample(*[d[k] for k in keys])

    def dump_to_json(self):
        d = {}
        d["qid"] = self.qid
        d["question"] = self.question
        d["table_id"] = self.table_id
        d["column_meta"] = self.column_meta
        d["agg"] = self.agg
        d["select"] = self.select
        d["conditions"] = self.conditions
        d["tokens"] = self.tokens
        d["char_to_word"] = self.char_to_word
        d["word_to_char_start"] = self.word_to_char_start
        d["value_start_end"] = self.value_start_end
        d["valid"] = self.valid

        return json.dumps(d)

    def output_SQ(self, return_str=True):
        agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        cond_ops = ['=', '>', '<', 'OP']

        agg_text = agg_ops[self.agg]
        select_text = self.column_meta[self.select][0]
        cond_texts = []
        for wc, op, value_text in self.conditions:
            column_text = self.column_meta[wc][0]
            op_text = cond_ops[op]
            cond_texts.append(column_text + op_text + value_text)

        if return_str:
            sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts)
        else:
            sq = (agg_text, select_text, set(cond_texts))
        return sq



In [4]:
from pathlib import Path

data_path = Path().absolute().parent / "src" / "data"
src_file = data_path / "train" / "train.jsonl"
schema_file = data_path / "train" / "train.tables.jsonl"
output_file = data_path / "train" / "hydra_train.jsonl"
schema, headers, colTypes, naturalMap = get_schema(read_jsonl(schema_file))

In [5]:
len(src_file.open("r", encoding="utf-8").readlines())

463238

In [6]:
len((data_path / "test" / "test.jsonl").open("r", encoding="utf-8").readlines())

82865

In [116]:
# Wikisql
data_path = Path().absolute().parent.parent / "wikisqldata"
src_file = data_path / "train.jsonl"
schema_file = data_path / "train.tables.jsonl"
output_file = data_path / "hydra_train.jsonl"
schema, headers, colTypes, naturalMap = get_schema(read_jsonl(schema_file))

In [117]:
cnt = 0
print("processing {0}...".format(src_file))

with open(output_file, "w", encoding="utf-8") as f:
    
    loader = tqdm(read_jsonl(src_file))
    for raw_sample in loader:
        table_id = raw_sample["table_id"]
        sql = raw_sample["sql"]

        cur_schema = schema[table_id]
        header = headers[table_id]
        cond_col_values = {header[cond[0]]: str(cond[2]) for cond in sql["conds"]}
        column_meta = []
        for col in header:
            if col in cond_col_values:
                column_meta.append((col, colTypes[table_id][col], cond_col_values[col]))
            else:
                detected_val = None
                # for cond_col_val in cond_col_values.values():
                #     if cond_col_val.lower() in cur_schema[col]:
                #         detected_val = cond_col_val
                #         break
                column_meta.append((col, colTypes[table_id][col], detected_val))
    
        example = SQLExample(
            cnt,
            raw_sample["question"],
            table_id,
            column_meta,
            sql["agg"],
            int(sql["sel"]),
            [(int(cond[0]), cond[1], str(cond[2])) for cond in sql["conds"]])

        f.write(example.dump_to_json() + "\n")
        cnt += 1
        loader.update()

processing C:\Users\simon\Desktop\Codes\wikisqldata\train.jsonl...


0it [00:00, ?it/s]

In [33]:
data_path = Path().absolute().parent / "src/data/test"
src_file = data_path / "test.jsonl"
schema_file = data_path / "test.tables.jsonl"
output_file = data_path / "hydra_test.jsonl"
schema, headers, colTypes, naturalMap = get_schema(read_jsonl(schema_file))

In [118]:
# Wikisql
data_path = Path().absolute().parent.parent / "wikisqldata"
src_file = data_path / "dev.jsonl"
schema_file = data_path / "dev.tables.jsonl"
output_file = data_path / "hydra_dev.jsonl"
schema, headers, colTypes, naturalMap = get_schema(read_jsonl(schema_file))

In [119]:
cnt = 0
print("processing {0}...".format(src_file))

with open(output_file, "w", encoding="utf-8") as f:
    
    loader = tqdm(read_jsonl(src_file))
    for raw_sample in loader:
        table_id = raw_sample["table_id"]
        sql = raw_sample["sql"]

        cur_schema = schema[table_id]
        header = headers[table_id]
        cond_col_values = {header[cond[0]]: str(cond[2]) for cond in sql["conds"]}
        column_meta = []
        for col in header:
            if col in cond_col_values:
                column_meta.append((col, colTypes[table_id][col], cond_col_values[col]))
            else:
                detected_val = None
                # for cond_col_val in cond_col_values.values():
                #     if cond_col_val.lower() in cur_schema[col]:
                #         detected_val = cond_col_val
                #         break
                column_meta.append((col, colTypes[table_id][col], detected_val))
    
        example = SQLExample(
            cnt,
            raw_sample["question"],
            table_id,
            column_meta,
            sql["agg"],
            int(sql["sel"]),
            [(int(cond[0]), cond[1], str(cond[2])) for cond in sql["conds"]])

        f.write(example.dump_to_json() + "\n")
        cnt += 1
        loader.update()

processing C:\Users\simon\Desktop\Codes\wikisqldata\dev.jsonl...


0it [00:00, ?it/s]

# Data loader

In [231]:
class InputFeature(object):
    def __init__(self,
                 question,
                 table_id,
                 tokens,
                 word_to_char_start,
                 word_to_subword,
                 subword_to_word,
                 input_ids,
                 input_mask,
                 segment_ids):
        self.question = question
        self.table_id = table_id
        self.tokens = tokens
        self.word_to_char_start = word_to_char_start
        self.word_to_subword = word_to_subword
        self.subword_to_word = subword_to_word
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids

        self.columns = None
        self.agg = None
        self.select = None
        self.where_num = None
        self.where = None
        self.op = None
        self.value_start = None
        self.value_end = None
        
        self.keys = [
            "question", "table_id", "tokens", "word_to_char_start", "word_to_subword", "subword_to_word", "input_ids", "input_mask", "segment_ids",
            "columns", "agg", "select", "where_num", "where", "op", "value_start", "value_end"
        ]
    
    def __str__(self):
        s = ""
        for k in self.keys:
            s += f"{k}: "
#             s_res = self.__dict__[k]
#             if k == "tokens":
#                 s_res = [[tkn for tkn in s if tkn != "[PAD]"] for s in s_res ]
#             if k == "input_ids":
#                 s_res = [[tkn for tkn in s if tkn != 0] for s in s_res ]
            s += f"{self.__dict__[k]}\n"
        return s
    
    def output_SQ(self, agg = None, sel = None, conditions = None, return_str=True):
        agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        cond_ops = ['=', '>', '<', 'OP']

        if agg is None and sel is None and conditions is None:
            sel = np.argmax(self.select)
            agg = self.agg[sel]
            conditions = []
            for i in range(len(self.where)):
                if self.where[i] == 0:
                    continue
                conditions.append((i, self.op[i], self.value_start[i], self.value_end[i]))

        agg_text = agg_ops[agg]
        select_text = self.columns[sel]
        cond_texts = []
        for wc, op, vs, ve in conditions:
            column_text = self.columns[wc]
            op_text = cond_ops[op]
            word_start, word_end = self.subword_to_word[wc][vs], self.subword_to_word[wc][ve]
            char_start = self.word_to_char_start[word_start]
            char_end = len(self.question) if word_end + 1 >= len(self.word_to_char_start) else self.word_to_char_start[word_end + 1]
            value_span_text = self.question[char_start:char_end]
            cond_texts.append(column_text + op_text + value_span_text.rstrip())

        if return_str:
            sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts)
        else:
            sq = (agg_text, select_text, set(cond_texts))

        return sq

In [232]:
config = {
    "model_type":"pytorch",

    "SAVE":1,
    "train_data_path":"data/wikitrain.jsonl",
    "dev_data_path":"data/wikidev.jsonl",
    "test_data_path":"data/wikitest.jsonl",

    "base_class":"bert",
    "base_name":"base",
    "max_total_length":96,
    "where_column_num":4,
    "op_num":4,
    "agg_num":6,

    "drop_rate":0.2,
    "learning_rate":3e-5,
    "decay":0.01,
    "epochs":5,
    "batch_size":256,
    "num_warmup_steps":400,
}
model = create_base_model(config)
tokenizer = create_tokenizer(config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
featurizer = HydraFeaturizer(config)

In [233]:
data_path = Path().absolute().parent.parent / "wikisqldata"
train_data_path = data_path / "hydra_train.jsonl"
include_label = True
model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]}
if include_label:
    for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]:
        model_inputs[k] = []

pos = []
input_features = []
for line in open(train_data_path, encoding="utf8"):
    example = SQLExample.load_from_json(line)
    if not example.valid and include_label == True:
        continue
    input_feature = get_input_feature(example, config)
    if include_label:
        success = fill_label_feature(example, input_feature, config)
        if not success:
            continue

    # sq = input_feature.output_SQ()
    input_features.append(input_feature)
    break

In [234]:
print(example)

qid: 0
question: Tell me what the notes are for South Australia 
table_id: 1-1000181-1
column_meta: [['State/territory', 'string', None], ['Text/background colour', 'string', None], ['Format', 'string', None], ['Current slogan', 'string', 'SOUTH AUSTRALIA'], ['Current series', 'string', None], ['Notes', 'string', None]]
agg: 0
select: 5
conditions: [[3, 0, 'SOUTH AUSTRALIA']]
tokens: ['Tell', 'me', 'what', 'the', 'notes', 'are', 'for', 'South', 'Australia']
char_to_word: [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
word_to_char_start: [0, 5, 8, 13, 17, 23, 27, 31, 37]
value_start_end: {'SOUTH AUSTRALIA': [7, 9]}
valid: True



In [None]:
print(input_feature)

In [184]:
max_total_length = config["max_total_length"]
input_feature = InputFeature(
    example.question,
    example.table_id,
    [],
    example.word_to_char_start,
    [],
    [],
    [],
    [],
    []
)

In [185]:
for column, col_type, _ in example.column_meta:
    # get query tokens
    tokens = []
    word_to_subword = []
    subword_to_word = []
    for i, query_token in enumerate(example.tokens):
        sub_tokens = tokenizer.tokenize(query_token)
        cur_pos = len(tokens)
        if len(sub_tokens) > 0:
            word_to_subword += [(cur_pos, cur_pos + len(sub_tokens))]
            tokens.extend(sub_tokens)
            subword_to_word.extend([i] * len(sub_tokens))
    tokenize_result = tokenizer(
        col_type + " " + column,
        tokens,
        max_length=max_total_length,
        truncation="longest_first",
        padding="max_length"
    )
    input_ids = tokenize_result["input_ids"]
    input_mask = tokenize_result["attention_mask"]
    
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    column_token_length = 0
    for i, token_id in enumerate(input_ids):
        if token_id == tokenizer.sep_token_id:
            column_token_length = i + 1
            break
    segment_ids = [0] * max_total_length
    for i in range(column_token_length, max_total_length):
        if input_mask[i] == 0:
            break
        segment_ids[i] = 1
    
    print(tokenizer.decode(np.array(input_ids)*np.array(segment_ids), skip_special_tokens=True))
    
    subword_to_word = [0] * column_token_length + subword_to_word
    word_to_subword = [(pos[0]+column_token_length, pos[1]+column_token_length) for pos in word_to_subword]
    
    input_feature.tokens.append(tokens)
    input_feature.word_to_subword.append(word_to_subword)
    input_feature.subword_to_word.append(subword_to_word)
    input_feature.input_ids.append(input_ids)
    input_feature.input_mask.append(input_mask)
    input_feature.segment_ids.append(segment_ids)

tell me what the notes are for south australia
tell me what the notes are for south australia
tell me what the notes are for south australia
tell me what the notes are for south australia
tell me what the notes are for south australia
tell me what the notes are for south australia


```
C:\Users\simon\miniconda3\envs\venv\lib\site-packages\transformers\tokenization_utils_base.py:2129: FutureWarning: The `truncation_strategy` argument is deprecated and will be removed in a future version, use `truncation=True` to truncate examples to a max length. You can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input size of the model (e.g. 512 for Bert).  If you have pairs of inputs, you can give a specific truncation strategy selected among `truncation='only_first'` (will only truncate the first sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).
  warnings.warn(
```

# Korean Test

In [236]:
from KoBertTokenizer import KoBertTokenizer
from transformers import BertConfig, BertModel
from transformers import ElectraTokenizer, ElectraModel

In [237]:
def get_model_tokenizer(model_type: str, special_tokens=None):
    model_path_dict = {
        "kobert": "monologg/kobert",
        "koelectra": "monologg/koelectra-base-v3-discriminator"
    }
    cls_dict = {
        "kobert": (KoBertTokenizer, BertModel),
        "koelectra": (ElectraTokenizer, ElectraModel)
    }
    model_path = model_path_dict[model_type]
    tokenizer_cls, model_cls = cls_dict[model_type]
    tokenizer = tokenizer_cls.from_pretrained(model_path, add_special_tokens=True, additional_special_tokens=special_tokens)
    
    model = model_cls.from_pretrained(model_path)
    model.resize_token_embeddings(len(tokenizer))

    return model, tokenizer

In [238]:
def get_input_feature(example: SQLExample, config):
    max_total_length = int(config["max_total_length"])

    input_feature = InputFeature(
        example.question,
        example.table_id,
        [],
        example.word_to_char_start,
        [],
        [],
        [],
        [],
        []
    )

    for column, col_type, _ in example.column_meta:
        # get query tokens
        tokens = []
        word_to_subword = []
        subword_to_word = []
        for column, col_type, _ in example.column_meta:
            # get query tokens
            tokens = []
            word_to_subword = []
            subword_to_word = []
            for i, query_token in enumerate(example.tokens):
                sub_tokens = tokenizer.tokenize(query_token)
                cur_pos = len(tokens)
                if len(sub_tokens) > 0:
                    word_to_subword += [(cur_pos, cur_pos + len(sub_tokens))]
                    tokens.extend(sub_tokens)
                    subword_to_word.extend([i] * len(sub_tokens))
            tokenize_result = tokenizer(
                col_type + " " + column,
                tokens,
                max_length=max_total_length,
                truncation="longest_first",
                padding="max_length"
            )
            input_ids = tokenize_result["input_ids"]
            input_mask = tokenize_result["attention_mask"]

            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            column_token_length = 0
            for i, token_id in enumerate(input_ids):
                if token_id == tokenizer.sep_token_id:
                    column_token_length = i + 1
                    break
            segment_ids = [0] * max_total_length
            for i in range(column_token_length, max_total_length):
                if input_mask[i] == 0:
                    break
                segment_ids[i] = 1

            subword_to_word = [0] * column_token_length + subword_to_word
            word_to_subword = [(pos[0]+column_token_length, pos[1]+column_token_length) for pos in word_to_subword]
            
            assert len(input_ids) == max_total_length
            assert len(input_mask) == max_total_length
            assert len(segment_ids) == max_total_length
            
            input_feature.tokens.append(tokens)
            input_feature.word_to_subword.append(word_to_subword)
            input_feature.subword_to_word.append(subword_to_word)
            input_feature.input_ids.append(input_ids)
            input_feature.input_mask.append(input_mask)
            input_feature.segment_ids.append(segment_ids)

    return input_feature

def fill_label_feature(example: SQLExample, input_feature: InputFeature, config):
    max_total_length = int(config["max_total_length"])

    columns = [c[0] for c in example.column_meta]
    col_num = len(columns)
    input_feature.columns = columns

    input_feature.agg = [0] * col_num
    input_feature.agg[example.select] = example.agg
    input_feature.where_num = [len(example.conditions)] * col_num

    input_feature.select = [0] * len(columns)
    input_feature.select[example.select] = 1

    input_feature.where = [0] * len(columns)
    input_feature.op = [0] * len(columns)
    input_feature.value_start = [0] * len(columns)
    input_feature.value_end = [0] * len(columns)

    for colidx, op, _ in example.conditions:
        input_feature.where[colidx] = 1
        input_feature.op[colidx] = op
    for colidx, column_meta in enumerate(example.column_meta):
        if column_meta[-1] == None:
            continue
        se = example.value_start_end[column_meta[-1]]
        try:
            s = input_feature.word_to_subword[colidx][se[0]][0]
            input_feature.value_start[colidx] = s
            e = input_feature.word_to_subword[colidx][se[1]-1][1]-1
            input_feature.value_end[colidx] = e

            assert s < max_total_length and input_feature.input_mask[colidx][s] == 1
            assert e < max_total_length and input_feature.input_mask[colidx][e] == 1

        except:
            print("value span is out of range")
            return False

    # feature_sq = input_feature.output_SQ(return_str=False)
    # example_sq = example.output_SQ(return_str=False)
    # if feature_sq != example_sq:
    #     print(example.qid, feature_sq, example_sq)
    return True

In [239]:
def get_model_tokenizer(model_type: str, special_tokens=None):
    model_path_dict = {
        "kobert": "monologg/kobert",
        "koelectra": "monologg/koelectra-base-v3-discriminator"
    }
    cls_dict = {
        "kobert": (KoBertTokenizer, BertModel),
        "koelectra": (ElectraTokenizer, ElectraModel)
    }
    model_path = model_path_dict[model_type]
    tokenizer_cls, model_cls = cls_dict[model_type]
    tokenizer = tokenizer_cls.from_pretrained(model_path, add_special_tokens=True, additional_special_tokens=special_tokens)
    
    model = model_cls.from_pretrained(model_path)
    model.resize_token_embeddings(len(tokenizer))

    return model, tokenizer

model_type = "kobert"
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" 
phase = "train"
data_path = Path().absolute().parent / "src" / "data"
train_data_path = data_path / phase / "hydra_train.jsonl"

# Tokenizer treats company code as known
with (data_path / "company_codes.txt").open("r", encoding="utf-8") as file:
    company_codes = [line.strip() for line in file.readlines()]

model, tokenizer = get_model_tokenizer(model_type=model_type, special_tokens=company_codes)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [251]:
include_label = True
model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]}
if include_label:
    for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]:
        model_inputs[k] = []

pos = []
input_features = []
it = tqdm(open(train_data_path, encoding="utf-8"))
for line in it:
    example = SQLExample.load_from_json(line)
    if not example.valid and include_label == True:
        continue
    input_feature = get_input_feature(example, config)
    if include_label:
        success = fill_label_feature(example, input_feature, config)
        if not success:
            continue

    # sq = input_feature.output_SQ()
    input_features.append(input_feature)
    
    cur_start = len(model_inputs["input_ids"])
    cur_sample_num = len(input_feature.input_ids)
    pos.append((cur_start, cur_start + cur_sample_num))
    
    model_inputs["input_ids"].extend(input_feature.input_ids)
    model_inputs["input_mask"].extend(input_feature.input_mask)
    model_inputs["segment_ids"].extend(input_feature.segment_ids)
    if include_label:
        model_inputs["agg"].extend(input_feature.agg)
        model_inputs["select"].extend(input_feature.select)
        model_inputs["where_num"].extend(input_feature.where_num)
        model_inputs["where"].extend(input_feature.where)
        model_inputs["op"].extend(input_feature.op)
        model_inputs["value_start"].extend(input_feature.value_start)
        model_inputs["value_end"].extend(input_feature.value_end)
    
    it.update()
    
for k in model_inputs:
    model_inputs[k] = np.array(model_inputs[k], dtype=np.int64)

0it [00:00, ?it/s]

KeyboardInterrupt: 

In [247]:
import pickle 

with (data_path / phase / "hydra_train_preprocessed.pickle").open("wb") as file:
    pickle.dump(model_inputs, file)

<__main__.InputFeature at 0x1c4ee37bc40>

In [265]:
class HydraFeaturizer(object):
    def __init__(self, config):
        self.config = config
        self.tokenizer = create_tokenizer(config)
        self.colType2token = {
            "string": "[unused1]",
            "real": "[unused2]"}

    def get_input_feature(self, example: SQLExample, config):
        max_total_length = int(config["max_total_length"])

        input_feature = InputFeature(
            example.question,
            example.table_id,
            [],
            example.word_to_char_start,
            [],
            [],
            [],
            [],
            []
        )

        for column, col_type, _ in example.column_meta:
            # get query tokens
            tokens = []
            word_to_subword = []
            subword_to_word = []
            for i, query_token in enumerate(example.tokens):
                if self.config["base_class"] == "roberta":
                    sub_tokens = self.tokenizer.tokenize(query_token, add_prefix_space=True)
                else:
                    sub_tokens = self.tokenizer.tokenize(query_token)
                cur_pos = len(tokens)
                if len(sub_tokens) > 0:
                    word_to_subword += [(cur_pos, cur_pos + len(sub_tokens))]
                    tokens.extend(sub_tokens)
                    subword_to_word.extend([i] * len(sub_tokens))

            if self.config["base_class"] == "roberta":
                tokenize_result = self.tokenizer.encode_plus(
                    col_type + " " + column,
                    tokens,
                    padding="max_length",
                    max_length=max_total_length,
                    truncation=True,
                    add_prefix_space=True
                )
            else:
                tokenize_result = self.tokenizer.encode_plus(
                    col_type + " " + column,
                    tokens,
                    max_length=max_total_length,
                    truncation_strategy="longest_first",
                    pad_to_max_length=True,
                )

            input_ids = tokenize_result["input_ids"]
            input_mask = tokenize_result["attention_mask"]

            tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
            column_token_length = 0
            for i, token_id in enumerate(input_ids):
                if token_id == self.tokenizer.sep_token_id:
                    column_token_length = i + 1
                    break
            segment_ids = [0] * max_total_length
            for i in range(column_token_length, max_total_length):
                if input_mask[i] == 0:
                    break
                segment_ids[i] = 1

            subword_to_word = [0] * column_token_length + subword_to_word
            word_to_subword = [(pos[0]+column_token_length, pos[1]+column_token_length) for pos in word_to_subword]

            assert len(input_ids) == max_total_length
            assert len(input_mask) == max_total_length
            assert len(segment_ids) == max_total_length

            input_feature.tokens.append(tokens)
            input_feature.word_to_subword.append(word_to_subword)
            input_feature.subword_to_word.append(subword_to_word)
            input_feature.input_ids.append(input_ids)
            input_feature.input_mask.append(input_mask)
            input_feature.segment_ids.append(segment_ids)

        return input_feature

    def fill_label_feature(self, example: SQLExample, input_feature: InputFeature, config):
        max_total_length = int(config["max_total_length"])

        columns = [c[0] for c in example.column_meta]
        col_num = len(columns)
        input_feature.columns = columns

        input_feature.agg = [0] * col_num
        input_feature.agg[example.select] = example.agg
        input_feature.where_num = [len(example.conditions)] * col_num

        input_feature.select = [0] * len(columns)
        input_feature.select[example.select] = 1

        input_feature.where = [0] * len(columns)
        input_feature.op = [0] * len(columns)
        input_feature.value_start = [0] * len(columns)
        input_feature.value_end = [0] * len(columns)

        for colidx, op, _ in example.conditions:
            input_feature.where[colidx] = 1
            input_feature.op[colidx] = op
        for colidx, column_meta in enumerate(example.column_meta):
            if column_meta[-1] == None:
                continue
            se = example.value_start_end[column_meta[-1]]
            try:
                s = input_feature.word_to_subword[colidx][se[0]][0]
                input_feature.value_start[colidx] = s
                e = input_feature.word_to_subword[colidx][se[1]-1][1]-1
                input_feature.value_end[colidx] = e

                assert s < max_total_length and input_feature.input_mask[colidx][s] == 1
                assert e < max_total_length and input_feature.input_mask[colidx][e] == 1

            except:
                print("value span is out of range")
                return False

        # feature_sq = input_feature.output_SQ(return_str=False)
        # example_sq = example.output_SQ(return_str=False)
        # if feature_sq != example_sq:
        #     print(example.qid, feature_sq, example_sq)
        return True

    def load_data(self, data_paths, config, include_label=False):
        model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]}
        if include_label:
            for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]:
                model_inputs[k] = []

        pos = []
        input_features = []
        for data_path in data_paths.split("|"):
            cnt = 0
            print(data_path)
            for line in tqdm(open(data_path, encoding="utf8")):
                example = SQLExample.load_from_json(line)
                if not example.valid and include_label == True:
                    continue

                input_feature = self.get_input_feature(example, config)
                if include_label:
                    success = self.fill_label_feature(example, input_feature, config)
                    if not success:
                        continue

                # sq = input_feature.output_SQ()
                input_features.append(input_feature)

                cur_start = len(model_inputs["input_ids"])
                cur_sample_num = len(input_feature.input_ids)
                pos.append((cur_start, cur_start + cur_sample_num))

                model_inputs["input_ids"].extend(input_feature.input_ids)
                model_inputs["input_mask"].extend(input_feature.input_mask)
                model_inputs["segment_ids"].extend(input_feature.segment_ids)
                if include_label:
                    model_inputs["agg"].extend(input_feature.agg)
                    model_inputs["select"].extend(input_feature.select)
                    model_inputs["where_num"].extend(input_feature.where_num)
                    model_inputs["where"].extend(input_feature.where)
                    model_inputs["op"].extend(input_feature.op)
                    model_inputs["value_start"].extend(input_feature.value_start)
                    model_inputs["value_end"].extend(input_feature.value_end)

                cnt += 1
                if cnt % 5000 == 0:
                    print(cnt)

                if "DEBUG" in config and cnt > 100:
                    break

        for k in model_inputs:
            model_inputs[k] = np.array(model_inputs[k], dtype=np.int64)

        return input_features, model_inputs, pos

class SQLDataset(torch.utils.data.Dataset):
    def __init__(self, data_paths, config, featurizer, include_label=False):
        self.config = config
        self.featurizer = featurizer
        self.input_features, self.model_inputs, self.pos = self.featurizer.load_data(data_paths, config, include_label)

        print("{0} loaded. Data shapes:".format(data_paths))
        for k, v in self.model_inputs.items():
            print(k, v.shape)

    def __len__(self):
        return self.model_inputs["input_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.model_inputs.items()}

In [272]:
# data_path = Path().absolute().parent.parent / "wikisqldata"
train_data_path = data_path / "hydra_train.jsonl"
config = {
    "model_type":"pytorch",

    "SAVE":1,
    "train_data_path": str(data_path / "hydra_train.jsonl"),
    "dev_data_path": str(data_path / "hydra_dev.jsonl"),
    "test_data_path": str(data_path / "wikitest.jsonl"),

    "base_class":"bert",
    "base_name":"base",
    "max_total_length":96,
    "where_column_num":4,
    "op_num":4,
    "agg_num":6,

    "drop_rate":0.2,
    "learning_rate":3e-5,
    "decay":0.01,
    "epochs":5,
    "batch_size":256,
    "num_warmup_steps":400,
}
# model = create_base_model(config)
# tokenizer = create_tokenizer(config)
featurizer = HydraFeaturizer(config)

In [273]:
input_features, model_inputs, pos = featurizer.load_data(config["train_data_path"], config, include_label=True)

C:\Users\simon\Desktop\Codes\wikisqldata\hydra_train.jsonl


0it [00:00, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


5000
10000
value span is out of range
value span is out of range
15000
20000
25000
30000
35000
40000
45000
50000
55000


In [289]:
class TransformerConfig(object):
    def __init__(self, config):
        model_path_dict = {
            "kobert": "monologg/kobert",
            "koelectra": "monologg/koelectra-base-v3-discriminator"
        }
        cls_dict = {
            "kobert": (KoBertTokenizer, BertModel),
            "koelectra": (ElectraTokenizer, ElectraModel)
        }
        self.model_type = config["model_type"]
        self.special_tkns_path = config["special_tkns_path"]
        self.model_path = model_path_dict[self.model_type]
        self.tokenizer_cls, self.model_cls = cls_dict[self.model_type]
        
        
    def get_transfomers_model(self, tokenizer):
        model = self.model_cls.from_pretrained(self.model_path)
        model.resize_token_embeddings(len(tokenizer))
        return model

    def get_transfomers_tokenizer(self):
        special_tkns = self.get_special_tokens()
        tokenizer = self.tokenizer_cls.from_pretrained(
            self.model_path, 
            add_special_tokens=True, 
            additional_special_tokens=special_tkns
        )
        return tokenizer
    
    def get_special_tokens(self):
        with open(self.special_tkns_path, mode="r", encoding="utf-8") as file:
            special_tkns = [line.strip() for line in file.readlines()]
        
        special_tkns += ["[STRING]", "[REAL]", "[INT]"]
        return special_tkns

In [290]:
class InputFeature(object):
    def __init__(self,
                 qid,
                 question,
                 table_id,
                 tokens,
                 word_to_char_start,
                 word_to_subword,
                 subword_to_word,
                 input_ids,
                 input_mask,
                 segment_ids):
        self.qid = qid
        self.question = question
        self.table_id = table_id
        self.tokens = tokens
        self.word_to_char_start = word_to_char_start
        self.word_to_subword = word_to_subword
        self.subword_to_word = subword_to_word
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids

        self.columns = None
        self.agg = None
        self.select = None
        self.where_num = None
        self.where = None
        self.op = None
        self.value_start = None
        self.value_end = None
        
        self.keys = [
            "question", "table_id", "tokens", "word_to_char_start", "word_to_subword", "subword_to_word", "input_ids", "input_mask", "segment_ids",
            "columns", "agg", "select", "where_num", "where", "op", "value_start", "value_end"
        ]
    
    def __str__(self):
        s = ""
        for k in self.keys:
            s += f"{k}: "
#             s_res = self.__dict__[k]
#             if k == "tokens":
#                 s_res = [[tkn for tkn in s if tkn != "[PAD]"] for s in s_res ]
#             if k == "input_ids":
#                 s_res = [[tkn for tkn in s if tkn != 0] for s in s_res ]
            s += f"{self.__dict__[k]}\n"
        return s
    
    def output_SQ(self, agg = None, sel = None, conditions = None, return_str=True):
        agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        cond_ops = ['=', '>', '<', 'OP']

        if agg is None and sel is None and conditions is None:
            sel = np.argmax(self.select)
            agg = self.agg[sel]
            conditions = []
            for i in range(len(self.where)):
                if self.where[i] == 0:
                    continue
                conditions.append((i, self.op[i], self.value_start[i], self.value_end[i]))

        agg_text = agg_ops[agg]
        select_text = self.columns[sel]
        cond_texts = []
        for wc, op, vs, ve in conditions:
            column_text = self.columns[wc]
            op_text = cond_ops[op]
            word_start, word_end = self.subword_to_word[wc][vs], self.subword_to_word[wc][ve]
            char_start = self.word_to_char_start[word_start]
            char_end = len(self.question) if word_end + 1 >= len(self.word_to_char_start) else self.word_to_char_start[word_end + 1]
            value_span_text = self.question[char_start:char_end]
            cond_texts.append(column_text + op_text + value_span_text.rstrip())

        if return_str:
            sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts)
        else:
            sq = (agg_text, select_text, set(cond_texts))

        return sq

In [291]:
class HydraFeaturizer(object):
    def __init__(self, config):
        self.config = config
        self.tokenizer = TransformerConfig(config).get_transfomers_tokenizer()
        self.colType2token = {
            "string": "[STRING]",
            "real": "[REAL]",
            "integer": "[INT]" 
        }

    def get_input_feature(self, example: SQLExample, config):
        max_total_length = int(config["max_total_length"])

        input_feature = InputFeature(
            example.qid,
            example.question,
            example.table_id,
            [],
            example.word_to_char_start,
            [],
            [],
            [],
            [],
            []
        )

        for column, col_type, _ in example.column_meta:
            # get query tokens
            tokens = []
            word_to_subword = []
            subword_to_word = []
            for i, query_token in enumerate(example.tokens):
                sub_tokens = self.tokenizer.tokenize(query_token)
                cur_pos = len(tokens)
                if len(sub_tokens) > 0:
                    word_to_subword += [(cur_pos, cur_pos + len(sub_tokens))]
                    tokens.extend(sub_tokens)
                    subword_to_word.extend([i] * len(sub_tokens))
            
            tokenize_result = self.tokenizer(
                self.colType2token[col_type] + " " + column,
                tokens,
                max_length=max_total_length,
                truncation="longest_first",
                padding="max_length"
            )

            input_ids = tokenize_result["input_ids"]
            input_mask = tokenize_result["attention_mask"]

            tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
            column_token_length = 0
            for i, token_id in enumerate(input_ids):
                if token_id == self.tokenizer.sep_token_id:
                    column_token_length = i + 1
                    break
            segment_ids = [0] * max_total_length
            for i in range(column_token_length, max_total_length):
                if input_mask[i] == 0:
                    break
                segment_ids[i] = 1

            subword_to_word = [0] * column_token_length + subword_to_word
            word_to_subword = [(pos[0]+column_token_length, pos[1]+column_token_length) for pos in word_to_subword]

            assert len(input_ids) == max_total_length
            assert len(input_mask) == max_total_length
            assert len(segment_ids) == max_total_length

            input_feature.tokens.append(tokens)
            input_feature.word_to_subword.append(word_to_subword)
            input_feature.subword_to_word.append(subword_to_word)
            input_feature.input_ids.append(input_ids)
            input_feature.input_mask.append(input_mask)
            input_feature.segment_ids.append(segment_ids)

        return input_feature

    def fill_label_feature(self, example: SQLExample, input_feature: InputFeature, config):
        max_total_length = int(config["max_total_length"])

        columns = [c[0] for c in example.column_meta]
        col_num = len(columns)
        input_feature.columns = columns

        input_feature.agg = [0] * col_num
        input_feature.agg[example.select] = example.agg
        input_feature.where_num = [len(example.conditions)] * col_num

        input_feature.select = [0] * len(columns)
        input_feature.select[example.select] = 1

        input_feature.where = [0] * len(columns)
        input_feature.op = [0] * len(columns)
        input_feature.value_start = [0] * len(columns)
        input_feature.value_end = [0] * len(columns)

        for colidx, op, _ in example.conditions:
            input_feature.where[colidx] = 1
            input_feature.op[colidx] = op
        for colidx, column_meta in enumerate(example.column_meta):
            if column_meta[-1] == None:
                continue
            se = example.value_start_end[column_meta[-1]]
            try:
                s = input_feature.word_to_subword[colidx][se[0]][0]
                input_feature.value_start[colidx] = s
                e = input_feature.word_to_subword[colidx][se[1]-1][1]-1
                input_feature.value_end[colidx] = e

                assert s < max_total_length and input_feature.input_mask[colidx][s] == 1
                assert e < max_total_length and input_feature.input_mask[colidx][e] == 1

            except:
                print("value span is out of range")
                return False
        
        
        feature_sq = input_feature.output_SQ(return_str=False)
        example_sq = example.output_SQ(return_str=False)
        if feature_sq != example_sq:
            print(example.qid, feature_sq, example_sq)
        return True

    def load_data(self, data_paths, config, include_label=False):
        model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]}
        if include_label:
            for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]:
                model_inputs[k] = []

        pos = []
        input_features = []
        for data_path in data_paths.split("|"):
            cnt = 0
            print(data_path)
            for line in tqdm(open(data_path, encoding="utf8")):
                example = SQLExample.load_from_json(line)
                if not example.valid and include_label == True:
                    continue

                input_feature = self.get_input_feature(example, config)
                if include_label:
                    success = self.fill_label_feature(example, input_feature, config)
                    if not success:
                        continue

                # sq = input_feature.output_SQ()
                input_features.append(input_feature)

                cur_start = len(model_inputs["input_ids"])
                cur_sample_num = len(input_feature.input_ids)
                pos.append((cur_start, cur_start + cur_sample_num))

                model_inputs["input_ids"].extend(input_feature.input_ids)
                model_inputs["input_mask"].extend(input_feature.input_mask)
                model_inputs["segment_ids"].extend(input_feature.segment_ids)
                if include_label:
                    model_inputs["agg"].extend(input_feature.agg)
                    model_inputs["select"].extend(input_feature.select)
                    model_inputs["where_num"].extend(input_feature.where_num)
                    model_inputs["where"].extend(input_feature.where)
                    model_inputs["op"].extend(input_feature.op)
                    model_inputs["value_start"].extend(input_feature.value_start)
                    model_inputs["value_end"].extend(input_feature.value_end)

                cnt += 1
                if cnt % 5000 == 0:
                    print(cnt)
                
                if cnt > 100:
                    break
                
        for k in model_inputs:
            model_inputs[k] = np.array(model_inputs[k], dtype=np.int64)

        return input_features, model_inputs, pos

class SQLDataset(torch.utils.data.Dataset):
    def __init__(self, data_paths, config, featurizer, include_label=False):
        self.config = config
        self.featurizer = featurizer
        self.input_features, self.model_inputs, self.pos = self.featurizer.load_data(data_paths, config, include_label)

        print("{0} loaded. Data shapes:".format(data_paths))
        for k, v in self.model_inputs.items():
            print(k, v.shape)

    def __len__(self):
        return self.model_inputs["input_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.model_inputs.items()}

In [292]:
phase = "train"

data_path = Path().absolute().parent / "src" / "data"
config = {
    "model_type": "kobert",
    
    "SAVE": 1,
    "train_data_path": str(data_path / "train" / "hydra_train.jsonl"),
    # "dev_data_path": str(data_path / "hydra_dev.jsonl"),
    "test_data_path": str(data_path  / "test" / "hydra_test.jsonl"),
    "special_tkns_path": str(data_path / "company_codes.txt"),
    
    "max_total_length": 96,
    "where_column_num": 4,
    "op_num": 4,
    "agg_num": 6,

    "drop_rate": 0.2,
    "learning_rate": 3e-5,
    "decay": 0.01,
    "epochs": 5,
    "batch_size": 256,
    "num_warmup_steps": 400,
}
featurizer = HydraFeaturizer(config)
input_features, model_inputs, pos = featurizer.load_data(config["train_data_path"], config, include_label=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


C:\Users\simon\Desktop\Codes\Text2SQL\src\data\train\hydra_train.jsonl


0it [00:00, ?it/s]

63 ('NA', 'thstrm_amount', {'account_nm=2018', 'bsns_year=2018'}) ('NA', 'thstrm_amount', {'account_nm=유동자산', 'bsns_year=2018'})
65 ('NA', 'thstrm_amount', {'account_nm=2018', 'bsns_year=2018'}) ('NA', 'thstrm_amount', {'account_nm=유동자산', 'bsns_year=2018'})
67 ('NA', 'thstrm_amount', {'account_nm=2019', 'bsns_year=2019'}) ('NA', 'thstrm_amount', {'bsns_year=2019', 'account_nm=유동자산'})
69 ('NA', 'thstrm_amount', {'account_nm=2019', 'bsns_year=2019'}) ('NA', 'thstrm_amount', {'bsns_year=2019', 'account_nm=유동자산'})
71 ('NA', 'thstrm_amount', {'account_nm=2020', 'bsns_year=2020'}) ('NA', 'thstrm_amount', {'bsns_year=2020', 'account_nm=유동자산'})
73 ('NA', 'thstrm_amount', {'account_nm=2020', 'bsns_year=2020'}) ('NA', 'thstrm_amount', {'bsns_year=2020', 'account_nm=유동자산'})
75 ('NA', 'thstrm_amount', {'thstrm_nm=제 58 기', 'account_nm=제'}) ('NA', 'thstrm_amount', {'thstrm_nm=제 58 기', 'account_nm=유동자산'})
77 ('NA', 'thstrm_amount', {'thstrm_nm=제 58 기', 'account_nm=제'}) ('NA', 'thstrm_amount', {'thstr

In [295]:
data_paths = config["train_data_path"]
include_label=True

In [305]:
model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]}
if include_label:
    for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]:
        model_inputs[k] = []

pos = []
input_features = []
for data_path in data_paths.split("|"):
    cnt = 0
    print(data_path)
    for line in tqdm(open(data_path, encoding="utf8")):
        example = SQLExample.load_from_json(line)
        if not example.valid and include_label == True:
            continue

        input_feature = featurizer.get_input_feature(example, config)
        if example.qid == 63:
            break
    break

C:\Users\simon\Desktop\Codes\Text2SQL\src\data\train\hydra_train.jsonl


0it [00:00, ?it/s]

In [311]:
example.keys

['qid',
 'question',
 'table_id',
 'column_meta',
 'agg',
 'select',
 'conditions',
 'tokens',
 'char_to_word',
 'word_to_char_start',
 'value_start_end',
 'valid']

In [308]:
example.question

'000040의 2018연도의 유동자산은 얼마야?'

In [316]:
example.conditions

[[3, 0, '2018'], [10, 0, '유동자산']]

In [314]:
example.value_start_end

{'2018': [2, 3], '유동자산': [2, 3]}

In [299]:
max_total_length = int(config["max_total_length"])

columns = [c[0] for c in example.column_meta]
col_num = len(columns)
input_feature.columns = columns

In [302]:
input_feature.columns

In [None]:

input_feature.agg = [0] * col_num
input_feature.agg[example.select] = example.agg
input_feature.where_num = [len(example.conditions)] * col_num

input_feature.select = [0] * len(columns)
input_feature.select[example.select] = 1

input_feature.where = [0] * len(columns)
input_feature.op = [0] * len(columns)
input_feature.value_start = [0] * len(columns)
input_feature.value_end = [0] * len(columns)

for colidx, op, _ in example.conditions:
    input_feature.where[colidx] = 1
    input_feature.op[colidx] = op
for colidx, column_meta in enumerate(example.column_meta):
    if column_meta[-1] == None:
        continue
    se = example.value_start_end[column_meta[-1]]
    try:
        s = input_feature.word_to_subword[colidx][se[0]][0]
        input_feature.value_start[colidx] = s
        e = input_feature.word_to_subword[colidx][se[1]-1][1]-1
        input_feature.value_end[colidx] = e

        assert s < max_total_length and input_feature.input_mask[colidx][s] == 1
        assert e < max_total_length and input_feature.input_mask[colidx][e] == 1

    except:
        print("value span is out of range")
        return False


feature_sq = input_feature.output_SQ(return_str=False)
example_sq = example.output_SQ(return_str=False)
if feature_sq != example_sq:
    print(example.qid, feature_sq, example_sq)