# TEXT2SQL with transformers

Lee Woo Chul, Jang Ji Soo

---

In [1]:
import json
import torch
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import transformers

from typing import Tuple, Dict, List, Union, Any
import os

from dbengine import DBEngine
# multiprocessing lib doesn’t have it implemented on Windows
# https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/14
num_workers = 0 if os.name == "nt" else 4

print(f"PyTroch Version: {torch.__version__}")
print(f"Transfomers Version: {transformers.__version__}")

PyTroch Version: 1.8.1
Transfomers Version: 4.6.1


# Data Description

`NLSQL.jsonl` and `"table.jsonl` contains the data like following format same with [WikiSQL](https://github.com/salesforce/WikiSQL), Please follow the [link](https://github.com/salesforce/WikiSQL#content-and-format) to see what are the keys mean.

```json
// example of 'NLSQL.jsonl'
{
    "phase": 1, 
    "question": "2015 삼성전자 유동자산은 어떻게 돼?", 
    "table_id": "receipts", 
    "sql": {
        "sel": 16, 
        "agg": 0, 
        "conds": [[10, 0, "유동자산"], [3, 0, 2016]]
    }
}
```

In [2]:
def load_data(sql_path, table_path):
    path_sql = Path(sql_path)
    path_table = Path(table_path)

    dataset = []
    table = {}
    with path_sql.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            x = json.loads(line.strip())
            dataset.append(x)

    with path_table.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            x = json.loads(line.strip())
            table[x['id']] = x
            
    return dataset, table

In [3]:
data, table = load_data("NLSQL.jsonl", "table.jsonl")
data_loader = torch.utils.data.DataLoader(
    batch_size=2,
    dataset=data,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=lambda x: x # now dictionary values are not merged!
)
# Load DBEngine
db_path = Path("./private")
dbengine = DBEngine(db_path / "samsung_new.db")

In [4]:
for i, batch_data in enumerate(tqdm(data_loader, desc="Test with toy data")):
    break

Test with toy data:   0%|          | 0/21120 [00:00<?, ?it/s]

# Model

## Encoder

Used BERT in hugging Face with KoBERT

- https://github.com/SKTBrain/KoBERT
- https://github.com/monologg/KoBERT-Transformers

In [11]:
from KoBertTokenizer import KoBertTokenizer
from transformers import BertModel, BertConfig

In [15]:


def get_bert(model_path: str, output_hidden_states: bool=False):
    special_tokens = ["[S]", "[E]", "[COL]"] # sequence start, sequence end, column tokens
    tokenizer = KoBertTokenizer.from_pretrained(model_path, add_special_tokens=True, additional_special_tokens=special_tokens)
    config = BertConfig.from_pretrained(model_path)
    config.output_hidden_states = output_hidden_states
    
    model = BertModel.from_pretrained(model_path)
    model.resize_token_embeddings(len(tokenizer))
    model.config.output_hidden_states = output_hidden_states
    
    return model, tokenizer, config

In [16]:
model_path = "monologg/kobert"
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" 

model_bert, tokenizer_bert, config_bert = get_bert(model_path=model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# def get_batch_data(data, dbengine):
#     batch_qs = [jsonl["question"] for jsonl in data]
#     tid = [jsonl["table_id"] for jsonl in data]
#     batch_sqls = [jsonl["sql"] for jsonl in data]
#     batch_ts = []
#     for table_id in tid:
#         dbengine.get_schema_info(table_id)
#         table_str = f"{table_id}" + "".join([
#             f"[COL]{col}" for col in dbengine.schema
#         ]) 
#         batch_ts.append(table_str)
    
#     return batch_qs, batch_sqls, batch_ts

def get_batch_data(data: List[Dict[str, Any]], table: Dict[str, Dict[str, List[Any]]], start_tkn="[S]", end_tkn="[E]") -> Tuple[List[str], List[str], List[Dict[str, Any]]]:
   
    batch_qs = [jsonl["question"] for jsonl in data]
    tid = [jsonl["table_id"] for jsonl in data]
    batch_sqls = [jsonl["sql"] for jsonl in data]
    batch_ts = []
    for table_id in tid:
        table_str = f"{table_id}" + "".join([
            f"[COL]{col}" for col in table[table_id]["header"]
        ])
        # TODO: [EXP] Experiment for generate column directly
        # table_str = f"{start_tkn}{table_id}{end_tkn}" + "".join([
        #     f"{col_tkn}{start_tkn}{col}{end_tkn}" for col in dbengine.schema
        # ]) 
        batch_ts.append(table_str)

    return batch_qs, batch_ts, batch_sqls

In [6]:
batch_qs, batch_ts, batch_sqls = get_batch_data(batch_data, table)

In [18]:
# Get Input
encode_input = tokenizer_bert(
    batch_qs, batch_ts, 
    max_length=512, padding=True, truncation=True, return_tensors="pt", 
    return_attention_mask=True, 
    return_special_tokens_mask=False, 
)

In [20]:
# Show an Example of Input
print(tokenizer_bert.decode(encode_input["input_ids"][0]))

[CLS] 삼성전자의 2017년의 유동자산은 어떻게 돼?[SEP] receipts [COL] index [COL] rcept_no [COL] reprt_code [COL] bsns_year [COL] corp_code [COL] stock_code [COL] fs_div [COL] fs_nm [COL] sj_div [COL] sj_nm [COL] account_nm [COL] thstrm_nm [COL] thstrm_dt [COL] thstrm_amount [COL] frmtrm_nm [COL] frmtrm_dt [COL] frmtrm_amount [COL] bfefrmtrm_nm [COL] bfefrmtrm_dt [COL] bfefrmtrm_amount[SEP]


## Prepare for decoder Inputs: Createing masks

In [21]:
def get_decoder_input_mask(input_ids, mask, batch_size, start_tkn_id, end_tkn_id):
    r"""
    input should only contains word tokens:
    """
    start_tkn_mask = input_ids == start_tkn_id
    end_tkn_mask = input_ids == end_tkn_id
    start_end_mask = torch.bitwise_or(start_tkn_mask, end_tkn_mask)
    index = torch.arange(input_ids.size(1)).repeat(batch_size)[start_end_mask.view(-1)].view(batch_size, -1)
    return mask.scatter(1, index, False)

def get_input_mask_and_answer(encode_input, tokenizer):
    r"""
    table -> database table name(id)
    header -> database header
    
    returns:
        input_question_mask, input_table_mask, input_header_mask, answer_table_tkns, answer_header_tkns
    """
    batch_size, max_length = encode_input["input_ids"].size()
    sep_tkn_mask = encode_input["input_ids"] == tokenizer.sep_token_id
    start_tkn_id, end_tkn_id, col_tkn_id = tokenizer.additional_special_tokens_ids
    
    input_question_mask = torch.bitwise_and(encode_input["token_type_ids"] == 0, encode_input["attention_mask"].bool())
    input_question_mask = torch.bitwise_and(input_question_mask, ~sep_tkn_mask) # [SEP] mask out
    input_question_mask[:, 0] = False  # [CLS] mask out

    db_mask = torch.bitwise_and(encode_input["token_type_ids"] == 1, encode_input["attention_mask"].bool())
    db_mask = torch.bitwise_xor(db_mask, sep_tkn_mask)
    col_tkn_mask = encode_input["input_ids"] == col_tkn_id
    db_mask = torch.bitwise_and(db_mask, ~col_tkn_mask)
    # split table_mask and header_mask
    input_idx = torch.arange(max_length).repeat(batch_size, 1)
    db_idx = input_idx[db_mask]
    table_header_tkn_idx = db_idx[db_idx > 0]
    table_start_idx = table_header_tkn_idx.view(batch_size, -1)[:, 0] + 1
    start_idx = table_header_tkn_idx[1:][table_header_tkn_idx.diff() == 2].view(batch_size, -1)
    table_end_sep_idx = start_idx[:, 0] - 1
    split_size = torch.stack([
        table_end_sep_idx-table_start_idx+1, table_header_tkn_idx.view(batch_size, -1).size(1)-(table_end_sep_idx-table_start_idx+1)
    ]).transpose(0, 1)

    # Token idx
    table_tkn_idx, header_tkn_idx = map(
        lambda x: torch.stack(x), 
        zip(*[torch.split(x, size.tolist()) for x, size in zip(table_header_tkn_idx.view(batch_size, -1), split_size)])
    )

    table_tkn_idx = table_tkn_idx[:, 1:]
    # Mask include [S] & [E] tokens
    table_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, table_tkn_idx, True)
    header_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, header_tkn_idx, True)

    # For Decoder Input, Maskout [S], [E] for table & header  
    input_table_mask = get_decoder_input_mask(
        encode_input["input_ids"], table_tkn_mask, batch_size, start_tkn_id, end_tkn_id
    )
    input_header_mask = get_decoder_input_mask(
        encode_input["input_ids"], header_tkn_mask, batch_size, start_tkn_id, end_tkn_id
    )
    # [COL] token mask: this is for attention
    col_tkn_idx = input_idx[col_tkn_mask].view(batch_size, -1)
    input_col_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, col_tkn_idx, True)

    return input_question_mask, input_table_mask, input_header_mask, input_col_mask # , answer_table_tkns, answer_header_tkns

In [22]:
input_question_mask, input_table_mask, input_header_mask, input_col_mask = get_input_mask_and_answer(encode_input, tokenizer_bert)

In [23]:
for m, t in zip(
        [input_question_mask, input_table_mask, input_header_mask, input_col_mask], 
        ["Question Tokens for Decoder", "Table Tokens for Decoder", "Header Tokens for Decoder", "Column(Index of Headers) Tokens for Decoder"]
    ):
    print(t)
    print("-----"*5)
    print(tokenizer_bert.decode(encode_input["input_ids"][m]))
    print()

Question Tokens for Decoder
-------------------------
삼성전자의 2017년의 유동자산은 어떻게 돼? 50기 삼성전자 비유동부채는 어떻게 돼?

Table Tokens for Decoder
-------------------------
receipts receipts

Header Tokens for Decoder
-------------------------
index rcept_no reprt_code bsns_year corp_code stock_code fs_div fs_nm sj_div sj_nm account_nm thstrm_nm thstrm_dt thstrm_amount frmtrm_nm frmtrm_dt frmtrm_amount bfefrmtrm_nm bfefrmtrm_dt bfefrmtrm_amount index rcept_no reprt_code bsns_year corp_code stock_code fs_div fs_nm sj_div sj_nm account_nm thstrm_nm thstrm_dt thstrm_amount frmtrm_nm frmtrm_dt frmtrm_amount bfefrmtrm_nm bfefrmtrm_dt bfefrmtrm_amount

Column(Index of Headers) Tokens for Decoder
-------------------------
[COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL]



In [24]:
# Feed to BERT Model
encode_outputs = model_bert(**encode_input)


The `encode_outputs` will be selected by 4 types of masks
```
encode_outputs
-> Question
-> Table
-> Header
-> Column(Index of Headers)
```

And pad batches which has less tokens than max length with "\[PAD\]"  for Decoder Input


In [25]:
def pad(batches: Tuple[torch.Tensor], lengths: List[int], model: BertModel, pad_idx: int=1) -> torch.Tensor:
    padded = []
    max_length = max(lengths)
    for x in batches:
        if len(x) < max_length:
            pad_tensor = model.embeddings.word_embeddings(torch.LongTensor([pad_idx]*(max_length - len(x))))
            padded.append(torch.cat([x, pad_tensor]))
        else:
            padded.append(x)
    return torch.stack(padded)

def get_decoder_batches(encode_output, mask, model, pad_idx):
    lengths = mask.sum(1)
    tensors = encode_output.last_hidden_state[mask, :]
    batches = torch.split(tensors, lengths.tolist())
    if lengths.ne(lengths.max()).sum().item() != 0:
        # pad not same length tokens
        tensors_padded = pad(batches, lengths.tolist(), model, pad_idx=pad_idx)
    else:
        # just stack the splitted tensors
        tensors_padded = torch.stack(batches)
    return tensors_padded, lengths.tolist()

# def get_pad_mask(lengths):
#     batch_size = len(lengths)
#     max_len = max(lengths)
#     mask = torch.ones(batch_size, max_len)
#     for i, l in enumerate(lengths):
#         mask[i, :l] = 0
#     return mask

In [26]:
question_padded, question_lengths = get_decoder_batches(encode_outputs, input_question_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
table_padded, table_lengths = get_decoder_batches(encode_outputs, input_table_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
header_padded, header_lengths = get_decoder_batches(encode_outputs, input_header_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
col_padded, col_lengths = get_decoder_batches(encode_outputs, input_col_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)

In [18]:
type(model_bert)

transformers.models.bert.modeling_bert.BertModel

## Create the Answers for decoder output

In [27]:
def get_sql_answers(batch_sqls, tokenizer, end_tkn_idx=1):
    """
    for backward compatibility, separated with get_g
    
    sc: select column
    sa: select agg
    wn: where number
    wc: where column
    wo: where operator
    wv: where value
    """

    get_ith_element = lambda li, i: [x[i] for x in li]
    g_sc = []
    g_sa = []
    g_wn = []
    g_wc = []
    g_wo = []
    g_wv = []
    for b, sql_dict in enumerate(batch_sqls):
        g_sc.append( sql_dict["sel"] )
        g_sa.append( sql_dict["agg"])

        conds = sql_dict["conds"]
        if not sql_dict["agg"] < 0:
            g_wn.append( len(conds) )
            g_wc.append( get_ith_element(conds, 0) )
            g_wo.append( get_ith_element(conds, 1) )
            g_wv.append( get_ith_element(conds, 2) )
        else:
            raise EnvironmentError
    
    # get where value tokenized 
    end_tkn = tokenizer.additional_special_tokens[end_tkn_idx]
    pad_tkn_id = tokenizer.pad_token_id
    g_wv_tkns = [[f"{s}{end_tkn}" for s in batch_wv] for batch_wv in g_wv]
    g_wv_tkns = [tokenizer(batch_wv, add_special_tokens=False)["input_ids"] for batch_wv in g_wv_tkns]
    # add empty list if batch has different where column number
    max_where_cols = max([len(batch_wv) for batch_wv in g_wv_tkns])
    g_wv_tkns = [batch_wv + [[]]*(max_where_cols-len(batch_wv)) if len(batch_wv) < max_where_cols else batch_wv for batch_wv in g_wv_tkns]
    temp = []
    for batch_wv in list(zip(*g_wv_tkns)):
        batch_max_len = max(map(len, batch_wv))
        batch_temp = []
        for wv_tkns in batch_wv:  # iter by number of where clause
            if len(wv_tkns) < batch_max_len:
                batch_temp.append(wv_tkns + [pad_tkn_id]*(batch_max_len - len(wv_tkns)))
            else:
                batch_temp.append(wv_tkns)
        temp.append(batch_temp)
    g_wv_tkns = list(zip(*temp))

    return g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns

In [28]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns = get_sql_answers(batch_sqls, tokenizer_bert, 1)
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns

([16, 16],
 [0, 0],
 [2, 2],
 [[10, 3], [10, 3]],
 [[0, 0], [0, 0]],
 [['유동자산', 2018], ['비유동부채', 2019]],
 [([3574, 5872, 7162, 8003, 1, 1], [554, 115, 8003]),
  ([2514, 7063, 5872, 6398, 7405, 8003], [554, 116, 8003])])

## Decoder

Similar structure in SQLova but a little difference in here.

- SQLova is a neural semantic parser translating natural language utterance to SQL query.
- Official Github: [https://github.com/naver/sqlova](https://github.com/naver/sqlova)
- Paper: [A Comprehensive Exploration on WikiSQL with Table-Aware Word Contextualization](https://arxiv.org/abs/1902.01069)

<img src="https://drive.google.com/uc?id=1PW9oAXfW-ZI-jxGn5q9O_gzUIZnNYaet" alt="Sqlova Decoder Architecture " width="50%" height="auto">

## Attention Layers

In [7]:
class AttentionBase(nn.Module):
    def __init__(self):
        super().__init__()
    
    def wipe_out_pad_tkn_score(self, score, lengths, dim=2):
        max_len = max(lengths)
        for batch_idx, length in enumerate(lengths):
            if length < max_len:
                if dim == 2:
                    score[batch_idx, :, length:] = -10000000
                elif dim == 1:
                    score[batch_idx, length:, :] = 0.0
                else:
                    raise ValueError(f"`dim` in wipe_out_pad_tkn_score should be 1 or 2")
        return score 


class C2QAttention(AttentionBase):
    r"""Decoder Column to Question Attention Module"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, o_c, o_q, q_lengths, c_lengths=None, rt_attn=False):
        r"""
        Calculate for each column tokens, How much related to question tokens?
        
        o_c: LSTM output of column
        o_q: LSTM output of question 
        
        c_lengths: wipe out row length
        return context atttended to question tokens
        """
        sqrt_H = torch.sqrt(torch.FloatTensor([o_c.size(-1)], device=o_c.device))  # Apply Attention is All you Need Technique
        o_q_transform = self.linear(o_q)  # (B, T_q, H)
        score_c2q = torch.bmm(o_c, o_q_transform.transpose(1, 2)) / sqrt_H  # (B, T_c, H) x (B, H, T_q) = (B, T_c, T_q)
        score_c2q = self.wipe_out_pad_tkn_score(score_c2q, q_lengths, dim=2)
        
        prob_c2q = self.softmax(score_c2q)
        if c_lengths is not None:
            prob_c2q = self.wipe_out_pad_tkn_score(prob_c2q, c_lengths, dim=1)
        # prob_c2q: (B, T_c, T_q) -> (B, T_c, T_q, 1)
        # o_q: (B, 1, T_q, H)
        # p_col2question \odot o_q = (B, T_c, T_q, 1) \odot (B, 1, T_q, H) = (B, T_c, T_q, H)
        # -> reduce sum to T_q to get context for each column (B, T_c, H)
        context = torch.mul(prob_c2q.unsqueeze(3), o_q.unsqueeze(1)).sum(dim=2)
        if rt_attn:
            attn = prob_c2q
        else:
            attn = None
        return context, attn

class SelfAttention(AttentionBase):
    r"""Decoder Self Attention Module"""
    def __init__(self, in_features, out_features=1):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, o, lengths, rt_attn=False):
        r"""
        Calculate for each o tokens, How much related to o tokens?
        
        return attended summary of o
        """
        o_transform = self.linear(o)  # (B, T_o, H) -> (B, T_o, 1)
        o_transform = self.wipe_out_pad_tkn_score(o_transform, lengths) 
        o_prob = self.softmax(o_transform)  # (B, T_o, 1)
        
        o_summary = torch.mul(o, o_prob).sum(1)  # (B, T_o, H) \odot (B, T_o, 1) -> (B, H)

        if rt_attn:
            attn = o_prob
        else:
            attn = None
        return o_summary, attn

## Decoder Sub Layers

In [8]:
class SelectDecoder(nn.Module):
    r"""SELECT Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict column index
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_q_context, attn = self.col2question_attn(col_context, o_q, question_lengths, col_lengths, rt_attn)  # (B, T_c, H), (B, T_c, T_q)
        
        vec = torch.cat([col_q_context, col_context], dim=2)  # (B, T_c, 2H)
        output = self.output_layer(vec)
        # TODO: add penalty for padded header(column) information
        
        return output.squeeze(-1), attn
    

class AggDecoder(nn.Module):
    r"""AGG Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size)
        )
                
    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], select_idxes: List[int], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        
        col_selected = col_context[list(range(batch_size)), select_idxes].unsqueeze(1)  # col_selected: (B, 1, H)
        
        col_q_context, attn = self.col2question_attn(col_selected, o_q, question_lengths, col_lengths, rt_attn)  # (B, 1, H), (B, 1, T_q)
        output = self.output_layer(col_q_context.squeeze(1))
        
        return output, attn
    
    
class WhereNumDecoder(nn.Module):
    r"""WHERE number Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3, max_where_conds=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        if self.output_size > self.max_where_conds+1:
            # HERE output will be dilivered to cross-entropy loss, not guessing the real number of where clause
            raise ValueError(f"`WhereNumDecoder` only support maximum {max_where_conds} where clause")
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_self_attn = SelfAttention(2*hidden_size, 1)
        self.lstm_q_hidden_init_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        self.lstm_q_cell_init_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        
        self.context_self_attn = SelfAttention(hidden_size, 1)
        
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size)
        )
        
        
    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        """
        batch_size, n_col, _ = col_padded.size()
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)

        col_self_attn, col_attn = self.col_self_attn(col_context, col_lengths, rt_attn)  # (B, 2H), (B, T_c)

        h_0 = self.lstm_q_hidden_init_linear(col_self_attn)  # (B, 2H)
        h_0 = h_0.view(batch_size, 2*self.num_layers, -1).transpose(0, 1).contiguous()  # (B, n_direc*num_layers, H/2) -> (n_direc*num_layers, B, H/2)
        c_0 = self.lstm_q_cell_init_linear(col_self_attn)  # (B, 2H)
        c_0 = c_0.view(batch_size, 2*self.num_layers, -1).transpose(0, 1).contiguous()  # (B, n_direc*num_layers, H/2) -> (n_direc*num_layers, B, H/2)
        
        o_q, (h_q, c_q) = self.lstm_q(question_padded, (h_0, c_0))  # o_q: (B, T_q, H)
        o_summary, o_attn = self.context_self_attn(o_q, question_lengths, rt_attn)  # (B, H), (B, T_q)
        output = self.output_layer(o_summary)
        
        return output, (col_attn, o_attn)

    
class WhereColumnDecoder(nn.Module):
    r"""WHERE Column Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int=1, num_layers: int=2, dropout_ratio:float=0.3, max_where_conds: int=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds

        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict column index
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_q_context, attn = self.col2question_attn(col_context, o_q, question_lengths, col_lengths, rt_attn)  # (B, T_c, H), (B, T_c, T_q)
        
        vec = torch.cat([col_q_context, col_context], dim=2)  # (B, T_c, 2H)
        output = self.output_layer(vec)
        # TODO: add penalty for padded header(column) information
        
        return output.squeeze(-1), attn
    
    
class WhereOpDecoder(nn.Module):
    r"""WHERE Opperator Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio: float=0.3, max_where_conds: int=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )
    
    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], where_nums: List[int], where_col_idxes: List[List[int]], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        max_where_col_nums is settled at WhereColumnDecoder, but it can be lower than or equal to `max_where_conds`
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_context_padded = self.get_context_padded(col_context, where_nums, where_col_idxes)  # (B, max_where_col_nums, H)
        
        col_q_context, attn = self.col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn)  # (B, max_where_col_nums, H), (B, max_where_col_nums, T_q)
        
        vec = torch.cat([col_q_context, col_context_padded], dim=2)  # (B, max_where_col_nums, 2H)
        output = self.output_layer(vec)  # (B, max_where_col_nums, n_cond_ops)
        # TODO: add penalty for padded header(column) information
        return output
        
    def get_context_padded(self, col_context, where_nums, where_col_idxes):
        r"""
        Select the where column index and pad if some batch doesn't match the max length of tensor
        In case for have different where column lengths
        """
        batch_size, n_col, hidden_size = col_context.size()
        max_where_col_nums = max(where_nums) 
        batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size, device=col_context.device)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
            
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)
    
    
class WhereValueDecoder(nn.Module):
    r"""WHERE Value Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio: float=0.3, max_where_conds: int=4, n_cond_ops: int=4,
                 start_tkn_id=8002, end_tkn_id=8003, embedding_layer=None) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        self.n_cond_ops = n_cond_ops
        
        self.start_tkn_id = start_tkn_id
        self.end_tkn_id = end_tkn_id
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.where_op_linear = nn.Linear(n_cond_ops, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        if embedding_layer is None:
            raise KeyError("Must initialize the embedding_layer to BertModel's word embedding layer")
        else:
            if not isinstance(embedding_layer, torch.nn.modules.sparse.Embedding):
                embedding_layer = embedding_layer.word_embeddings
            self.embedding_layer = embedding_layer
            vocab_size, bert_hidden_size = embedding_layer.weight.data.size()
            self.output_lstm_hidden_init_linear = nn.Linear(3*hidden_size, bert_hidden_size)
            self.output_lstm_cell_init_linear = nn.Linear(3*hidden_size, bert_hidden_size)
            self.output_lstm = nn.LSTM(bert_hidden_size, bert_hidden_size, 1, batch_first=True)
            self.output_linear = nn.Linear(bert_hidden_size, vocab_size)
            self.output_linear.weight.data = embedding_layer.weight.data

        
    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], where_nums: List[int], where_col_idxes: List[List[int]], where_op_idxes: List[List[int]], value_tkn_max_len=None, g_wv_tkns=None, rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        max_where_col_nums is setted at WhereColumnDecoder
        value_tkn_max_len = Train if None else Test
        g_wv_tkns = When Train should not be None
        
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_context_padded = self.get_context_padded(col_context, where_nums, where_col_idxes)  # (B, max_where_col_nums, H)

        col_q_context, attn = self.col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn)  # (B, max_where_col_nums, H), (B, max_where_col_nums, T_q)
        where_op_one_hot_padded = self.get_where_op_one_hot_padded(where_op_idxes, where_nums, where_col_idxes, n_cond_ops=self.n_cond_ops)#.to(o_q.device)  # (B, max_where_col_nums, n_cond_ops)
        where_op = self.where_op_linear(where_op_one_hot_padded)  # (B, max_where_col_nums, H)

        vec = torch.cat([col_q_context, col_context_padded, where_op], dim=2)  # (B, max_where_col_nums, 3H)
        max_where_col_nums = vec.size(1)
        # predict each where_col
        total_scores = []
        for i in range(max_where_col_nums):
            g_wv_tkns_i = torch.LongTensor([g_wv_tkns[b_idx][i] for b_idx in range(batch_size)]) if g_wv_tkns is not None else None  # (B, T_d_i)
            vec_i = vec[:, i, :]  # (B, 3H)
            
            h_0 = self.output_lstm_hidden_init_linear(vec_i).unsqueeze(1).transpose(0, 1).contiguous()  # (B, 3H) -> (B, bert_H) -> (1, B, bert_H)
            c_0 = self.output_lstm_cell_init_linear(vec_i).unsqueeze(1).transpose(0, 1).contiguous()  # (B, 3H) -> (B, bert_H) -> (1, B, bert_H)
            
            scores = self.decode_single_where_col(batch_size, h_0, c_0, value_tkn_max_len=value_tkn_max_len, g_wv_tkns_i=g_wv_tkns_i)  # (B, T_d_i, vocab_size)
            total_scores.append(scores)
        
        # total_scores: [(B, T_d_i, vocab_size)] x max_where_col_nums
        return total_scores
    
    def start_token(self, batch_size):
        sos = torch.LongTensor([self.start_tkn_id]*batch_size).unsqueeze(1)  # (B, 1)
        return sos
    
    def decode_single_where_col(self, batch_size, h_0, c_0, value_tkn_max_len=None, g_wv_tkns_i=None):
        if value_tkn_max_len is None:
            # [Training] set the max length to gold token max length (already padded)
            max_len = len(g_wv_tkns_i[0])
        else:
            # [Testing]  don't know the max length
            max_len = value_tkn_max_len
            
        # Version2: left_batch_size = batch_size
        
        sos = self.start_token(batch_size)  # (B, 1)
        emb = self.embedding_layer(sos)  # (B, 1, bert_H)
        scores = [] 
        for i in range(max_len):
            o, (h, c) = self.output_lstm(emb, (h_0, c_0))  # h: (1, B, bert_H)  
            s = self.output_linear(h[-1, :]) # select last layer if use multiple rnn layers, h: (1, B, bert_H) -> (B, bert_H) -> s: (B, vocab_size)
            
            if g_wv_tkns_i is not None:
                # [Training] Teacher Force model
                pred = g_wv_tkns_i[:, i]  # (B, )
                scores.append(s)
            else:
                # [Testing]
                pred = s.argmax(1)  # (B, )
                if (pred == self.end_tkn_id).sum() == batch_size:  # all stop
                    break
                else:
                    scores.append(s)
                
                # Version2: Seperate all tokens
                # if (pred == dd.end_tkn_id).sum() == left_batch_size:  # all stop
                #     scores.append(s)
                #     break
                # else:
                #     stop_mask = pred == dd.end_tkn_id
                #     pred = pred[~stop_mask]
                #     scores.append(pred)
                #     left_batch_size -= stop_mask.sum().item()
                    
            emb = self.embedding_layer(pred.unsqueeze(1))  # (B, 1, bert_H)
        
        return torch.stack(scores).transpose(0, 1).contiguous() # (T_d_i, B, vocab_size) -> (B, T_d_i, vocab_size)
        
    def get_context_padded(self, col_context: torch.Tensor, where_nums: List[int], where_col_idxes: List[List[int]]):
        r"""
        Select the where column index and pad if some batch doesn't match the max length of tensor
        In case for have different where column lengths
        """
        batch_size, n_col, hidden_size = col_context.size()
        max_where_col_nums = max(where_nums)
        batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
            
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)
    
    def get_where_op_one_hot_padded(self, where_op_idxes: List[List[int]], where_nums: List[int], where_col_idxes: List[List[int]], n_cond_ops: int):
        r"""
        Turn where operation indexs into one hot encoded vectors
        In case for have different where column lengths
        """
        max_where_col_nums = max(where_nums)
        batches = [torch.zeros(where_num, n_cond_ops).scatter(1, torch.LongTensor(batch_col).unsqueeze(1), 1) for where_num, batch_col in zip(where_nums, where_op_idxes)]  
        # batches = [(where_col_nums, n_cond_ops), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), n_cond_ops)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)

## Decoder Module

### Variables

In [34]:
input_size = config_bert.hidden_size
hidden_size = 100
num_layers = 2
dropout_ratio = 0.3
max_where_conds = 4
n_agg_ops = len(dbengine.agg_ops)
n_cond_ops = len(dbengine.cond_ops)
start_tkn_id = tokenizer_bert.additional_special_tokens_ids[0]
end_tkn_id = tokenizer_bert.additional_special_tokens_ids[1]
embedding_layer = model_bert.embeddings.word_embeddings
train = True
if train:
    value_tkn_max_len = None
else:
    value_tkn_max_len = 20

In [9]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_ratio, max_where_conds, n_agg_ops, n_cond_ops, start_tkn_id, end_tkn_id, embedding_layer):
        super().__init__()
        self.max_where_conds = max_where_conds
        
        self.select_decoder = SelectDecoder(
            input_size, hidden_size, output_size=1, num_layers=num_layers, dropout_ratio=dropout_ratio
        )
        self.agg_decoder = AggDecoder(
            input_size, hidden_size, output_size=n_agg_ops, num_layers=num_layers, dropout_ratio=dropout_ratio
        )
        self.where_num_decoder = WhereNumDecoder(
            input_size, hidden_size, output_size=(max_where_conds+1), num_layers=num_layers, dropout_ratio=dropout_ratio
        )
        self.where_col_decoder = WhereColumnDecoder(
            input_size, hidden_size, output_size=1, num_layers=num_layers, dropout_ratio=dropout_ratio, max_where_conds=max_where_conds
        )
        self.where_op_decoder = WhereOpDecoder(
            input_size, hidden_size, output_size=n_cond_ops, num_layers=num_layers, dropout_ratio=dropout_ratio, max_where_conds=max_where_conds
        )
        self.where_value_decoder = WhereValueDecoder(
            input_size, hidden_size, output_size=n_cond_ops, num_layers=num_layers, dropout_ratio=dropout_ratio, max_where_conds=max_where_conds, 
            n_cond_ops=n_cond_ops, start_tkn_id=start_tkn_id, end_tkn_id=end_tkn_id, embedding_layer=embedding_layer
        )
    
    
    def forward(self, question_padded, header_padded, col_padded, question_lengths, col_lengths, value_tkn_max_len=None, gold=None):
        """
        # Outputs Size
        # sc = (B, T_c)
        # sa = (B, n_agg_ops)
        # wn = (B, max_where_conds+1)
        # wc = (B, T_c): binary
        # wo = (B, max_where_col_nums, n_cond_ops)
        # wv = [(B, T_d_i, vocab_size)] x max_where_col_nums / T_d_i = may have different length for answer
        """
        if gold is None:
            g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_tkns = [None] * 6
        else:
            g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_tkns = gold
        decoder_outputs = {}

        select_outputs, _ = self.select_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
        select_idxes = g_sc if g_sc else self.predict_decoder("sc", select_outputs=select_outputs)

        agg_outputs, _ = self.agg_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths, select_idxes)

        where_num_outputs, _  = self.where_num_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
        where_nums = g_wn if g_wn else self.predict_decoder("wn", where_num_outputs=where_num_outputs)

        where_col_outputs, _ = self.where_col_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
        where_col_idxes = g_wc if g_wc else self.predict_decoder("wc", where_col_outputs=where_col_outputs, where_nums=where_nums)

        where_op_outputs = self.where_op_decoder(question_padded, header_padded, col_padded, question_lengths, where_nums, where_col_idxes)
        where_op_idxes = g_wo if g_wo else self.predict_decoder("wo", where_op_outputs=where_op_outputs, where_nums=where_nums)

        where_value_outputs = self.where_value_decoder(question_padded, header_padded, col_padded, question_lengths, where_nums, where_col_idxes, where_op_idxes, value_tkn_max_len, g_wv_tkns)

        decoder_outputs = {
            "sc": select_outputs,  # cross entropy
            "sa": agg_outputs,  # cross entropy
            "wn": where_num_outputs,  # cross entropy
            "wc": where_col_outputs,  # binary cross entropy
            "wo": where_op_outputs,  # cross entropy
            "wv": where_value_outputs  # cross entropy
        }
        
        return decoder_outputs
        
    def predict_decoder(self, typ, **kwargs):
        r"""
        if not using teacher force model will use this function to predict answer
        # Outputs Size
        # sc = (B, T_c)
        # sa = (B, n_agg_ops)
        # wn = (B, max_where_conds+1)
        # wc = (B, T_c): binary
        # wo = (B, max_where_col_nums, n_cond_ops)
        # wv = [(B, T_d_i, vocab_size)] x max_where_col_nums / T_d_i = may have different length for answer
        """
        if typ == "sc":  # SELECT column
            select_outputs = kwargs["select_outputs"]
            return select_outputs.argmax(1).tolist()
        elif typ == "sa":  # SELECT aggregation operator
            agg_outputs = kwargs["agg_outputs"]
            return agg_outputs.argmax(1).tolist()
        elif typ == "wn":  # WHERE number
            where_num_outputs = kwargs["where_num_outputs"]
            return where_num_outputs.argmax(1).tolist()
        elif typ == "wc":  # WHERE clause column
            where_col_outputs = kwargs["where_col_outputs"]
            where_col_argsort = torch.sigmoid(where_col_outputs).argsort(1)
            where_nums = kwargs["where_nums"]
            where_col_idxes = [where_col_argsort[b_idx, :w_num].tolist() for b_idx, w_num in enumerate(where_nums)]
            return where_col_idxes
        elif typ == "wo":  # WHERE clause operator
            where_op_outputs = kwargs["where_op_outputs"]
            where_nums = kwargs["where_nums"]
            where_op_idxes = []
            for b_idx, w_num in enumerate(where_nums):
                if w_num == 0:  # means no where number
                    where_op_idxes.append([])
                else:
                    where_op_idxes.append(where_op_outputs.argmax(2)[b_idx, :w_num].tolist())
            return where_op_idxes
        elif typ == "wv":  # WHERE clause value
            where_value_outputs = kwargs["where_value_outputs"]
            return [o.argmax(2).tolist() for o in where_value_outputs]  # iter with each where clause
        else:
            raise KeyError("`typ` must be in ['sc', 'sa', 'wn', 'wc', 'wo', 'wv']")

In [295]:
# predicts = {}
# predicts["sc"] = model_decoder.predict_decoder("sc", select_outputs=outputs["sc"])
# predicts["sa"] = model_decoder.predict_decoder("sa", agg_outputs=outputs["sa"])
# predicts["wn"] = model_decoder.predict_decoder("wn", where_num_outputs=outputs["wn"])
# predicts["wc"] = model_decoder.predict_decoder("wc", where_col_outputs=outputs["wc"], where_nums=predicts["wn"])
# predicts["wo"] = model_decoder.predict_decoder("wo", where_op_outputs=outputs["wo"], where_nums=predicts["wn"])
# predicts["wv"] = model_decoder.predict_decoder("wv", where_value_outputs=outputs["wv"])

In [184]:
model_decoder = Decoder(input_size, hidden_size, num_layers, dropout_ratio, max_where_conds, n_agg_ops, n_cond_ops, start_tkn_id, end_tkn_id, embedding_layer)

TRAIN

In [228]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns = get_sql_answers(batch_sqls, tokenizer_bert)
gold = g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_tkns
decoder_outputs = model_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths, value_tkn_max_len=None, gold=gold)
decoder_outputs

{'sc': tensor([[-0.0512, -0.0510, -0.0482, -0.0414, -0.0402, -0.0447, -0.0397, -0.0437,
          -0.0454, -0.0486, -0.0494, -0.0447, -0.0496, -0.0523, -0.0460, -0.0452,
          -0.0435, -0.0446, -0.0462, -0.0415],
         [-0.0356, -0.0358, -0.0307, -0.0365, -0.0369, -0.0346, -0.0382, -0.0346,
          -0.0403, -0.0408, -0.0399, -0.0368, -0.0418, -0.0408, -0.0395, -0.0387,
          -0.0391, -0.0409, -0.0380, -0.0352]], grad_fn=<SqueezeBackward1>),
 'sa': tensor([[-0.0188,  0.1266,  0.0655, -0.0132, -0.0502,  0.0563],
         [-0.0298,  0.1227,  0.0649, -0.0285, -0.0581,  0.0507]],
        grad_fn=<AddmmBackward>),
 'wn': tensor([[-0.0960, -0.1163, -0.0221,  0.0192, -0.0908],
         [-0.1013, -0.1192, -0.0132,  0.0161, -0.0888]],
        grad_fn=<AddmmBackward>),
 'wc': tensor([[0.0350, 0.0374, 0.0367, 0.0411, 0.0455, 0.0437, 0.0417, 0.0375, 0.0384,
          0.0391, 0.0335, 0.0313, 0.0306, 0.0383, 0.0422, 0.0378, 0.0342, 0.0409,
          0.0386, 0.0340],
         [0.0405, 0.0

TEST

In [242]:
decoder_outputs = model_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths, value_tkn_max_len=20, gold=None)
decoder_outputs

{'sc': tensor([[-0.0386, -0.0402, -0.0374, -0.0357, -0.0362, -0.0359, -0.0380, -0.0341,
          -0.0347, -0.0388, -0.0379, -0.0317, -0.0339, -0.0398, -0.0451, -0.0404,
          -0.0395, -0.0334, -0.0362, -0.0391],
         [-0.0357, -0.0366, -0.0337, -0.0350, -0.0371, -0.0358, -0.0382, -0.0352,
          -0.0339, -0.0367, -0.0299, -0.0351, -0.0377, -0.0393, -0.0331, -0.0350,
          -0.0384, -0.0396, -0.0394, -0.0433]], grad_fn=<SqueezeBackward1>),
 'sa': tensor([[-0.0166,  0.1289,  0.0615, -0.0191, -0.0575,  0.0516],
         [-0.0289,  0.1171,  0.0701, -0.0318, -0.0553,  0.0462]],
        grad_fn=<AddmmBackward>),
 'wn': tensor([[-0.0913, -0.1125, -0.0182,  0.0102, -0.0842],
         [-0.1004, -0.1137, -0.0206,  0.0189, -0.0867]],
        grad_fn=<AddmmBackward>),
 'wc': tensor([[0.0510, 0.0406, 0.0425, 0.0470, 0.0503, 0.0506, 0.0495, 0.0474, 0.0475,
          0.0401, 0.0496, 0.0439, 0.0417, 0.0420, 0.0513, 0.0436, 0.0520, 0.0519,
          0.0538, 0.0535],
         [0.0381, 0.0

---

# Loss

TRAIN for example

In [241]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns = \
    ([13, 16],
     [0, 0],
     [3, 1],
     [[10, 3, 3], [10]],
     [[0, 0, 0], [0]],
     [['유동부채', 2018, 2018], ['유동자산']],
     [([3574, 5872, 6398, 7405, 8003], [554, 115, 8003], [554, 115, 8003]),
      ([3574, 5872, 7162, 8003, 1], [1, 1, 1], [1, 1, 1])])

In [315]:
cross_entropy = nn.CrossEntropyLoss(reduction="sum")
binary_cross_entropy = nn.BCEWithLogitsLoss(reduction="sum")
vocab_size = len(tokenizer_bert)
wn_penalty = 2.0  # scale up for guessing where number
wv_penalty = 5.0
# Outputs Size
# sc = (B, T_c)
# sa = (B, n_agg_ops)
# wn = (B, 5)
# wc = (B, T_c): binary
# wo = (B, max_where_col_nums, n_cond_ops)
# wv = [(B, T_d_i, vocab_size)] x max_where_col_nums / T_d_i = may have different length for answer
batch_size = decoder_outputs["sc"].size(0)
loss_sc = cross_entropy(decoder_outputs["sc"], torch.LongTensor(g_sc))
loss_sa = cross_entropy(decoder_outputs["sa"], torch.LongTensor(g_sa))
loss_wn = cross_entropy(decoder_outputs["wn"], torch.LongTensor(g_wn)) * wn_penalty

# need consider: might have different length of where numers
# So when calculate scores looping by where numbers, ignore the out of length tokens
loss_wc = 0
loss_wo = 0
loss_wv = 0
for batch_idx, where_num in enumerate(g_wn):
    one_hot_dist = torch.zeros_like(decoder_outputs["wc"][batch_idx]).scatter(0, torch.LongTensor(g_wc[batch_idx]), 1.0)
    loss_wc += binary_cross_entropy(decoder_outputs["wc"][batch_idx], one_hot_dist)
    
    batch_g_wo = g_wo[batch_idx]  # (where_num,)
    batch_wo = decoder_outputs["wo"][batch_idx, :where_num, :]  # (where_num, n_cond_ops)
    loss_wo += cross_entropy(batch_wo, torch.LongTensor(batch_g_wo))
    
    batch_g_wv = g_wv_tkns[batch_idx][:where_num]  # (where_num, T_d_i)
    batch_wv = torch.stack([wv[batch_idx] for wv in decoder_outputs["wv"]])  # (where_num, value_tkn_max_len, vocab_size)
    for wv, g_wv_i in zip(batch_wv_temp, batch_g_wv_temp):  # will by where_num
        if wv.size(0) > len(g_wv_i):
            wv_penalty = 1.0
            wv = wv[:len(g_wv_i), :]  # (T_d_gold, vocab_size)
        elif wv.size(0) < len(g_wv_i):
            # giving penalty if not generate enought tokens
            wv_penalty = 5.0
            g_wv_i = g_wv_i[:len(wv)]  # (T_d_predict,)
        else:
            wv_penalty = 1.0
        # now have the same T_d size, ignore all over lengthed
        loss_wv += cross_entropy(wv, torch.LongTensor(g_wv_i)) * wv_penalty
    
total_loss = (loss_sc + loss_sa + loss_wn + loss_wc + loss_wo + loss_wv) / batch_size
total_loss

tensor(124.2673, grad_fn=<DivBackward0>)

# Whole Model

In [19]:
class Text2SQL(nn.Module):
    def __init__(
        self,
        model_bert, 
        tokenizer_bert,
        special_end_tkn,
        input_size, 
        hidden_size, 
        num_layers, 
        dropout_ratio, 
        max_where_conds, 
        n_agg_ops, 
        n_cond_ops,
        wn_penalty,
        wo_penalty,
        wv_penalty
    ) -> None:
        super().__init__()
        # Encoder
        self.model_bert = model_bert
        self.tokenizer_bert = tokenizer_bert
        self.special_end_tkn = special_end_tkn  # str [E]
        # Decoder
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        self.n_agg_ops = n_agg_ops
        self.n_cond_ops = n_cond_ops
        
        self.model_decoder = Decoder(
            input_size, 
            hidden_size, 
            num_layers, 
            dropout_ratio,
            max_where_conds, 
            n_agg_ops, 
            n_cond_ops, 
            start_tkn_id = tokenizer_bert.additional_special_tokens_ids[0],
            end_tkn_id = tokenizer_bert.additional_special_tokens_ids[1],
            embedding_layer = model_bert.embeddings.word_embeddings
        )
        # Loss
        self.cross_entropy = nn.CrossEntropyLoss(reduction="sum")
        self.binary_cross_entropy = nn.BCEWithLogitsLoss(reduction="sum")
        self.vocab_size = len(self.tokenizer_bert)
        self.wn_penalty = wn_penalty  # scale up for guessing where number
        self.wo_penalty = wo_penalty
        self.wv_penalty = wv_penalty  # giving penalty if not generate enough tokens
        
    def forward_outputs(self, batch_qs, batch_ts, batch_sqls=None, value_tkn_max_len=None, train=True):
        # --- Get Answer & Variables ---
        if train:
            assert value_tkn_max_len is None, "In train phase, `value_tkn_max_len` must be None"
            assert batch_sqls is not None, "In train phase, `batch_sqls` must not be None"
            g_sc, g_sa, g_wn, g_wc, g_wo, _, g_wv_tkns = self.get_sql_answers(batch_sqls, self.tokenizer_bert)
            gold = [g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_tkns]
        else:
            assert value_tkn_max_len is not None, "In validation Phase, `value_tkn_max_len` must not be None"
            gold = None
            value_tkn_max_len = value_tkn_max_len
            
        # --- Get Inputs for Encoder --- 
        encode_inputs = tokenizer_bert(
            batch_qs, batch_ts, 
            max_length=512, padding=True, truncation=True, return_tensors="pt", 
            return_attention_mask=True, 
            return_special_tokens_mask=False, 
        )
        
        # --- Forward Encoder ---
        encode_outputs = self.model_bert(**encode_inputs)
        
        # --- Get Inputs for Decoder ---
        input_question_mask, input_table_mask, input_header_mask, input_col_mask = self.get_input_mask_and_answer(encode_inputs, self.tokenizer_bert)
        question_padded, question_lengths = self.get_decoder_batches(encode_outputs, input_question_mask, pad_idx=self.tokenizer_bert.pad_token_id)
        # table_padded, table_lengths = self.get_decoder_batches(encode_outputs, input_table_mask, pad_idx=self.tokenizer_bert.pad_token_id)  # Not used yet
        header_padded, header_lengths = self.get_decoder_batches(encode_outputs, input_header_mask, pad_idx=self.tokenizer_bert.pad_token_id)
        col_padded, col_lengths = self.get_decoder_batches(encode_outputs, input_col_mask, pad_idx=self.tokenizer_bert.pad_token_id)
        
        # --- Forward Decoder ---
        decoder_outputs = self.model_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths, value_tkn_max_len, gold)
        
        return decoder_outputs
    
    def predict(self, batch_qs, batch_ts, value_tkn_max_len):
        outputs = self.forward_outputs(batch_qs, batch_ts, batch_sqls=None, value_tkn_max_len=value_tkn_max_len, train=False)
        
        predicts = {}
        predicts["sc"] = self.model_decoder.predict_decoder("sc", select_outputs=outputs["sc"])
        predicts["sa"] = self.model_decoder.predict_decoder("sa", agg_outputs=outputs["sa"])
        predicts["wn"] = self.model_decoder.predict_decoder("wn", where_num_outputs=outputs["wn"])
        predicts["wc"] = self.model_decoder.predict_decoder("wc", where_col_outputs=outputs["wc"], where_nums=predicts["wn"])
        predicts["wo"] = self.model_decoder.predict_decoder("wo", where_op_outputs=outputs["wo"], where_nums=predicts["wn"])
        predicts["wv_tkns"] = self.model_decoder.predict_decoder("wv", where_value_outputs=outputs["wv"])  # (B, value_tkn_max_len) x where_nums
        # internally wv means wv_tkns, will convert to string here using tokenizer
        predicts["wv"] = []
        for where_idx, wv_tkns in enumerate(predicts["wv_tkns"]): # iter: (B, value_tkn_max_len)
            predicts["wv"].append([self.tokenizer_bert.decode(torch.LongTensor(batch_wv)) for batch_wv in wv_tkns])
                
        predicts["wv"] = list(zip(*predicts["wv"]))
        
        return predicts
    
    def forward(self, batch_qs, batch_ts, batch_sqls=None, value_tkn_max_len=None, train=True):
        outputs = self.forward_outputs(batch_qs, batch_ts, batch_sqls, value_tkn_max_len, train)
        g_sc, g_sa, g_wn, g_wc, g_wo, _, g_wv_tkns = self.get_sql_answers(batch_sqls, self.tokenizer_bert)
        gold = [g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_tkns]
        loss = self.calculate_loss(outputs, gold)  # when calculate loss must need gold answer
        return loss, outputs
    
    def calculate_loss(self, decoder_outputs, gold):
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_tkns = gold
        g_wv_tkns = list(zip(*g_wv_tkns))  # (B, )
        # Outputs Size
        # sc = (B, T_c)
        # sa = (B, n_agg_ops)
        # wn = (B, 5)
        # wc = (B, T_c): binary
        # wo = (B, max_where_col_nums, n_cond_ops)
        # wv = [(B, T_d_i, vocab_size)] x max_where_col_nums / T_d_i = may have different length for answer
        batch_size = decoder_outputs["sc"].size(0)
        loss_sc = self.cross_entropy(decoder_outputs["sc"], torch.LongTensor(g_sc))
        loss_sa = self.cross_entropy(decoder_outputs["sa"], torch.LongTensor(g_sa))
        loss_wn = self.cross_entropy(decoder_outputs["wn"], torch.LongTensor(g_wn)) * self.wn_penalty

        # need consider: might have different length of where numers
        # So when calculate scores looping by where numbers, ignore the out of length tokens
        loss_wc = 0
        loss_wo = 0
        loss_wv = 0
        for batch_idx, where_num in enumerate(g_wn):

            one_hot_dist = torch.zeros_like(decoder_outputs["wc"][batch_idx]).scatter(0, torch.LongTensor(g_wc[batch_idx]), 1.0)
            loss_wc += self.binary_cross_entropy(decoder_outputs["wc"][batch_idx], one_hot_dist)

            batch_g_wo = g_wo[batch_idx]  # (where_num_gold,)
            batch_wo = decoder_outputs["wo"][batch_idx, :where_num, :]  # (where_num_predict, n_cond_ops)
            if (len(batch_wo) == 0 and where_num != 0):
                # if predict nothing where clause and answer is not, what loss should be added?
                # simply giving big loss will be enough?
                loss_wo += loss_wn * 100
            else:
                give_wo_penalty = False
                if len(batch_wo) > len(batch_g_wo): 
                    wo_penalty = self.wo_penalty / 2
                    give_wo_penalty = True
                    batch_wo = batch_wo[:len(batch_g_wo), :]  # (where_num_predict, n_cond_ops)
                elif len(batch_wo) < len(batch_g_wo):
                    # giving penalty if not guessed right where numbers
                    # It becomes problem when reduce the gold tokens but predicted corrected 
                    # Then `loss_wo_base` will be 0, if simply multiply by `loss_wv_base` to loss_base will be zero
                    wo_penalty = self.wo_penalty
                    give_wo_penalty = True
                    batch_g_wo = batch_g_wo[:len(batch_wo)]  # (where_num_gold,)
                else:
                    wo_penalty = 1.0
                    give_wo_penalty = False
                loss_wo_base = self.cross_entropy(batch_wo, torch.LongTensor(batch_g_wo))
                if give_wo_penalty:
                    loss_wo += loss_wo_base + loss_wn * wo_penalty
                else:
                    loss_wo += loss_wo_base

            batch_g_wv = g_wv_tkns[batch_idx][:where_num]  # (gold_where_num, T_d_i)
            batch_wv = [wv[batch_idx] for wv in decoder_outputs["wv"]]  # (predict_where_num, T_d_i, vocab_size)
            if len(batch_wo) == 0 and where_num != 0:
                # if predict nothing where clause and answer is not, what loss should be added?
                loss_wv += loss_wn * 100
            else:
                for wv, g_wv_i in zip(batch_wv, batch_g_wv):  # will iter by where_num
                    give_wv_penalty = False
                    if len(wv) > len(g_wv_i):
                        wv_penalty = self.wv_penalty / 2
                        give_wv_penalty = True
                        wv = wv[:len(g_wv_i), :]  # (T_d_gold, vocab_size)
                    elif len(wv) < len(g_wv_i):
                        # giving penalty if not generate enough tokens
                        # It becomes problem when reduce the gold tokens but predicted corrected 
                        # Then `loss_wv_base` will be 0, if simply multiply by `loss_wv_base` to loss_base will be zero
                        wv_penalty = self.wv_penalty
                        give_wv_penalty = True
                        g_wv_i = g_wv_i[:len(wv)]  # (T_d_predict,)
                    else:
                        wv_penalty = 1.0
                        give_wv_penalty = False
                    # now have the same T_d size, ignore all over lengthed
                    loss_wv_base = self.cross_entropy(wv, torch.LongTensor(g_wv_i))
                    if give_wv_penalty:
                        loss_wv += loss_wv_base + loss_wn * wv_penalty
                    else:
                        loss_wv += loss_wv_base
        loss = (loss_sc + loss_sa + loss_wn + loss_wc + loss_wo + loss_wv) / batch_size
        return loss
    
    def get_sql_answers(self, batch_sqls: List[Dict[str, Any]], tokenizer: KoBertTokenizer):
        """[summary]
        sc: select column
        sa: select agg
        wn: where number
        wc: where column
        wo: where operator
        wv: where value

        Args:
            batch_sqls (List[Dict[str, Any]]): [description]
            tokenizer (KoBertTokenizer): [description]

        Raises:
            EnvironmentError: [description]

        Returns:
            [type]: [description]
        """
        get_ith_element = lambda li, i: [x[i] for x in li]
        g_sc = []
        g_sa = []
        g_wn = []
        g_wc = []
        g_wo = []
        g_wv = []
        for b, sql_dict in enumerate(batch_sqls):
            g_sc.append( sql_dict["sel"] )
            g_sa.append( sql_dict["agg"])

            conds = sql_dict["conds"]
            if not sql_dict["agg"] < 0:
                g_wn.append( len(conds) )
                g_wc.append( get_ith_element(conds, 0) )
                g_wo.append( get_ith_element(conds, 1) )
                g_wv.append( get_ith_element(conds, 2) )
            else:
                raise EnvironmentError

        # get where value tokenized 
        pad_tkn_id = tokenizer.pad_token_id
        g_wv_tkns = [[f"{s}{self.special_end_tkn}" for s in batch_wv] for batch_wv in g_wv]
        g_wv_tkns = [tokenizer(batch_wv, add_special_tokens=False)["input_ids"] for batch_wv in g_wv_tkns]
        # add empty list if batch has different where column number
        max_where_cols = max([len(batch_wv) for batch_wv in g_wv_tkns])
        g_wv_tkns = [batch_wv + [[]]*(max_where_cols-len(batch_wv)) if len(batch_wv) < max_where_cols else batch_wv for batch_wv in g_wv_tkns]
        temp = []
        for batch_wv in list(zip(*g_wv_tkns)):
            batch_max_len = max(map(len, batch_wv))
            batch_temp = []
            for wv_tkns in batch_wv:  # iter by number of where clause
                if len(wv_tkns) < batch_max_len:
                    batch_temp.append(wv_tkns + [pad_tkn_id]*(batch_max_len - len(wv_tkns)))
                else:
                    batch_temp.append(wv_tkns)
            temp.append(batch_temp)
        g_wv_tkns = list(zip(*temp))

        return g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns
    
    
    ## Masks
    # TODO: [EXP] Experiment for generate column directly
    # def get_answer(input_ids, mask, batch_size, start_tkn_id, end_tkn_id):
    #     r"""
    #     answer should include end token: [E]
    #     """
    #     masked_input_ids = input_ids[mask]
    #     start_tkn_mask = masked_input_ids == start_tkn_id
    #     end_tkn_mask = masked_input_ids == end_tkn_id
    #     table_col_length = masked_input_ids.view(batch_size, -1).size(1)
    #     start_end_mask = torch.bitwise_or(start_tkn_mask, end_tkn_mask)
    #     index = torch.arange(table_col_length).repeat(batch_size)[start_end_mask].view(batch_size, -1, 2)
    #     tkn_lengths = index[:, :, 1] - index[:, :, 0]
    #     answer_col_tkns = [x.split(tkn_length.tolist()) for x, tkn_length in zip(
    #         masked_input_ids[~start_tkn_mask].view(batch_size, -1), tkn_lengths)]
    #     return answer_col_tkns


    def get_decoder_input_mask(self, input_ids: torch.Tensor, mask: torch.BoolTensor, batch_size: int, start_tkn_id: int, end_tkn_id: int) -> torch.BoolTensor:
        """[summary]

        Args:
            input_ids (torch.Tensor): [description]
            mask (torch.BoolTensor): [description]
            batch_size (int): [description]
            start_tkn_id (int): [description]
            end_tkn_id (int): [description]

        Returns:
            torch.BoolTensor: [description]
        """    
        start_tkn_mask = input_ids == start_tkn_id
        end_tkn_mask = input_ids == end_tkn_id
        start_end_mask = torch.bitwise_or(start_tkn_mask, end_tkn_mask)
        index = torch.arange(input_ids.size(1)).repeat(batch_size)[start_end_mask.view(-1)].view(batch_size, -1)
        return mask.scatter(1, index, False)


    def get_input_mask_and_answer(self, encode_input: transformers.tokenization_utils_base.BatchEncoding, tokenizer: KoBertTokenizer) -> Tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]:
        """[summary]

        In this code 'table' means database table name(id), 'header' means database header, 'col' means index of header 

        Args:
            encode_input (transformers.tokenization_utils_base.BatchEncoding): [description]
            tokenizer (KoBertTokenizer): [description]

        Returns:
            Tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]: [description]
        """

        batch_size, max_length = encode_input["input_ids"].size()
        sep_tkn_mask = encode_input["input_ids"] == tokenizer.sep_token_id
        start_tkn_id, end_tkn_id, col_tkn_id = tokenizer.additional_special_tokens_ids

        input_question_mask = torch.bitwise_and(encode_input["token_type_ids"] == 0, encode_input["attention_mask"].bool())
        input_question_mask = torch.bitwise_and(input_question_mask, ~sep_tkn_mask) # [SEP] mask out
        input_question_mask[:, 0] = False  # [CLS] mask out

        db_mask = torch.bitwise_and(encode_input["token_type_ids"] == 1, encode_input["attention_mask"].bool())
        db_mask = torch.bitwise_xor(db_mask, sep_tkn_mask)
        col_tkn_mask = encode_input["input_ids"] == col_tkn_id
        db_mask = torch.bitwise_and(db_mask, ~col_tkn_mask)
        # split table_mask and header_mask
        input_idx = torch.arange(max_length).repeat(batch_size, 1)
        db_idx = input_idx[db_mask]
        table_header_tkn_idx = db_idx[db_idx > 0]
        table_start_idx = table_header_tkn_idx.view(batch_size, -1)[:, 0] + 1
        start_idx = table_header_tkn_idx[1:][table_header_tkn_idx.diff() == 2].view(batch_size, -1)
        table_end_sep_idx = start_idx[:, 0] - 1
        split_size = torch.stack([
            table_end_sep_idx-table_start_idx+1, table_header_tkn_idx.view(batch_size, -1).size(1)-(table_end_sep_idx-table_start_idx+1)
        ]).transpose(0, 1)

        # Token idx
        table_tkn_idx, header_tkn_idx = map(
            lambda x: torch.stack(x), 
            zip(*[torch.split(x, size.tolist()) for x, size in zip(table_header_tkn_idx.view(batch_size, -1), split_size)])
        )

        table_tkn_idx = table_tkn_idx[:, 1:]

        # TODO: [EXP] Experiment for generate column directly
        # If [EXP], `table_tkn_mask` and `header_tkn_mask` should include [S] & [E] tokens
        table_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, table_tkn_idx, True)
        header_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, header_tkn_idx, True)

        # TODO: [EXP] Experiment for generate column directly
        # For Decoder Input, Maskout [S], [E] for table & header -> will be done automatically
        input_table_mask = self.get_decoder_input_mask(
            encode_input["input_ids"], table_tkn_mask, batch_size, start_tkn_id, end_tkn_id
        )
        input_header_mask = self.get_decoder_input_mask(
            encode_input["input_ids"], header_tkn_mask, batch_size, start_tkn_id, end_tkn_id
        )

        # [COL] token mask: this is for attention
        col_tkn_idx = input_idx[col_tkn_mask].view(batch_size, -1)
        input_col_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, col_tkn_idx, True)

        # TODO: [EXP] Experiment for generate column directly
        # For Answer, Maskout [S] for table & header 
        # answer_table_tkns = get_answer(
        #     encode_input["input_ids"], table_tkn_mask, batch_size, start_tkn_id, end_tkn_id
        # )
        # answer_header_tkns = get_answer(
        #     encode_input["input_ids"], header_tkn_mask, batch_size, start_tkn_id, end_tkn_id
        # )

        return input_question_mask, input_table_mask, input_header_mask, input_col_mask # , answer_table_tkns, answer_header_tkns


    ## Pad for decoder inputs
    def pad(self, batches: Tuple[torch.Tensor], lengths: List[int], pad_idx: int=1) -> torch.Tensor:
        """[summary]

        Args:
            batches (Tuple[torch.Tensor]): [description]
            lengths (List[int]): [description]
            model (transformers.models.bert.modeling_bert.BertModel): [description]
            pad_idx (int, optional): [description]. Defaults to 1.

        Returns:
            torch.Tensor: [description]
        """       
        padded = []
        max_length = max(lengths)
        for x in batches:
            if len(x) < max_length:
                pad_tensor = self.model_bert.embeddings.word_embeddings(torch.LongTensor([pad_idx]*(max_length - len(x))))
                padded.append(torch.cat([x, pad_tensor]))
            else:
                padded.append(x)
        return torch.stack(padded)

    def get_decoder_batches(self, encode_output: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions, mask: torch.BoolTensor, pad_idx: int) -> Tuple[torch.Tensor, List[int]]:
        """[summary]

        Args:
            encode_output (transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions): [description]
            mask (torch.BoolTensor): [description]
            model (BertModel): [description]
            pad_idx (int): [description]

        Returns:
            Tuple[torch.Tensor, List[int]]: [description]
        """    
        lengths = mask.sum(1)
        tensors = encode_output.last_hidden_state[mask, :]
        batches = torch.split(tensors, lengths.tolist())
        if lengths.ne(lengths.max()).sum().item() != 0:
            # pad not same length tokens
            tensors_padded = self.pad(batches, lengths.tolist(), pad_idx=pad_idx)
        else:
            # just stack the splitted tensors
            tensors_padded = torch.stack(batches)
        return tensors_padded, lengths.tolist()


In [13]:
special_end_tkn = "[E]"
hidden_size = 100
num_layers = 2
dropout_ratio = 0.3
max_where_conds = 4
wn_penalty = 2.0
wo_penalty = 4.0
wv_penalty = 5.0
value_tkn_max_len = 20 

In [23]:
model = Text2SQL(
    model_bert=model_bert,
    tokenizer_bert=tokenizer_bert,
    special_end_tkn=special_end_tkn,
    input_size=config_bert.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout_ratio=dropout_ratio,
    max_where_conds=max_where_conds,
    n_agg_ops=len(dbengine.agg_ops),
    n_cond_ops=len(dbengine.cond_ops),
    wn_penalty=wn_penalty,
    wo_penalty=wo_penalty,
    wv_penalty=wv_penalty
)

In [24]:
loss, outputs = model(batch_qs, batch_ts, batch_sqls, value_tkn_max_len=None, train=True) # Train
loss

tensor(122.8362, grad_fn=<DivBackward0>)

In [26]:
for _ in range(10):
    loss, outputs = model(batch_qs, batch_ts, batch_sqls, value_tkn_max_len=20, train=False) # Test
    print(loss)
    # check if returns error

tensor(113.3654, grad_fn=<DivBackward0>)
tensor(113.2863, grad_fn=<DivBackward0>)
tensor(113.6325, grad_fn=<DivBackward0>)
tensor(113.3764, grad_fn=<DivBackward0>)
tensor(113.5626, grad_fn=<DivBackward0>)
tensor(113.2283, grad_fn=<DivBackward0>)
tensor(113.3978, grad_fn=<DivBackward0>)
tensor(113.2684, grad_fn=<DivBackward0>)
tensor(113.6896, grad_fn=<DivBackward0>)
tensor(113.5539, grad_fn=<DivBackward0>)


# TODO: Add accuracy

In [27]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns = model.get_sql_answers(batch_sqls, model.tokenizer_bert)
predicts = model.predict(batch_qs, batch_ts, value_tkn_max_len)

In [28]:
predicts

{'sc': [10, 11],
 'sa': [5, 5],
 'wn': [1, 1],
 'wc': [[5], [14]],
 'wo': [[3], [3]],
 'wv_tkns': [[[6248,
    6249,
    6977,
    5839,
    6248,
    6249,
    6977,
    5839,
    6248,
    6249,
    6977,
    5839,
    6248,
    6249,
    6977,
    5839,
    6248,
    6249,
    6977,
    5839],
   [2032,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467,
    5467]]],
 'wv': [('뭔뭘옳더라뭔뭘옳더라뭔뭘옳더라뭔뭘옳더라뭔뭘옳더라',), ('면접곶곶곶곶곶곶곶곶곶곶곶곶곶곶곶곶곶곶곶',)]}

In [67]:
import os 
import pytorch_lightning as pl

In [None]:
class Model(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        # Load DBEngine
        self.dbengine = DBEngine(Path(self.hparams.db_path))        
        self.model = self.create_model(dbengine=self.dbengine)
        

    def forward(self, **kwargs):
        return self.model(**kwargs)

    
    def training_step(self, batch, batch_idx):
        batch_qs, batch_ts, batch_sqls = self.get_batch_data(batch, self.table, self.hparams.special_start_tkn, self.hparams.special_end_tkn)
        loss, outputs = self(
            batch_qs=batch_qs, 
            batch_ts=batch_ts, 
            batch_sqls=batch_sqls, 
            value_tkn_max_len=None, 
            train=True
        )
        return  {'loss': loss}  
    
    def train_epoch_end(self, outputs):
        loss = torch.tensor(0, dtype=torch.float)
        for out in outputs:
            loss += out["loss"].detach().cpu()
        loss = loss / len(outputs)

        return {'loss': loss}
    
    def validation_step(self, batch, batch_idx, dataloader_idx):
        batch_qs, batch_ts, batch_sqls = self.get_batch_data(batch, self.table, self.hparams.special_start_tkn, self.hparams.special_end_tkn)
        loss, outputs = self(
            batch_qs=batch_qs, 
            batch_ts=batch_ts, 
            batch_sqls=batch_sqls, 
            value_tkn_max_len=self.hparams.value_tkn_max_len, 
            train=False
        )

        return  {'loss': loss}  
    
    def train_dataloader(self):
        return self.create_dataloader(mode="train")

    def val_dataloader(self):
        return self.create_dataloader(mode="eval")    
  
    
    def load_data(self, sql_path: Union[Path, str], table_path: Union[Path, str]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
        """Load data from path

        Args:
            sql_path (Union[Path, str]): dataset path which contains NL with SQL queries (+answers)
            table_path (Union[Path, str]): table information contains table name, header and values

        Returns:
            Tuple[List[Dict[str, Any]], Dict[str, Any]]: [description]
        """    
        path_sql = Path(sql_path)
        path_table = Path(table_path)

        dataset = []
        table = {}
        with path_sql.open("r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
                x = json.loads(line.strip())
                dataset.append(x)

        with path_table.open("r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
                x = json.loads(line.strip())
                table[x['id']] = x

        return dataset, table
    
    def create_data_loader(self, mode):
        num_workers = 0 if os.name == "nt" else self.hparams.num_workers
        if mode == "train":
            shuffle = True
            batch_size = self.hparams.train_batch_size
            sql_file = self.hparams.train_sql_file
            table_file = self.hparams.train_table_file
        else:
            shuffle = False
            batch_size = self.hparams.eval_batch_size
            sql_file = self.hparams.eval_sql_file
            table_file = self.hparams.eval_table_file
        
        dataset, self.table = self.load_data(sql_file, table_file)
        
        data_loader = torch.utils.data.DataLoader(
            batch_size=batch_size,
            dataset=dataset,
            shuffle=shuffle,
            num_workers=num_workers,
            collate_fn=lambda x: x # now dictionary values are not merged!
        )
        return data_loader

    def get_batch_data(self, data: List[Dict[str, Any]], table: Dict[str, Dict[str, List[Any]]], start_tkn="[S]", end_tkn="[E]") -> Tuple[List[str], List[str], List[Dict[str, Any]]]:
        """[summary]

        Args:
            data (List[Dict[str, Any]]): [description]
            dbengine (DBEngine): [description]
            start_tkn (str, optional): [description]. Defaults to "[S]".
            end_tkn (str, optional): [description]. Defaults to "[E]".

        Returns:
            Tuple[List[str], List[str], List[Dict[str, Any]]]: [description]
        """    
        batch_qs = [jsonl["question"] for jsonl in data]
        tid = [jsonl["table_id"] for jsonl in data]
        batch_sqls = [jsonl["sql"] for jsonl in data]
        batch_ts = []
        for table_id in tid:
            table_str = f"{table_id}" + "".join([
                f"{self.hparams.special_col_tkn}{col}" for col in table[table_id]["header"]
            ])
            # TODO: [EXP] Experiment for generate column directly
            # table_str = f"{start_tkn}{table_id}{end_tkn}" + "".join([
            #     f"{col_tkn}{start_tkn}{col}{end_tkn}" for col in dbengine.schema
            # ]) 
            batch_ts.append(table_str)

        return batch_qs, batch_ts, batch_sqls
    
    def create_model(self, dbengine):
        model_bert, tokenizer_bert, config_bert = self.get_bert(model_path=self.hparams.model_bert_path)
        model = Text2SQL(
            model_bert=model_bert,
            tokenizer_bert=tokenizer_bert,
            special_end_tkn=self.hparams.special_end_tkn,
            input_size=config_bert.hidden_size,
            hidden_size=self.hparams.hidden_size,
            num_layers=self.hparams.num_layers,
            dropout_ratio=self.hparams.dropout_ratio,
            max_where_conds=self.hparams.max_where_conds,
            n_agg_ops=len(dbengine.agg_ops),
            n_cond_ops=len(dbengine.cond_ops),
            wn_penalty=self.hparams.wn_penalty,
            wo_penalty=self.hparams.wo_penalty,
            wv_penalty=self.hparams.wv_penalty
        )
        return model
    
    def get_bert(self, model_path: str, output_hidden_states: bool=False):
        self.special_tokens = [self.hparams.special_start_tkn, self.hparams.special_end_tkn, self.hparams.special_col_tkn] # sequence start, sequence end, column tokens
        tokenizer = KoBertTokenizer.from_pretrained(model_path, add_special_tokens=True, additional_special_tokens=self.special_tokens)
        config = BertConfig.from_pretrained(model_path)
        config.output_hidden_states = output_hidden_states

        model = BertModel.from_pretrained(model_path)
        model.resize_token_embeddings(len(tokenizer))
        model.config.output_hidden_states = output_hidden_states

        return model, tokenizer, config

In [None]:
args_dict = dict(
    db_path = "./private/samsung_new.db",
    model_bert_path = "monologg/kobert",
    # Dataloader
    train_sql_file = "./NLSQL_train.jsonl",
    train_table_file = "./table_train.jsonl",
    train_batch_size = 16,
    eval_file = "./NLSQL_test.jsonl",
    eval_table_file = "./table_test.jsonl",
    eval_batch_size = 16,
    num_workers = 4,
    # Model-decoder
    hidden_size = 100,
    num_layers = 2,
    dropout_ratio = 0.3,
    max_where_conds = 4,
    value_tkn_max_len = 20, 
    # Tokenizer
    special_start_tkn = "[S]", 
    special_end_tkn = "[E]",
    special_col_tkn = "[COL]",
    # Loss Function
    wn_penalty = 2.0,  # scale up for guessing where number
    wo_penalty = 4.0,
    wv_penalty = 5.0  # giving penalty if not generate enough tokens, if generate more than answer lenght will give 1/2 of it
)

In [189]:
args_dict = dict(
    db_path = "./data/samsung_new.db",
    model_bert_path = "monologg/kobert",
    # Dataloader
    train_sql_file = "./data/NLSQL_train.jsonl",
    train_table_file = "./data/table_train.jsonl",
    train_batch_size = 16,
    eval_file = "./data/NLSQL_test.jsonl",
    eval_table_file = "./data/table_test.jsonl",
    eval_batch_size = 16,
    num_workers = 4,
    # Model-decoder
    hidden_size = 100,
    num_layers = 2,
    dropout_ratio = 0.3,
    max_where_conds = 4,
)

## Traning

Stil Working on it

In [None]:
lr = 1e-3
lr_bert = 1e-5

opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=lr, weight_decay=0)
opt_bert = torch.optim.AdamW(filter(lambda p: p.requires_grad, model_bert.parameters()),
                            lr=lr_bert, weight_decay=0)

## Testing: Execution-guided beam decoding

Stil Working on it

In [37]:
beam_size = 4

select 

In [38]:
select_output, _ = select_decoder(question_padded, header_padded, col_padded, question_lengths)
select_output.size()

torch.Size([2, 20])

construct all possible select + (agg) score

In [39]:
batch_size, n_col = select_output.size()

select_prob = torch.softmax(select_output, 1)  # prob_sc
if n_col < beam_size:
    beam_size_max_col = n_col
else:
    beam_size_max_col = beam_size

prob_sc_sa = torch.zeros([batch_size, beam_size_max_col, n_agg_ops])
prob_sca = torch.zeros_like(prob_sc_sa)
print(prob_sca.size())  # (B, beam-size, n_agg_ops)

torch.Size([2, 4, 6])


In [42]:
# beamseacrh
_, pr_sc_beam = select_output.topk(k=beam_size_max_col)
print(f"sc top k: {pr_sc_beam.tolist()}")

for i_beam in range(beam_size_max_col):
    select_idx = pr_sc_beam[:, i_beam].tolist() # pr_sc
    agg_output, _ = agg_decoder(question_padded, col_padded, question_lengths, select_idx)
    agg_prob = torch.softmax(agg_output, dim=-1)  # prob_sa: (B, n_agg_ops)
    prob_sc_sa[:, i_beam, :] = agg_prob
    
    prob_sc_selected = select_prob[range(batch_size), select_idx]  # (B,)
    prob_sca[:, i_beam, :] = (agg_prob.t() * prob_sc_selected).t()  # (n_agg_ops, B) \odot (1, B) (broadcast) -> (B, max_col)

sc top k: [[19, 18, 0, 14], [3, 4, 6, 5]]


In [43]:
print(prob_sc_sa.data)

tensor([[[0.1765, 0.1692, 0.1639, 0.1756, 0.1588, 0.1561],
         [0.1777, 0.1687, 0.1647, 0.1742, 0.1588, 0.1558],
         [0.1764, 0.1690, 0.1643, 0.1751, 0.1593, 0.1558],
         [0.1779, 0.1693, 0.1647, 0.1732, 0.1581, 0.1568]],

        [[0.1778, 0.1697, 0.1638, 0.1724, 0.1561, 0.1601],
         [0.1778, 0.1712, 0.1638, 0.1742, 0.1554, 0.1577],
         [0.1789, 0.1702, 0.1634, 0.1741, 0.1557, 0.1578],
         [0.1774, 0.1712, 0.1640, 0.1729, 0.1564, 0.1581]]])


In [44]:
print(prob_sca.size())  # (B, beam_size, prob_sc(beam size selected) * prob_agg)
print(prob_sca.data)

torch.Size([2, 4, 6])
tensor([[[0.0088, 0.0085, 0.0082, 0.0088, 0.0079, 0.0078],
         [0.0089, 0.0084, 0.0082, 0.0087, 0.0079, 0.0078],
         [0.0088, 0.0084, 0.0082, 0.0088, 0.0080, 0.0078],
         [0.0089, 0.0085, 0.0082, 0.0087, 0.0079, 0.0078]],

        [[0.0089, 0.0085, 0.0082, 0.0086, 0.0078, 0.0080],
         [0.0089, 0.0086, 0.0082, 0.0087, 0.0078, 0.0079],
         [0.0089, 0.0085, 0.0082, 0.0087, 0.0078, 0.0079],
         [0.0089, 0.0086, 0.0082, 0.0086, 0.0078, 0.0079]]])


In [45]:
def topk_multi_dim(tensor, n_topk):
    batch_size = tensor.size(0)
    values_1d, idxes_1d = tensor.view(batch_size, -1).topk(n_topk)
    idxes = np.stack(np.unravel_index(idxes_1d, tensor.size()[1:])).transpose(1, 2, 0)
    values = tensor.view(batch_size, -1).gather(1, idxes_1d).numpy()
    return idxes, values

In [46]:
# First flatten to 1-d
if np.prod(prob_sca.shape[1:]) < beam_size:
    beam_size_sca = np.prod(prob_sca.shape[1:])
else:
    beam_size_sca = beam_size
# Now as sc_idx is already sorted, re-map them properly.
# idxes: [sc_beam_idx, sa_idx] -> sca_idxes: [sc_idx, sa_idx]
idxes, values = topk_multi_dim(prob_sca.detach().cpu(), n_topk=beam_size_sca)
sc_beam_idxes = idxes[:, :, 0]
sc_idxes = np.stack([pr_sc_beam.numpy()[i, sc_beam_idx] for i, sc_beam_idx in enumerate(sc_beam_idxes)])
sca_idxes = np.stack([sc_idxes, idxes[:, :, 1]]).transpose(1, 2, 0)

In [47]:
sca_idxes

array([[[14,  0],
        [18,  0],
        [19,  0],
        [ 0,  0]],

       [[ 6,  0],
        [ 3,  0],
        [ 4,  0],
        [ 5,  0]]], dtype=int64)

writing ...