# TEXT2SQL with transformers

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import transformers

from typing import Tuple, Dict, List, Union, Any
print(f"PyTroch Version: {torch.__version__}")
print(f"Transfomers Version: {transformers.__version__}")

PyTroch Version: 1.8.1
Transfomers Version: 4.6.1


## Data to build

In [10]:
from pathlib import Path
import re
import records

schema_re = re.compile(r'\((.+)\)')
num_re = re.compile(r'[-+]?\d*\.\d+|\d+')

db_path = Path("./private")
db = records.Database(f"sqlite:///{db_path / 'samsung_new.db'}")

table_id = "receipts"
table_info = db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql
schema_str = schema_re.findall(table_info.replace("\n", ""))[0]
schema = {}
for tup in schema_str.split(', '):
    c, t = tup.split()
    schema[c.strip('"')] = t

In [11]:
# sqlite can have NULL, INTEGER, REAL, TEXT, BLOB data type
schema

{'index': 'INTEGER',
 'rcept_no': 'TEXT',
 'reprt_code': 'TEXT',
 'bsns_year': 'INTEGER',
 'corp_code': 'TEXT',
 'stock_code': 'TEXT',
 'fs_div': 'TEXT',
 'fs_nm': 'TEXT',
 'sj_div': 'TEXT',
 'sj_nm': 'TEXT',
 'account_nm': 'TEXT',
 'thstrm_nm': 'TEXT',
 'thstrm_dt': 'TEXT',
 'thstrm_amount': 'INTEGER',
 'frmtrm_nm': 'TEXT',
 'frmtrm_dt': 'TEXT',
 'frmtrm_amount': 'INTEGER',
 'bfefrmtrm_nm': 'TEXT',
 'bfefrmtrm_dt': 'TEXT',
 'bfefrmtrm_amount': 'INTEGER'}

In [12]:
set(schema.values())

{'INTEGER', 'TEXT'}

In [13]:
question = "제 51 기에 삼성전자의 유동자산은 어떻게 돼?"
cls_token = "[CLS]"
table_token = "[T]"
column_token = "[C]"

f"{cls_token} {question} [T] {table_id} " + " ".join([f"{column_token} {col} [{typ}]" for col, typ in schema.items()])

'[CLS] 제 51 기에 삼성전자의 유동자산은 어떻게 돼? [T] receipts [C] index [INTEGER] [C] rcept_no [TEXT] [C] reprt_code [TEXT] [C] bsns_year [INTEGER] [C] corp_code [TEXT] [C] stock_code [TEXT] [C] fs_div [TEXT] [C] fs_nm [TEXT] [C] sj_div [TEXT] [C] sj_nm [TEXT] [C] account_nm [TEXT] [C] thstrm_nm [TEXT] [C] thstrm_dt [TEXT] [C] thstrm_amount [INTEGER] [C] frmtrm_nm [TEXT] [C] frmtrm_dt [TEXT] [C] frmtrm_amount [INTEGER] [C] bfefrmtrm_nm [TEXT] [C] bfefrmtrm_dt [TEXT] [C] bfefrmtrm_amount [INTEGER]'

## Build a cumstom Tokenizer

https://huggingface.co/docs/tokenizers/python/latest/pipeline.html

### Normalization

- `normalizers`는 raw text를 더 깨끗하게 만드는 과정이다.
- `NFD` 사용하게 되면 한글은 자음 모음으로 분리된다.
    - NFD(Normalization Form Canonical Decomposition) = 조합형
    - NFC(Normalizaiton Form Canonical Compostion) = 완성형

In [7]:
from tokenizers import normalizers
from tokenizers.normalizers import NFD, StripAccents

normalizer = normalizers.Sequence([NFD(), StripAccents()])
# 일반 출력시 합쳐져서 보이지만, for문을 사용하면 분리된다. 
print(normalizer.normalize_str(question))
print("/".join([x for x in normalizer.normalize_str(question)]))

제 51 기에 삼성전자의 유동자산은 어떻게 돼?
ᄌ/ᅦ/ /5/1/ /ᄀ/ᅵ/ᄋ/ᅦ/ /ᄉ/ᅡ/ᆷ/ᄉ/ᅥ/ᆼ/ᄌ/ᅥ/ᆫ/ᄌ/ᅡ/ᄋ/ᅴ/ /ᄋ/ᅲ/ᄃ/ᅩ/ᆼ/ᄌ/ᅡ/ᄉ/ᅡ/ᆫ/ᄋ/ᅳ/ᆫ/ /ᄋ/ᅥ/ᄄ/ᅥ/ᇂ/ᄀ/ᅦ/ /ᄃ/ᅫ/?


### Pre-Tokenization

- `pre_tokenizers`는 텍스트를 더 작은 토큰으로 분리하는 과정이다.
- `Whitespace`는 토큰을 공백을 기준으로 나누면 string의 위치를 같이 반환한다. (start + end) - 기준 regex: `\w+|[^\w\s]+`
- `Punctuation`은 문장 부호를 각각 분리한다. 
- `Digits`를 통해 숫자를 각각 분리할지 말지 결정할 수 있다. 

In [8]:
from tokenizers import pre_tokenizers
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Digits

print("Whitespace: ")
print(Whitespace().pre_tokenize_str(question))

print("Digits individual: False")
pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=False)])
print(pre_tokenizer.pre_tokenize_str(question + "????"))

print("Punctuation + Individual Digits")
pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Punctuation(), Digits(individual_digits=True)])
print(pre_tokenizer.pre_tokenize_str(question + "????"))

Whitespace: 
[('제', (0, 1)), ('51', (2, 4)), ('기에', (5, 7)), ('삼성전자의', (8, 13)), ('유동자산은', (14, 19)), ('어떻게', (20, 23)), ('돼', (24, 25)), ('?', (25, 26))]
Digits individual: False
[('제', (0, 1)), ('51', (2, 4)), ('기에', (5, 7)), ('삼성전자의', (8, 13)), ('유동자산은', (14, 19)), ('어떻게', (20, 23)), ('돼', (24, 25)), ('?????', (25, 30))]
Punctuation + Individual Digits
[('제', (0, 1)), ('5', (2, 3)), ('1', (3, 4)), ('기에', (5, 7)), ('삼성전자의', (8, 13)), ('유동자산은', (14, 19)), ('어떻게', (20, 23)), ('돼', (24, 25)), ('?', (25, 26)), ('?', (26, 27)), ('?', (27, 28)), ('?', (28, 29)), ('?', (29, 30))]


### The Model

- `BPE`: Byte-Pair Encoding 토큰화, 자주 등장하는 character를 합쳐서 표현하는 알고리즘
- `Unigram`: 확률적으로 최적의 subword 토큰을 결정
- `WordLevel`: 단어 단위의 토큰화
- `WordPiece`: Google WordPiece 토큰화

In [38]:
from tokenizers.models import BPE, Unigram, WordLevel, WordPiece

### Post-Processing

- `processors` 에서 후처리를 할 수 있다.
- `TemplateProcessing`을 이용해 원하는 형태로 토큰을 분리할 수 있다.

In [42]:
from tokenizers.processors import PostProcessor

In [None]:
from tokenizers.processors import TemplateProcessing

post_processor = TemplateProcessing(
    single="[CLS] $A [T] $B [C] $B",
    pair="[CLS] $A [T] $B:1 [C] $B:1",
    special_tokens=[("[CLS]", 1), ("[SEP]", 2)]#, ("[T]", 3), ("[C]", 4), ("[INTEGER]", 5), ("[REAL]", 6), ("[TEXT]", 7), ("[BLOB]", 8)],
)

---

- https://huggingface.co/docs/tokenizers/python/latest/index.html
- https://huggingface.co/transformers/main_classes/tokenizer.html?highlight=pretrainedtokenizer#transformers.PreTrainedTokenizer

In [14]:
input_str = f"{question} [T] {table_id} " + "".join([f"[SEP]{col}" for col in schema]) 
# " ".join([f"{column_token} {col} [{typ}]" for col, typ in schema.items()])
input_str

'제 51 기에 삼성전자의 유동자산은 어떻게 돼? [T] receipts [SEP]index[SEP]rcept_no[SEP]reprt_code[SEP]bsns_year[SEP]corp_code[SEP]stock_code[SEP]fs_div[SEP]fs_nm[SEP]sj_div[SEP]sj_nm[SEP]account_nm[SEP]thstrm_nm[SEP]thstrm_dt[SEP]thstrm_amount[SEP]frmtrm_nm[SEP]frmtrm_dt[SEP]frmtrm_amount[SEP]bfefrmtrm_nm[SEP]bfefrmtrm_dt[SEP]bfefrmtrm_amount'

In [15]:
batch_qs = ["제 51 기에 삼성전자의 유동자산은 어떻게 돼?", "2020년도 삼성전자의 유동자산은 얼마?"]
table_str = f"[T]{table_id}" + "".join([f"[SEP]{col}" for col in schema]) 
batch_ts = [table_str] * len(batch_qs)

In [16]:
from KoBertTokenizer import KoBertTokenizer

In [17]:
# new_special_tokens = ["[T]", "[C]", "[INTEGER]", "[REAL]", "[TEXT]", "[BLOB]"]
# new_special_tokens = list(map(lambda x: AddedToken(x, single_word=True, normalized=False), new_special_tokens))
# tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert', add_special_tokens=True, additional_special_tokens=new_special_tokens)
tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert')

In [18]:
special_tokenes2idx = dict(zip(tokenizer.additional_special_tokens, tokenizer.additional_special_tokens_ids))
special_tokenes2idx

{}

In [19]:
encode_input = tokenizer(
    batch_qs, batch_ts, 
    max_length=512, padding=True, truncation=True, return_tensors="pt", 
    return_attention_mask=True, 
    return_special_tokens_mask=False, 
)
encode_input.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [20]:
batch_size, max_length = encode_input["input_ids"].size()
question_mask = torch.bitwise_and(encode_input["token_type_ids"] == 0, encode_input["attention_mask"].bool())
question_mask[:, 0] = False  # [CLS] mask out
question_mask[:, -1] = False  # [SEP] mask out

# table_mask = torch.where(
#     (encode_input["input_ids"] == 8002), 
#     torch.ones_like(encode_input["input_ids"], dtype=torch.bool), 
#     torch.zeros_like(encode_input["input_ids"], dtype=torch.bool)
# )
# column_mask = torch.where(
#     (encode_input["input_ids"] == 8003),
#     torch.ones_like(encode_input["input_ids"], dtype=torch.bool), 
#     torch.zeros_like(encode_input["input_ids"], dtype=torch.bool)
# )

In [21]:
for x in encode_input["input_ids"]:
    print(tokenizer.decode(x, skip_special_tokens=False))
    print()

[CLS] 제 51 기에 삼성전자의 유동자산은 어떻게 돼?[SEP] [T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP]

[CLS] 2020년도 삼성전자의 유동자산은 얼마?[SEP] [T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP][PAD][PAD]



In [22]:
for inputs, type_ids in zip(encode_input["input_ids"], encode_input["token_type_ids"]):
    print("--- type_ids = 0")
    print(tokenizer.decode(inputs[type_ids == 0]).replace("[PAD]", ""))
    print("--- type_ids = 1")
    print(tokenizer.decode(inputs[type_ids == 1]))
    print()

--- type_ids = 0
[CLS] 제 51 기에 삼성전자의 유동자산은 어떻게 돼?[SEP]
--- type_ids = 1
[T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP]

--- type_ids = 0
[CLS] 2020년도 삼성전자의 유동자산은 얼마?[SEP]
--- type_ids = 1
[T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP]



## Encoder

In [23]:
from transformers import BertModel, BertConfig
model = BertModel.from_pretrained("monologg/kobert")
model.resize_token_embeddings(len(tokenizer))

Embedding(8002, 768, padding_idx=1)

In [52]:
model.config.output_hidden_states = True

In [53]:
encode_output = model(**encode_input)

In [58]:
len(encode_output.hidden_states)

13

## Decoder

<img src="https://drive.google.com/uc?id=1PW9oAXfW-ZI-jxGn5q9O_gzUIZnNYaet" alt="Sqlova Decoder Architecture " width="100%" height="auto">

### Select Column

$$\begin{aligned} 
s(n\vert c) &= D_c^T W E_n \\
p(n\vert c) &= \text{softmax}(s(n\vert c))\\
C_c &= \sum_n p(n \vert c) E_n \\
s_{sc}(c) &= W  \tanh ([WD_c; WC_c]) \\
p_{sc}(c) &= \text{softmax}(s_{sc}(c))
\end{aligned}$$

$E_n$ is LSTM output of $n$-th token of question, $D_c$ is the encoding of header $c$, $C_n$ is context vector of question for given column, $[\cdot ; \cdot]$ is concatenation of two vectors, $p_{sc}(c)$ is probability of generating column c.

In [1]:
class DBEngine:

    def __init__(self, fdb):
        #fdb = 'data/test.db'
        self.db = records.Database('sqlite:///{}'.format(fdb))

    def execute_query(self, table_id, query, *args, **kwargs):
        return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)

    def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
            
        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
        schema_str = schema_re.findall(table_info)[0]
        schema = {}
        for tup in schema_str.split(', '):
            c, t = tup.split()
            schema[c] = t
        select = 'col{}'.format(select_index)
        agg = agg_ops[aggregation_index]
        if agg:
            select = '{}({})'.format(agg, select)
        where_clause = []
        where_map = {}
        for col_index, op, val in conditions:
            if lower and (isinstance(val, str) or isinstance(val, str)):
                val = val.lower()
            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
                try:
                    # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
                    # val = float(parse_decimal(val)) # somehow it generates error.
                    val = float(parse_decimal(val, locale='en_US'))
                    # print('!!!!!!After: val', val)

                except NumberFormatError as e:
                    try:
                        val = float(num_re.findall(val)[0]) # need to understand and debug this part.
                    except:
                        # Although column is of number, selected one is not number. Do nothing in this case.
                        pass
            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
            where_map['col{}'.format(col_index)] = val
        where_str = ''
        if where_clause:
            where_str = 'WHERE ' + ' AND '.join(where_clause)
        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)

        #print query
        out = self.db.query(query, **where_map)

        return [o.result for o in out]

    def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
        schema_str = schema_re.findall(table_info)[0]
        schema = {}
        for tup in schema_str.split(', '):
            c, t = tup.split()
            schema[c] = t
        select = 'col{}'.format(select_index)
        agg = agg_ops[aggregation_index]
        if agg:
            select = '{}({})'.format(agg, select)
        where_clause = []
        where_map = {}
        for col_index, op, val in conditions:
            if lower and (isinstance(val, str) or isinstance(val, str)):
                val = val.lower()
            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
                try:
                    # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
                    # val = float(parse_decimal(val)) # somehow it generates error.
                    val = float(parse_decimal(val, locale='en_US'))
                    # print('!!!!!!After: val', val)

                except NumberFormatError as e:
                    val = float(num_re.findall(val)[0])
            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
            where_map['col{}'.format(col_index)] = val
        where_str = ''
        if where_clause:
            where_str = 'WHERE ' + ' AND '.join(where_clause)
        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)

        #print query
        out = self.db.query(query, **where_map)


        return [o.result for o in out], query
    def show_table(self, table_id):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
        rows = self.db.query('select * from ' +table_id)
        print(rows.dataset)

---