# TEXT2SQL with transformers

Lee Woo Chul, Jang Ji Soo

---

In [2]:
import json
import torch
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import transformers

from typing import Tuple, Dict, List, Union, Any
import os

from dbengine import DBEngine
# multiprocessing lib doesn’t have it implemented on Windows
# https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/14
num_workers = 0 if os.name == "nt" else 4

print(f"PyTroch Version: {torch.__version__}")
print(f"Transfomers Version: {transformers.__version__}")

PyTroch Version: 1.8.1
Transfomers Version: 4.6.1


# Data Description

`NLSQL.jsonl` and `"table.jsonl` contains the data like following format same with [WikiSQL](https://github.com/salesforce/WikiSQL), Please follow the [link](https://github.com/salesforce/WikiSQL#content-and-format) to see what are the keys mean.

```json
// example of 'NLSQL.jsonl'
{
    "phase": 1, 
    "question": "2015 삼성전자 유동자산은 어떻게 돼?", 
    "table_id": "receipts", 
    "sql": {
        "sel": 16, 
        "agg": 0, 
        "conds": [[10, 0, "유동자산"], [3, 0, 2016]]
    }
}
```

In [4]:
def load_data(sql_path, table_path):
    path_sql = Path(sql_path)
    path_table = Path(table_path)

    dataset = []
    table = {}
    with path_sql.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            x = json.loads(line.strip())
            dataset.append(x)

    with path_table.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            x = json.loads(line.strip())
            table[x['id']] = x
            
    return dataset, table

In [6]:
data, table = load_data("NLSQL.jsonl", "table.jsonl")
data_loader = torch.utils.data.DataLoader(
    batch_size=2,
    dataset=data,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=lambda x: x # now dictionary values are not merged!
)
# Load DBEngine
db_path = Path("./private")
dbengine = DBEngine(db_path / "samsung_new.db")

In [58]:
for i, batch_data in enumerate(tqdm(data_loader, desc="Test with toy data")):
    break

Test with toy data:   0%|          | 0/21120 [00:00<?, ?it/s]

# Model

## Encoder

Used BERT in hugging Face with KoBERT

- https://github.com/SKTBrain/KoBERT
- https://github.com/monologg/KoBERT-Transformers

In [10]:
from KoBertTokenizer import KoBertTokenizer
from transformers import BertModel, BertConfig

def get_bert(model_path: str, device: str, output_hidden_states: bool=False):
    special_tokens = ["[S]", "[E]", "[COL]"] # sequence start, sequence end, column tokens
    tokenizer = KoBertTokenizer.from_pretrained(model_path, add_special_tokens=True, additional_special_tokens=special_tokens)
    config = BertConfig.from_pretrained(model_path)
    config.output_hidden_states = output_hidden_states
    
    model = BertModel.from_pretrained(model_path)
    model.resize_token_embeddings(len(tokenizer))
    model.config.output_hidden_states = output_hidden_states
    model.to(device)
    
    return model, tokenizer, config

In [11]:
model_path = "monologg/kobert"
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" 

model_bert, tokenizer_bert, config_bert = get_bert(model_path=model_path, device=device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [55]:
def get_batch_data(data, dbengine):
    batch_qs = [jsonl["question"] for jsonl in data]
    tid = [jsonl["table_id"] for jsonl in data]
    batch_sqls = [jsonl["sql"] for jsonl in data]
    batch_ts = []
    for table_id in tid:
        dbengine.get_schema_info(table_id)
        table_str = f"{table_id}" + "".join([
            f"[COL]{col}" for col in dbengine.schema
        ]) 
        batch_ts.append(table_str)
    
    return batch_qs, batch_sqls, batch_ts

In [57]:
batch_qs, batch_sqls, batch_ts = get_batch_data(batch_data, dbengine)

In [13]:
# Get Input
encode_input = tokenizer_bert(
    batch_qs, batch_ts, 
    max_length=512, padding=True, truncation=True, return_tensors="pt", 
    return_attention_mask=True, 
    return_special_tokens_mask=False, 
)

In [14]:
# Show an Example of Input
print(tokenizer_bert.decode(encode_input["input_ids"][0]))

[CLS] 삼성전자의 2017년도 영업이익이 어때?[SEP] receipts [COL] index [COL] rcept_no [COL] reprt_code [COL] bsns_year [COL] corp_code [COL] stock_code [COL] fs_div [COL] fs_nm [COL] sj_div [COL] sj_nm [COL] account_nm [COL] thstrm_nm [COL] thstrm_dt [COL] thstrm_amount [COL] frmtrm_nm [COL] frmtrm_dt [COL] frmtrm_amount [COL] bfefrmtrm_nm [COL] bfefrmtrm_dt [COL] bfefrmtrm_amount[SEP][PAD][PAD]


In [63]:
type(tokenizer_bert)

KoBertTokenizer.KoBertTokenizer

## Prepare for decoder Inputs: Createing masks

In [17]:
def get_decoder_input_mask(input_ids, mask, batch_size, start_tkn_id, end_tkn_id):
    r"""
    input should only contains word tokens:
    """
    start_tkn_mask = input_ids == start_tkn_id
    end_tkn_mask = input_ids == end_tkn_id
    start_end_mask = torch.bitwise_or(start_tkn_mask, end_tkn_mask)
    index = torch.arange(input_ids.size(1)).repeat(batch_size)[start_end_mask.view(-1)].view(batch_size, -1)
    return mask.scatter(1, index, False)

def get_input_mask_and_answer(encode_input, tokenizer):
    r"""
    table -> database table name(id)
    header -> database header
    
    returns:
        input_question_mask, input_table_mask, input_header_mask, answer_table_tkns, answer_header_tkns
    """
    batch_size, max_length = encode_input["input_ids"].size()
    sep_tkn_mask = encode_input["input_ids"] == tokenizer.sep_token_id
    start_tkn_id, end_tkn_id, col_tkn_id = tokenizer.additional_special_tokens_ids
    
    input_question_mask = torch.bitwise_and(encode_input["token_type_ids"] == 0, encode_input["attention_mask"].bool())
    input_question_mask = torch.bitwise_and(input_question_mask, ~sep_tkn_mask) # [SEP] mask out
    input_question_mask[:, 0] = False  # [CLS] mask out

    db_mask = torch.bitwise_and(encode_input["token_type_ids"] == 1, encode_input["attention_mask"].bool())
    db_mask = torch.bitwise_xor(db_mask, sep_tkn_mask)
    col_tkn_mask = encode_input["input_ids"] == col_tkn_id
    db_mask = torch.bitwise_and(db_mask, ~col_tkn_mask)
    # split table_mask and header_mask
    input_idx = torch.arange(max_length).repeat(batch_size, 1)
    db_idx = input_idx[db_mask]
    table_header_tkn_idx = db_idx[db_idx > 0]
    table_start_idx = table_header_tkn_idx.view(batch_size, -1)[:, 0] + 1
    start_idx = table_header_tkn_idx[1:][table_header_tkn_idx.diff() == 2].view(batch_size, -1)
    table_end_sep_idx = start_idx[:, 0] - 1
    split_size = torch.stack([
        table_end_sep_idx-table_start_idx+1, table_header_tkn_idx.view(batch_size, -1).size(1)-(table_end_sep_idx-table_start_idx+1)
    ]).transpose(0, 1)

    # Token idx
    table_tkn_idx, header_tkn_idx = map(
        lambda x: torch.stack(x), 
        zip(*[torch.split(x, size.tolist()) for x, size in zip(table_header_tkn_idx.view(batch_size, -1), split_size)])
    )

    table_tkn_idx = table_tkn_idx[:, 1:]
    # Mask include [S] & [E] tokens
    table_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, table_tkn_idx, True)
    header_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, header_tkn_idx, True)

    # For Decoder Input, Maskout [S], [E] for table & header  
    input_table_mask = get_decoder_input_mask(
        encode_input["input_ids"], table_tkn_mask, batch_size, start_tkn_id, end_tkn_id
    )
    input_header_mask = get_decoder_input_mask(
        encode_input["input_ids"], header_tkn_mask, batch_size, start_tkn_id, end_tkn_id
    )
    # [COL] token mask: this is for attention
    col_tkn_idx = input_idx[col_tkn_mask].view(batch_size, -1)
    input_col_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, col_tkn_idx, True)

    return input_question_mask, input_table_mask, input_header_mask, input_col_mask # , answer_table_tkns, answer_header_tkns

In [18]:
input_question_mask, input_table_mask, input_header_mask, input_col_mask = get_input_mask_and_answer(encode_input, tokenizer_bert)

In [23]:
for m, t in zip(
        [input_question_mask, input_table_mask, input_header_mask, input_col_mask], 
        ["Question Tokens for Decoder", "Table Tokens for Decoder", "Header Tokens for Decoder", "Column(Index of Headers) Tokens for Decoder"]
    ):
    print(t)
    print("-----"*5)
    print(tokenizer_bert.decode(encode_input["input_ids"][m]))
    print()

Question Tokens for Decoder
-------------------------
삼성전자의 2017년도 영업이익이 어때? 삼성전자 2019의 이익잉여금은 어때?

Table Tokens for Decoder
-------------------------
receipts receipts

Header Tokens for Decoder
-------------------------
index rcept_no reprt_code bsns_year corp_code stock_code fs_div fs_nm sj_div sj_nm account_nm thstrm_nm thstrm_dt thstrm_amount frmtrm_nm frmtrm_dt frmtrm_amount bfefrmtrm_nm bfefrmtrm_dt bfefrmtrm_amount index rcept_no reprt_code bsns_year corp_code stock_code fs_div fs_nm sj_div sj_nm account_nm thstrm_nm thstrm_dt thstrm_amount frmtrm_nm frmtrm_dt frmtrm_amount bfefrmtrm_nm bfefrmtrm_dt bfefrmtrm_amount

Column(Index of Headers) Tokens for Decoder
-------------------------
[COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL]



In [24]:
# Feed to BERT Model
encode_outputs = model_bert(**encode_input)


The `encode_outputs` will be selected by 4 types of masks
```
encode_outputs
-> Question
-> Table
-> Header
-> Column(Index of Headers)
```

And pad batches which has less tokens than max length with "\[PAD\]"  for Decoder Input


In [25]:
def pad(batches: Tuple[torch.Tensor], lengths: List[int], model: BertModel, pad_idx: int=1) -> torch.Tensor:
    padded = []
    max_length = max(lengths)
    for x in batches:
        if len(x) < max_length:
            pad_tensor = model.embeddings.word_embeddings(torch.LongTensor([pad_idx]*(max_length - len(x))))
            padded.append(torch.cat([x, pad_tensor]))
        else:
            padded.append(x)
    return torch.stack(padded)

def get_decoder_batches(encode_output, mask, model, pad_idx):
    lengths = mask.sum(1)
    tensors = encode_output.last_hidden_state[mask, :]
    batches = torch.split(tensors, lengths.tolist())
    if lengths.ne(lengths.max()).sum().item() != 0:
        # pad not same length tokens
        tensors_padded = pad(batches, lengths.tolist(), model, pad_idx=pad_idx)
    else:
        # just stack the splitted tensors
        tensors_padded = torch.stack(batches)
    return tensors_padded, lengths.tolist()

def get_pad_mask(lengths):
    batch_size = len(lengths)
    max_len = max(lengths)
    mask = torch.ones(batch_size, max_len)
    for i, l in enumerate(lengths):
        mask[i, :l] = 0
    return mask

In [26]:
question_padded, question_lengths = get_decoder_batches(encode_outputs, input_question_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
table_padded, table_lengths = get_decoder_batches(encode_outputs, input_table_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
header_padded, header_lengths = get_decoder_batches(encode_outputs, input_header_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
col_padded, col_lengths = get_decoder_batches(encode_outputs, input_col_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)

In [69]:
type(model_bert)

transformers.models.bert.modeling_bert.BertModel

## Create the Answers for decoder output

In [43]:
def get_sql_answers(batch_sqls, tokenizer, end_tkn_idx=1):
    """
    for backward compatibility, separated with get_g
    
    sc: select column
    sa: select agg
    wn: where number
    wc: where column
    wo: where operator
    wv: where value
    """

    get_ith_element = lambda li, i: [x[i] for x in li]
    g_sc = []
    g_sa = []
    g_wn = []
    g_wc = []
    g_wo = []
    g_wv = []
    for b, sql_dict in enumerate(batch_sqls):
        g_sc.append( sql_dict["sel"] )
        g_sa.append( sql_dict["agg"])

        conds = sql_dict["conds"]
        if not sql_dict["agg"] < 0:
            g_wn.append( len(conds) )
            g_wc.append( get_ith_element(conds, 0) )
            g_wo.append( get_ith_element(conds, 1) )
            g_wv.append( get_ith_element(conds, 2) )
        else:
            raise EnvironmentError
    
    # get where value tokenized 
    end_tkn = tokenizer.additional_special_tokens[end_tkn_idx]
    pad_tkn_id = tokenizer.pad_token_id
    g_wv_tkns = [[f"{s}{end_tkn}" for s in batch_wv] for batch_wv in g_wv]
    g_wv_tkns = [tokenizer(batch_wv, add_special_tokens=False)["input_ids"] for batch_wv in g_wv_tkns]
    # add empty list if batch has different where column number
    max_where_cols = max([len(batch_wv) for batch_wv in g_wv_tkns])
    g_wv_tkns = [batch_wv + [[]]*(max_where_cols-len(batch_wv)) if len(batch_wv) < max_where_cols else batch_wv for batch_wv in g_wv_tkns]
    temp = []
    for batch_wv in list(zip(*g_wv_tkns)):
        batch_max_len = max(map(len, batch_wv))
        batch_temp = []
        for wv_tkns in batch_wv:  # iter by number of where clause
            if len(wv_tkns) < batch_max_len:
                batch_temp.append(wv_tkns + [pad_tkn_id]*(batch_max_len - len(wv_tkns)))
            else:
                batch_temp.append(wv_tkns)
        temp.append(batch_temp)
    g_wv_tkns = list(zip(*temp))

    return g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns

In [44]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns = get_sql_answers(batch_sqls, tokenizer_bert, 1)
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns

([16, 16],
 [0, 0],
 [2, 2],
 [[10, 3], [10, 3]],
 [[0, 0], [0, 0]],
 [['영업이익', 2018], ['이익잉여금', 2020]],
 [([3383, 8003, 1, 1, 1], [554, 115, 8003]),
  ([3736, 7144, 6916, 5550, 8003], [554, 127, 8003])])

## Decoder

Similar structure in SQLova but a little difference in here.

- SQLova is a neural semantic parser translating natural language utterance to SQL query.
- Official Github: [https://github.com/naver/sqlova](https://github.com/naver/sqlova)
- Paper: [A Comprehensive Exploration on WikiSQL with Table-Aware Word Contextualization](https://arxiv.org/abs/1902.01069)

<img src="https://drive.google.com/uc?id=1PW9oAXfW-ZI-jxGn5q9O_gzUIZnNYaet" alt="Sqlova Decoder Architecture " width="100%" height="auto">

## Attention Layers

In [29]:
class AttentionBase(nn.Module):
    def __init__(self):
        super().__init__()
    
    def wipe_out_pad_tkn_score(self, score, lengths, dim=2):
        max_len = max(lengths)
        for batch_idx, length in enumerate(lengths):
            if length < max_len:
                if dim == 2:
                    score[batch_idx, :, length:] = -10000000
                elif dim == 1:
                    score[batch_idx, length:, :] = 0.0
                else:
                    raise ValueError(f"`dim` in wipe_out_pad_tkn_score should be 1 or 2")
        return score 


class C2QAttention(AttentionBase):
    r"""Decoder Column to Question Attention Module"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, o_c, o_q, q_lengths, c_lengths=None, rt_attn=False):
        r"""
        Calculate for each column tokens, How much related to question tokens?
        
        o_c: LSTM output of column
        o_q: LSTM output of question 
        
        c_lengths: wipe out row length
        return context atttended to question tokens
        """
        sqrt_H = torch.sqrt(torch.FloatTensor([o_c.size(-1)], device=o_c.device))  # Apply Attention is All you Need Technique
        o_q_transform = self.linear(o_q)  # (B, T_q, H)
        score_c2q = torch.bmm(o_c, o_q_transform.transpose(1, 2)) / sqrt_H  # (B, T_c, H) x (B, H, T_q) = (B, T_c, T_q)
        score_c2q = self.wipe_out_pad_tkn_score(score_c2q, q_lengths, dim=2)
        
        prob_c2q = self.softmax(score_c2q)
        if c_lengths is not None:
            prob_c2q = self.wipe_out_pad_tkn_score(prob_c2q, c_lengths, dim=1)
        # prob_c2q: (B, T_c, T_q) -> (B, T_c, T_q, 1)
        # o_q: (B, 1, T_q, H)
        # p_col2question \odot o_q = (B, T_c, T_q, 1) \odot (B, 1, T_q, H) = (B, T_c, T_q, H)
        # -> reduce sum to T_q to get context for each column (B, T_c, H)
        context = torch.mul(prob_c2q.unsqueeze(3), o_q.unsqueeze(1)).sum(dim=2)
        if rt_attn:
            attn = prob_c2q
        else:
            attn = None
        return context, attn

class SelfAttention(AttentionBase):
    r"""Decoder Self Attention Module"""
    def __init__(self, in_features, out_features=1):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, o, lengths, rt_attn=False):
        r"""
        Calculate for each o tokens, How much related to o tokens?
        
        return attended summary of o
        """
        o_transform = self.linear(o)  # (B, T_o, H) -> (B, T_o, 1)
        o_transform = self.wipe_out_pad_tkn_score(o_transform, lengths) 
        o_prob = self.softmax(o_transform)  # (B, T_o, 1)
        
        o_summary = torch.mul(o, o_prob).sum(1)  # (B, T_o, H) \odot (B, T_o, 1) -> (B, H)

        if rt_attn:
            attn = o_prob
        else:
            attn = None
        return o_summary, attn


## Decoder Sub Layers

In [46]:
class SelectDecoder(nn.Module):
    r"""SELECT Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict column index
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_q_context, attn = self.col2question_attn(col_context, o_q, question_lengths, col_lengths, rt_attn)  # (B, T_c, H), (B, T_c, T_q)
        
        vec = torch.cat([col_q_context, col_context], dim=2)  # (B, T_c, 2H)
        output = self.output_layer(vec)
        # TODO: add penalty for padded header(column) information
        
        return output.squeeze(-1), attn
    

class AggDecoder(nn.Module):
    r"""AGG Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size)
        )
                
    def forward(self, question_padded, col_padded, question_lengths: List[int], col_lengths: List[int], select_idxes: List[int], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        
        col_selected = col_context[list(range(batch_size)), select_idxes].unsqueeze(1)  # col_selected: (B, 1, H)
        
        col_q_context, attn = self.col2question_attn(col_selected, o_q, question_lengths, col_lengths, rt_attn)  # (B, 1, H), (B, 1, T_q)
        output = self.output_layer(col_q_context.squeeze(1))
        
        return output, attn
    
    
class WhereNumDecoder(nn.Module):
    r"""WHERE number Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3, max_where_conds=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        if self.output_size > self.max_where_conds+1:
            # HERE output will be dilivered to cross-entropy loss, not guessing the real number of where clause
            raise ValueError(f"`WhereNumDecoder` only support maximum {max_where_conds} where clause")
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_self_attn = SelfAttention(2*hidden_size, 1)
        self.lstm_q_hidden_init_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        self.lstm_q_cell_init_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        
        self.context_self_attn = SelfAttention(hidden_size, 1)
        
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size)
        )
        
        
    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        """
        batch_size, n_col, _ = col_padded.size()
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)

        col_self_attn, col_attn = self.col_self_attn(col_context, col_lengths, rt_attn)  # (B, 2H), (B, T_c)

        h_0 = self.lstm_q_hidden_init_linear(col_self_attn)  # (B, 2H)
        h_0 = h_0.view(batch_size, 2*self.num_layers, -1).transpose(0, 1).contiguous()  # (B, n_direc*num_layers, H/2) -> (n_direc*num_layers, B, H/2)
        c_0 = self.lstm_q_cell_init_linear(col_self_attn)  # (B, 2H)
        c_0 = c_0.view(batch_size, 2*self.num_layers, -1).transpose(0, 1).contiguous()  # (B, n_direc*num_layers, H/2) -> (n_direc*num_layers, B, H/2)
        
        o_q, (h_q, c_q) = self.lstm_q(question_padded, (h_0, c_0))  # o_q: (B, T_q, H)
        o_summary, o_attn = self.context_self_attn(o_q, question_lengths, rt_attn)  # (B, H), (B, T_q)
        output = self.output_layer(o_summary)
        
        return output, (col_attn, o_attn)

    
class WhereColumnDecoder(nn.Module):
    r"""WHERE Column Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int=1, num_layers: int=2, dropout_ratio:float=0.3, max_where_conds: int=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio

        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict column index
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_q_context, attn = self.col2question_attn(col_context, o_q, question_lengths, col_lengths, rt_attn)  # (B, T_c, H), (B, T_c, T_q)
        
        vec = torch.cat([col_q_context, col_context], dim=2)  # (B, T_c, 2H)
        output = self.output_layer(vec)
        # TODO: add penalty for padded header(column) information
        
        return output.squeeze(-1), attn
    
    
class WhereOpDecoder(nn.Module):
    r"""WHERE Opperator Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio: float=0.3, max_where_conds: int=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )
    
    def forward(self, question_padded, col_padded, question_lengths: List[int], where_nums: List[int], where_col_idxes: List[List[int]], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        max_where_col_nums is settled at WhereColumnDecoder, but it can be lower than or equal to `max_where_conds`
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_context_padded = self.get_context_padded(col_context, where_nums, where_col_idxes)  # (B, max_where_col_nums, H)
        
        col_q_context, attn = self.col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn)  # (B, max_where_col_nums, H), (B, max_where_col_nums, T_q)
        
        vec = torch.cat([col_q_context, col_context_padded], dim=2)  # (B, max_where_col_nums, 2H)
        output = self.output_layer(vec)  # (B, max_where_col_nums, n_cond_ops)
        # TODO: add penalty for padded header(column) information
        return output
        
    def get_context_padded(self, col_context, where_nums, where_col_idxes):
        r"""
        Select the where column index and pad if some batch doesn't match the max length of tensor
        In case for have different where column lengths
        """
        batch_size, n_col, hidden_size = col_context.size()
        max_where_col_nums = max(where_nums)
        batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size, device=col_context.device)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
            
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)
    
    
class WhereValueDecoder(nn.Module):
    r"""WHERE Value Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio: float=0.3, max_where_conds: int=4, n_cond_ops: int=4,
                 start_tkn_id=8002, end_tkn_id=8003, embedding_layer=None) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        self.n_cond_ops = n_cond_ops
        
        self.start_tkn_id = start_tkn_id
        self.end_tkn_id = end_tkn_id
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.where_op_linear = nn.Linear(n_cond_ops, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        if embedding_layer is None:
            raise KeyError("Must initialize the embedding_layer to BertModel's word embedding layer")
        else:
            if not isinstance(embedding_layer, torch.nn.modules.sparse.Embedding):
                embedding_layer = embedding_layer.word_embeddings
            self.embedding_layer = embedding_layer
            vocab_size, bert_hidden_size = embedding_layer.weight.data.size()
            self.output_lstm_hidden_init_linear = nn.Linear(3*hidden_size, bert_hidden_size)
            self.output_lstm_cell_init_linear = nn.Linear(3*hidden_size, bert_hidden_size)
            self.output_lstm = nn.LSTM(bert_hidden_size, bert_hidden_size, 1, batch_first=True)
            self.output_linear = nn.Linear(bert_hidden_size, vocab_size)
            self.output_linear.weight.data = embedding_layer.weight.data

        
    def forward(self, question_padded, col_padded, question_lengths: List[int], where_nums: List[int], where_col_idxes: List[List[int]], where_op_idxes: List[List[int]], value_tkn_max_len=None, g_wv_tkns=None, rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        max_where_col_nums is setted at WhereColumnDecoder
        value_tkn_max_len = Test if None else Train
        g_wv_tkns = When Train should not be None
        
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_context_padded = self.get_context_padded(col_context, where_nums, where_col_idxes)  # (B, max_where_col_nums, H)
        
        col_q_context, attn = self.col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn)  # (B, max_where_col_nums, H), (B, max_where_col_nums, T_q)
        where_op_one_hot_padded = self.get_where_op_one_hot_padded(where_op_idxes, where_nums, where_col_idxes, n_cond_ops=self.n_cond_ops)#.to(o_q.device)  # (B, max_where_col_nums, n_cond_ops)
        where_op = self.where_op_linear(where_op_one_hot_padded)  # (B, max_where_col_nums, H)
        
        vec = torch.cat([col_q_context, col_context_padded, where_op], dim=2)  # (B, max_where_col_nums, 3H)
        max_where_col_nums = vec.size(1)
        # predict each where_col
        total_scores = []
        for i in range(max_where_col_nums):
            g_wv_tkns_i = torch.LongTensor([g_wv_tkns[b_idx][i] for b_idx in range(batch_size)]) if g_wv_tkns is not None else None  # (B, T_d_i)
            vec_i = vec[:, i, :]  # (B, 3H)
            
            h_0 = self.output_lstm_hidden_init_linear(vec_i).unsqueeze(1).transpose(0, 1).contiguous()  # (B, 3H) -> (B, bert_H) -> (1, B, bert_H)
            c_0 = self.output_lstm_cell_init_linear(vec_i).unsqueeze(1).transpose(0, 1).contiguous()  # (B, 3H) -> (B, bert_H) -> (1, B, bert_H)
            
            scores = self.decode_single_where_col(batch_size, h_0, c_0, value_tkn_max_len=value_tkn_max_len, g_wv_tkns_i=g_wv_tkns_i)  # (B, T_d_i, vocab_size)
            total_scores.append(scores)
        
        # total_scores: [(B, T_d_i, vocab_size)] x max_where_col_nums
        return total_scores
    
    def start_token(self, batch_size):
        sos = torch.LongTensor([self.start_tkn_id]*batch_size).unsqueeze(1)  # (B, 1)
        return sos
    
    def decode_single_where_col(self, batch_size, h_0, c_0, value_tkn_max_len=None, g_wv_tkns_i=None):
        if value_tkn_max_len is None:
            # [Training] set the max length to gold token max length (already padded)
            max_len = len(g_wv_tkns_i[0])
        else:
            # [Testing]  don't know the max length
            max_len = value_tkn_max_len
            
        sos = self.start_token(batch_size)  # (B, 1)
        emb = self.embedding_layer(sos)  # (B, 1, bert_H)
        scores = [] 
        for i in range(max_len):
            o, (h, c) = self.output_lstm(emb, (h_0, c_0))  # h: (1, B, bert_H)  
            s = self.output_linear(h[-1, :]) # select last layer if use multiple rnn layers, h: (1, B, bert_H) -> (B, bert_H) -> s: (B, vocab_size)
            scores.append(s)
            
            if g_wv_tkns_i is not None:
                # [Training] Teacher Force model
                pred = g_wv_tkns_i[:, i]  # (B, )
            else:
                # [Testing]
                pred = s.argmax(1)  # (1,) only for single batch_size
                if pred.item() == self.end_tkn_id:
                    break
                    
            emb = self.embedding_layer(pred.unsqueeze(1))  # (B, 1, bert_H)
        
        return torch.stack(scores).transpose(0, 1).contiguous() # (T_d_i, B, vocab_size) -> (B, T_d_i, vocab_size)
        
    def get_context_padded(self, col_context: torch.Tensor, where_nums: List[int], where_col_idxes: List[List[int]]):
        r"""
        Select the where column index and pad if some batch doesn't match the max length of tensor
        In case for have different where column lengths
        """
        batch_size, n_col, hidden_size = col_context.size()
        max_where_col_nums = max(where_nums)
        batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
            
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)
    
    
    def get_where_op_one_hot_padded(self, where_op_idxes: List[List[int]], where_nums: List[int], where_col_idxes: List[List[int]], n_cond_ops: int):
        r"""
        Turn where operation indexs into one hot encoded vectors
        In case for have different where column lengths
        """
        max_where_col_nums = max(where_nums)
        batches = [torch.zeros(where_num, n_cond_ops).scatter(1, torch.LongTensor(batch_col).unsqueeze(1), 1) for where_num, batch_col in zip(where_nums, where_op_idxes)]  
        # batches = [(where_col_nums, n_cond_ops), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), n_cond_ops)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)

## Decoder Module

### Variables

In [47]:
input_size = config_bert.hidden_size
hidden_size = 100
num_layers = 2
dropout_ratio = 0.3
max_where_conds = 4
n_agg_ops = len(dbengine.agg_ops)
n_cond_ops = len(dbengine.cond_ops)
start_tkn_id = tokenizer_bert.additional_special_tokens_ids[0]
end_tkn_id = tokenizer_bert.additional_special_tokens_ids[1]
embedding_layer = model_bert.embeddings.word_embeddings
train = True
if train:
    value_tkn_max_len = None
else:
    value_tkn_max_len = 20

In [48]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_ratio, max_where_conds, n_agg_ops, n_cond_ops, start_tkn_id, end_tkn_id, value_tkn_max_len, embedding_layer):
        super().__init__()
        self.select_decoder = SelectDecoder(
            input_size, hidden_size, output_size=1, num_layers=num_layers, dropout_ratio=dropout_ratio
        )
        self.agg_decoder = AggDecoder(
            input_size, hidden_size, output_size=n_agg_ops, num_layers=num_layers, dropout_ratio=dropout_ratio
        )
        self.where_num_decoder = WhereNumDecoder(
            input_size, hidden_size, output_size=(max_where_conds+1), num_layers=num_layers, dropout_ratio=dropout_ratio
        )
        self.where_col_decoder = WhereColumnDecoder(
            input_size, hidden_size, output_size=1, num_layers=num_layers, dropout_ratio=dropout_ratio, max_where_conds=max_where_conds
        )
        self.where_op_decoder = WhereOpDecoder(
            input_size, hidden_size, output_size=n_cond_ops, num_layers=num_layers, dropout_ratio=dropout_ratio, max_where_conds=max_where_conds
        )
        self.where_value_decoder = WhereValueDecoder(
            input_size, hidden_size, output_size=n_cond_ops, num_layers=num_layers, dropout_ratio=dropout_ratio, max_where_conds=max_where_conds, 
            n_cond_ops=n_cond_ops, start_tkn_id=start_tkn_id, end_tkn_id=end_tkn_id, embedding_layer=embedding_layer
        )
    
    
    def forward(self, question_padded, header_padded, col_padded, question_lengths, col_lengths, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns):
        decoder_outputs = {}

        select_outputs, _ = self.select_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
        select_idxes = g_sc if g_sc else predict_decoder("sc", select_outputs=select_outputs)

        agg_outputs, _ = self.agg_decoder(question_padded, col_padded, question_lengths, col_lengths, select_idxes)

        where_num_outputs, _  = self.where_num_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
        where_nums = g_wn if g_wn else predict_decoder("wn", where_num_outputs=where_num_outputs)

        where_col_outputs, _ = self.where_col_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
        where_col_argsort = torch.sigmoid(where_col_outputs).argsort(1)
        where_col_idxes = g_wc if g_wc else predict_decoder("wc", where_col_argsort=where_col_argsort, where_nums=where_nums)

        where_op_outputs = self.where_op_decoder(question_padded, col_padded, question_lengths, where_nums, where_col_idxes)
        where_op_idxes = g_wo if g_wo else predict_decoder("wo", where_op_outputs=where_op_outputs, where_nums=where_nums)

        where_value_outputs = self.where_value_decoder(question_padded, col_padded, question_lengths, where_nums, where_col_idxes, where_op_idxes, value_tkn_max_len, g_wv_tkns)

        decoder_outputs = {
            "sc": select_outputs,
            "sa": agg_outputs,
            "wn": where_num_outputs,
            "wc": where_col_outputs,
            "wo": where_op_outputs,
            "wv": where_value_outputs
        }
        
        return decoder_outputs
        
    def predict_decoder(typ, **kwargs):
        r"""
        if not using teacher force model will use this function to predict answer
        """
        if typ == "sc":  # SELECT column
            select_outputs = kwargs["select_outputs"]
            return select_outputs.argmax(1).tolist()
        elif typ == "sa":  # SELECT aggregation operator
            # not need actually
            agg_outputs = kwargs["agg_outputs"]
            return agg_outputs.argmax(1)
        elif typ == "wn":  # WHERE number
            where_num_outputs = kwargs["where_num_outputs"]
            return where_num_outputs.argmax(1).tolist()
        elif typ == "wc":  # WHERE clause column
            where_col_argsort = kwargs["where_col_argsort"]
            where_nums = kwargs["where_nums"]
            where_col_idxes = [where_col_argsort[b_idx, :w_num].tolist() for b_idx, w_num in enumerate(where_nums)]
            return where_col_idxes
        elif typ == "wo":  # WHERE clause operator
            where_op_outputs = kwargs["where_op_outputs"]
            where_nums = kwargs["where_nums"]
            where_op_idxes = [where_op_outputs.argmax(2)[b_idx, :w_num].tolist() for b_idx, w_num in enumerate(where_nums)]
            return where_op_idxes
        elif typ == "wv":  # WHERE clause value
            # not need actually
            where_value_outputs = kwargs["where_value_outputs"]
            return [o.argmax(2) for o in where_value_outputs]
        else:
            raise KeyError("`typ` must be in ['sc', 'sa', 'wn', 'wc', 'wo', 'wv']")

In [49]:
model = Decoder(input_size, hidden_size, num_layers, dropout_ratio, max_where_conds, n_agg_ops, n_cond_ops, start_tkn_id, end_tkn_id, value_tkn_max_len, embedding_layer)

In [50]:
decoder_outputs = model(question_padded, header_padded, col_padded, question_lengths, col_lengths, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_tkns)

In [51]:
decoder_outputs

{'sc': tensor([[0.0293, 0.0333, 0.0291, 0.0303, 0.0326, 0.0241, 0.0281, 0.0251, 0.0200,
          0.0293, 0.0290, 0.0245, 0.0304, 0.0330, 0.0350, 0.0344, 0.0328, 0.0284,
          0.0287, 0.0300],
         [0.0249, 0.0244, 0.0189, 0.0174, 0.0193, 0.0153, 0.0147, 0.0188, 0.0083,
          0.0152, 0.0211, 0.0284, 0.0196, 0.0206, 0.0239, 0.0259, 0.0250, 0.0226,
          0.0247, 0.0204]], grad_fn=<SqueezeBackward1>),
 'sa': tensor([[-0.1265, -0.0473,  0.0327,  0.0128, -0.0155,  0.0326],
         [-0.0944, -0.0563,  0.0321,  0.0297, -0.0045,  0.0324]],
        grad_fn=<AddmmBackward>),
 'wn': tensor([[ 0.0133, -0.0735,  0.0406, -0.0394,  0.0236],
         [ 0.0189, -0.0583,  0.0404, -0.0418,  0.0382]],
        grad_fn=<AddmmBackward>),
 'wc': tensor([[ 9.3779e-03,  3.8128e-03,  1.5921e-03,  2.8032e-03, -2.9681e-03,
          -2.8121e-03, -3.1518e-03, -1.6126e-03,  2.7844e-03,  6.7169e-03,
           4.4253e-03,  4.8969e-03,  6.2904e-03,  5.4647e-03,  2.8376e-03,
           3.4574e-03,  3.9

# Whole Model

In [52]:
import pytorch_lightning as pl

In [None]:
class Model(pl.LightningDataModule):
    def __init__()

## Traning

Stil Working on it

In [None]:
lr = 1e-3
lr_bert = 1e-5

opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=lr, weight_decay=0)
opt_bert = torch.optim.AdamW(filter(lambda p: p.requires_grad, model_bert.parameters()),
                            lr=lr_bert, weight_decay=0)

## Testing: Execution-guided beam decoding

Stil Working on it

In [37]:
beam_size = 4

select 

In [38]:
select_output, _ = select_decoder(question_padded, header_padded, col_padded, question_lengths)
select_output.size()

torch.Size([2, 20])

construct all possible select + (agg) score

In [39]:
batch_size, n_col = select_output.size()

select_prob = torch.softmax(select_output, 1)  # prob_sc
if n_col < beam_size:
    beam_size_max_col = n_col
else:
    beam_size_max_col = beam_size

prob_sc_sa = torch.zeros([batch_size, beam_size_max_col, n_agg_ops])
prob_sca = torch.zeros_like(prob_sc_sa)
print(prob_sca.size())  # (B, beam-size, n_agg_ops)

torch.Size([2, 4, 6])


In [42]:
# beamseacrh
_, pr_sc_beam = select_output.topk(k=beam_size_max_col)
print(f"sc top k: {pr_sc_beam.tolist()}")

for i_beam in range(beam_size_max_col):
    select_idx = pr_sc_beam[:, i_beam].tolist() # pr_sc
    agg_output, _ = agg_decoder(question_padded, col_padded, question_lengths, select_idx)
    agg_prob = torch.softmax(agg_output, dim=-1)  # prob_sa: (B, n_agg_ops)
    prob_sc_sa[:, i_beam, :] = agg_prob
    
    prob_sc_selected = select_prob[range(batch_size), select_idx]  # (B,)
    prob_sca[:, i_beam, :] = (agg_prob.t() * prob_sc_selected).t()  # (n_agg_ops, B) \odot (1, B) (broadcast) -> (B, max_col)

sc top k: [[19, 18, 0, 14], [3, 4, 6, 5]]


In [43]:
print(prob_sc_sa.data)

tensor([[[0.1765, 0.1692, 0.1639, 0.1756, 0.1588, 0.1561],
         [0.1777, 0.1687, 0.1647, 0.1742, 0.1588, 0.1558],
         [0.1764, 0.1690, 0.1643, 0.1751, 0.1593, 0.1558],
         [0.1779, 0.1693, 0.1647, 0.1732, 0.1581, 0.1568]],

        [[0.1778, 0.1697, 0.1638, 0.1724, 0.1561, 0.1601],
         [0.1778, 0.1712, 0.1638, 0.1742, 0.1554, 0.1577],
         [0.1789, 0.1702, 0.1634, 0.1741, 0.1557, 0.1578],
         [0.1774, 0.1712, 0.1640, 0.1729, 0.1564, 0.1581]]])


In [44]:
print(prob_sca.size())  # (B, beam_size, prob_sc(beam size selected) * prob_agg)
print(prob_sca.data)

torch.Size([2, 4, 6])
tensor([[[0.0088, 0.0085, 0.0082, 0.0088, 0.0079, 0.0078],
         [0.0089, 0.0084, 0.0082, 0.0087, 0.0079, 0.0078],
         [0.0088, 0.0084, 0.0082, 0.0088, 0.0080, 0.0078],
         [0.0089, 0.0085, 0.0082, 0.0087, 0.0079, 0.0078]],

        [[0.0089, 0.0085, 0.0082, 0.0086, 0.0078, 0.0080],
         [0.0089, 0.0086, 0.0082, 0.0087, 0.0078, 0.0079],
         [0.0089, 0.0085, 0.0082, 0.0087, 0.0078, 0.0079],
         [0.0089, 0.0086, 0.0082, 0.0086, 0.0078, 0.0079]]])


In [45]:
def topk_multi_dim(tensor, n_topk):
    batch_size = tensor.size(0)
    values_1d, idxes_1d = tensor.view(batch_size, -1).topk(n_topk)
    idxes = np.stack(np.unravel_index(idxes_1d, tensor.size()[1:])).transpose(1, 2, 0)
    values = tensor.view(batch_size, -1).gather(1, idxes_1d).numpy()
    return idxes, values

In [46]:
# First flatten to 1-d
if np.prod(prob_sca.shape[1:]) < beam_size:
    beam_size_sca = np.prod(prob_sca.shape[1:])
else:
    beam_size_sca = beam_size
# Now as sc_idx is already sorted, re-map them properly.
# idxes: [sc_beam_idx, sa_idx] -> sca_idxes: [sc_idx, sa_idx]
idxes, values = topk_multi_dim(prob_sca.detach().cpu(), n_topk=beam_size_sca)
sc_beam_idxes = idxes[:, :, 0]
sc_idxes = np.stack([pr_sc_beam.numpy()[i, sc_beam_idx] for i, sc_beam_idx in enumerate(sc_beam_idxes)])
sca_idxes = np.stack([sc_idxes, idxes[:, :, 1]]).transpose(1, 2, 0)

In [47]:
sca_idxes

array([[[14,  0],
        [18,  0],
        [19,  0],
        [ 0,  0]],

       [[ 6,  0],
        [ 3,  0],
        [ 4,  0],
        [ 5,  0]]], dtype=int64)

writing ...