# TEXT2SQL with transformers

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import transformers

from typing import Tuple, Dict, List, Union, Any
print(f"PyTroch Version: {torch.__version__}")
print(f"Transfomers Version: {transformers.__version__}")

PyTroch Version: 1.8.1
Transfomers Version: 4.6.1


## Data to build

In [10]:
from pathlib import Path
import re
import records

schema_re = re.compile(r'\((.+)\)')
num_re = re.compile(r'[-+]?\d*\.\d+|\d+')

db_path = Path("./private")
db = records.Database(f"sqlite:///{db_path / 'samsung_new.db'}")

table_id = "receipts"
table_info = db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql
schema_str = schema_re.findall(table_info.replace("\n", ""))[0]
schema = {}
for tup in schema_str.split(', '):
    c, t = tup.split()
    schema[c.strip('"')] = t

In [11]:
# sqlite can have NULL, INTEGER, REAL, TEXT, BLOB data type
schema

{'index': 'INTEGER',
 'rcept_no': 'TEXT',
 'reprt_code': 'TEXT',
 'bsns_year': 'INTEGER',
 'corp_code': 'TEXT',
 'stock_code': 'TEXT',
 'fs_div': 'TEXT',
 'fs_nm': 'TEXT',
 'sj_div': 'TEXT',
 'sj_nm': 'TEXT',
 'account_nm': 'TEXT',
 'thstrm_nm': 'TEXT',
 'thstrm_dt': 'TEXT',
 'thstrm_amount': 'INTEGER',
 'frmtrm_nm': 'TEXT',
 'frmtrm_dt': 'TEXT',
 'frmtrm_amount': 'INTEGER',
 'bfefrmtrm_nm': 'TEXT',
 'bfefrmtrm_dt': 'TEXT',
 'bfefrmtrm_amount': 'INTEGER'}

In [12]:
set(schema.values())

{'INTEGER', 'TEXT'}

In [13]:
question = "제 51 기에 삼성전자의 유동자산은 어떻게 돼?"
cls_token = "[CLS]"
table_token = "[T]"
column_token = "[C]"

f"{cls_token} {question} [T] {table_id} " + " ".join([f"{column_token} {col} [{typ}]" for col, typ in schema.items()])

'[CLS] 제 51 기에 삼성전자의 유동자산은 어떻게 돼? [T] receipts [C] index [INTEGER] [C] rcept_no [TEXT] [C] reprt_code [TEXT] [C] bsns_year [INTEGER] [C] corp_code [TEXT] [C] stock_code [TEXT] [C] fs_div [TEXT] [C] fs_nm [TEXT] [C] sj_div [TEXT] [C] sj_nm [TEXT] [C] account_nm [TEXT] [C] thstrm_nm [TEXT] [C] thstrm_dt [TEXT] [C] thstrm_amount [INTEGER] [C] frmtrm_nm [TEXT] [C] frmtrm_dt [TEXT] [C] frmtrm_amount [INTEGER] [C] bfefrmtrm_nm [TEXT] [C] bfefrmtrm_dt [TEXT] [C] bfefrmtrm_amount [INTEGER]'

## Build a cumstom Tokenizer

https://huggingface.co/docs/tokenizers/python/latest/pipeline.html

### Normalization

- `normalizers`는 raw text를 더 깨끗하게 만드는 과정이다.
- `NFD` 사용하게 되면 한글은 자음 모음으로 분리된다.
    - NFD(Normalization Form Canonical Decomposition) = 조합형
    - NFC(Normalizaiton Form Canonical Compostion) = 완성형

In [7]:
from tokenizers import normalizers
from tokenizers.normalizers import NFD, StripAccents

normalizer = normalizers.Sequence([NFD(), StripAccents()])
# 일반 출력시 합쳐져서 보이지만, for문을 사용하면 분리된다. 
print(normalizer.normalize_str(question))
print("/".join([x for x in normalizer.normalize_str(question)]))

제 51 기에 삼성전자의 유동자산은 어떻게 돼?
ᄌ/ᅦ/ /5/1/ /ᄀ/ᅵ/ᄋ/ᅦ/ /ᄉ/ᅡ/ᆷ/ᄉ/ᅥ/ᆼ/ᄌ/ᅥ/ᆫ/ᄌ/ᅡ/ᄋ/ᅴ/ /ᄋ/ᅲ/ᄃ/ᅩ/ᆼ/ᄌ/ᅡ/ᄉ/ᅡ/ᆫ/ᄋ/ᅳ/ᆫ/ /ᄋ/ᅥ/ᄄ/ᅥ/ᇂ/ᄀ/ᅦ/ /ᄃ/ᅫ/?


### Pre-Tokenization

- `pre_tokenizers`는 텍스트를 더 작은 토큰으로 분리하는 과정이다.
- `Whitespace`는 토큰을 공백을 기준으로 나누면 string의 위치를 같이 반환한다. (start + end) - 기준 regex: `\w+|[^\w\s]+`
- `Punctuation`은 문장 부호를 각각 분리한다. 
- `Digits`를 통해 숫자를 각각 분리할지 말지 결정할 수 있다. 

In [8]:
from tokenizers import pre_tokenizers
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Digits

print("Whitespace: ")
print(Whitespace().pre_tokenize_str(question))

print("Digits individual: False")
pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=False)])
print(pre_tokenizer.pre_tokenize_str(question + "????"))

print("Punctuation + Individual Digits")
pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Punctuation(), Digits(individual_digits=True)])
print(pre_tokenizer.pre_tokenize_str(question + "????"))

Whitespace: 
[('제', (0, 1)), ('51', (2, 4)), ('기에', (5, 7)), ('삼성전자의', (8, 13)), ('유동자산은', (14, 19)), ('어떻게', (20, 23)), ('돼', (24, 25)), ('?', (25, 26))]
Digits individual: False
[('제', (0, 1)), ('51', (2, 4)), ('기에', (5, 7)), ('삼성전자의', (8, 13)), ('유동자산은', (14, 19)), ('어떻게', (20, 23)), ('돼', (24, 25)), ('?????', (25, 30))]
Punctuation + Individual Digits
[('제', (0, 1)), ('5', (2, 3)), ('1', (3, 4)), ('기에', (5, 7)), ('삼성전자의', (8, 13)), ('유동자산은', (14, 19)), ('어떻게', (20, 23)), ('돼', (24, 25)), ('?', (25, 26)), ('?', (26, 27)), ('?', (27, 28)), ('?', (28, 29)), ('?', (29, 30))]


### The Model

- `BPE`: Byte-Pair Encoding 토큰화, 자주 등장하는 character를 합쳐서 표현하는 알고리즘
- `Unigram`: 확률적으로 최적의 subword 토큰을 결정
- `WordLevel`: 단어 단위의 토큰화
- `WordPiece`: Google WordPiece 토큰화

In [38]:
from tokenizers.models import BPE, Unigram, WordLevel, WordPiece

### Post-Processing

- `processors` 에서 후처리를 할 수 있다.
- `TemplateProcessing`을 이용해 원하는 형태로 토큰을 분리할 수 있다.

In [42]:
from tokenizers.processors import PostProcessor

In [None]:
from tokenizers.processors import TemplateProcessing

post_processor = TemplateProcessing(
    single="[CLS] $A [T] $B [C] $B",
    pair="[CLS] $A [T] $B:1 [C] $B:1",
    special_tokens=[("[CLS]", 1), ("[SEP]", 2)]#, ("[T]", 3), ("[C]", 4), ("[INTEGER]", 5), ("[REAL]", 6), ("[TEXT]", 7), ("[BLOB]", 8)],
)

---

- https://huggingface.co/docs/tokenizers/python/latest/index.html
- https://huggingface.co/transformers/main_classes/tokenizer.html?highlight=pretrainedtokenizer#transformers.PreTrainedTokenizer

In [14]:
input_str = f"{question} [T] {table_id} " + "".join([f"[SEP]{col}" for col in schema]) 
# " ".join([f"{column_token} {col} [{typ}]" for col, typ in schema.items()])
input_str

'제 51 기에 삼성전자의 유동자산은 어떻게 돼? [T] receipts [SEP]index[SEP]rcept_no[SEP]reprt_code[SEP]bsns_year[SEP]corp_code[SEP]stock_code[SEP]fs_div[SEP]fs_nm[SEP]sj_div[SEP]sj_nm[SEP]account_nm[SEP]thstrm_nm[SEP]thstrm_dt[SEP]thstrm_amount[SEP]frmtrm_nm[SEP]frmtrm_dt[SEP]frmtrm_amount[SEP]bfefrmtrm_nm[SEP]bfefrmtrm_dt[SEP]bfefrmtrm_amount'

In [15]:
batch_qs = ["제 51 기에 삼성전자의 유동자산은 어떻게 돼?", "2020년도 삼성전자의 유동자산은 얼마?"]
table_str = f"[T]{table_id}" + "".join([f"[SEP]{col}" for col in schema]) 
batch_ts = [table_str] * len(batch_qs)

In [16]:
from KoBertTokenizer import KoBertTokenizer

In [17]:
# new_special_tokens = ["[T]", "[C]", "[INTEGER]", "[REAL]", "[TEXT]", "[BLOB]"]
# new_special_tokens = list(map(lambda x: AddedToken(x, single_word=True, normalized=False), new_special_tokens))
# tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert', add_special_tokens=True, additional_special_tokens=new_special_tokens)
tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert')

In [18]:
special_tokenes2idx = dict(zip(tokenizer.additional_special_tokens, tokenizer.additional_special_tokens_ids))
special_tokenes2idx

{}

In [19]:
encode_input = tokenizer(
    batch_qs, batch_ts, 
    max_length=512, padding=True, truncation=True, return_tensors="pt", 
    return_attention_mask=True, 
    return_special_tokens_mask=False, 
)
encode_input.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [20]:
batch_size, max_length = encode_input["input_ids"].size()
question_mask = torch.bitwise_and(encode_input["token_type_ids"] == 0, encode_input["attention_mask"].bool())
question_mask[:, 0] = False  # [CLS] mask out
question_mask[:, -1] = False  # [SEP] mask out

# table_mask = torch.where(
#     (encode_input["input_ids"] == 8002), 
#     torch.ones_like(encode_input["input_ids"], dtype=torch.bool), 
#     torch.zeros_like(encode_input["input_ids"], dtype=torch.bool)
# )
# column_mask = torch.where(
#     (encode_input["input_ids"] == 8003),
#     torch.ones_like(encode_input["input_ids"], dtype=torch.bool), 
#     torch.zeros_like(encode_input["input_ids"], dtype=torch.bool)
# )

In [21]:
for x in encode_input["input_ids"]:
    print(tokenizer.decode(x, skip_special_tokens=False))
    print()

[CLS] 제 51 기에 삼성전자의 유동자산은 어떻게 돼?[SEP] [T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP]

[CLS] 2020년도 삼성전자의 유동자산은 얼마?[SEP] [T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP][PAD][PAD]



In [22]:
for inputs, type_ids in zip(encode_input["input_ids"], encode_input["token_type_ids"]):
    print("--- type_ids = 0")
    print(tokenizer.decode(inputs[type_ids == 0]).replace("[PAD]", ""))
    print("--- type_ids = 1")
    print(tokenizer.decode(inputs[type_ids == 1]))
    print()

--- type_ids = 0
[CLS] 제 51 기에 삼성전자의 유동자산은 어떻게 돼?[SEP]
--- type_ids = 1
[T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP]

--- type_ids = 0
[CLS] 2020년도 삼성전자의 유동자산은 얼마?[SEP]
--- type_ids = 1
[T]receipts[SEP] index[SEP] rcept_no[SEP] reprt_code[SEP] bsns_year[SEP] corp_code[SEP] stock_code[SEP] fs_div[SEP] fs_nm[SEP] sj_div[SEP] sj_nm[SEP] account_nm[SEP] thstrm_nm[SEP] thstrm_dt[SEP] thstrm_amount[SEP] frmtrm_nm[SEP] frmtrm_dt[SEP] frmtrm_amount[SEP] bfefrmtrm_nm[SEP] bfefrmtrm_dt[SEP] bfefrmtrm_amount[SEP]



## Encoder

In [23]:
from transformers import BertModel, BertConfig
model = BertModel.from_pretrained("monologg/kobert")
model.resize_token_embeddings(len(tokenizer))

Embedding(8002, 768, padding_idx=1)

In [52]:
model.config.output_hidden_states = True

In [53]:
encode_output = model(**encode_input)

In [58]:
len(encode_output.hidden_states)

13

## Decoder

<img src="https://drive.google.com/uc?id=1PW9oAXfW-ZI-jxGn5q9O_gzUIZnNYaet" alt="Sqlova Decoder Architecture " width="100%" height="auto">

### Select Column

$$\begin{aligned} 
s(n\vert c) &= D_c^T W E_n \\
p(n\vert c) &= \text{softmax}(s(n\vert c))\\
C_c &= \sum_n p(n \vert c) E_n \\
s_{sc}(c) &= W  \tanh ([WD_c; WC_c]) \\
p_{sc}(c) &= \text{softmax}(s_{sc}(c))
\end{aligned}$$

$E_n$ is LSTM output of $n$-th token of question, $D_c$ is the encoding of header $c$, $C_n$ is context vector of question for given column, $[\cdot ; \cdot]$ is concatenation of two vectors, $p_{sc}(c)$ is probability of generating column c.

In [1]:
class DBEngine:

    def __init__(self, fdb):
        #fdb = 'data/test.db'
        self.db = records.Database('sqlite:///{}'.format(fdb))

    def execute_query(self, table_id, query, *args, **kwargs):
        return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)

    def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
            
        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
        schema_str = schema_re.findall(table_info)[0]
        schema = {}
        for tup in schema_str.split(', '):
            c, t = tup.split()
            schema[c] = t
        select = 'col{}'.format(select_index)
        agg = agg_ops[aggregation_index]
        if agg:
            select = '{}({})'.format(agg, select)
        where_clause = []
        where_map = {}
        for col_index, op, val in conditions:
            if lower and (isinstance(val, str) or isinstance(val, str)):
                val = val.lower()
            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
                try:
                    # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
                    # val = float(parse_decimal(val)) # somehow it generates error.
                    val = float(parse_decimal(val, locale='en_US'))
                    # print('!!!!!!After: val', val)

                except NumberFormatError as e:
                    try:
                        val = float(num_re.findall(val)[0]) # need to understand and debug this part.
                    except:
                        # Although column is of number, selected one is not number. Do nothing in this case.
                        pass
            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
            where_map['col{}'.format(col_index)] = val
        where_str = ''
        if where_clause:
            where_str = 'WHERE ' + ' AND '.join(where_clause)
        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)

        #print query
        out = self.db.query(query, **where_map)

        return [o.result for o in out]

    def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
        schema_str = schema_re.findall(table_info)[0]
        schema = {}
        for tup in schema_str.split(', '):
            c, t = tup.split()
            schema[c] = t
        select = 'col{}'.format(select_index)
        agg = agg_ops[aggregation_index]
        if agg:
            select = '{}({})'.format(agg, select)
        where_clause = []
        where_map = {}
        for col_index, op, val in conditions:
            if lower and (isinstance(val, str) or isinstance(val, str)):
                val = val.lower()
            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
                try:
                    # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
                    # val = float(parse_decimal(val)) # somehow it generates error.
                    val = float(parse_decimal(val, locale='en_US'))
                    # print('!!!!!!After: val', val)

                except NumberFormatError as e:
                    val = float(num_re.findall(val)[0])
            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
            where_map['col{}'.format(col_index)] = val
        where_str = ''
        if where_clause:
            where_str = 'WHERE ' + ' AND '.join(where_clause)
        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)

        #print query
        out = self.db.query(query, **where_map)


        return [o.result for o in out], query
    def show_table(self, table_id):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
        rows = self.db.query('select * from ' +table_id)
        print(rows.dataset)

---

# SQLova Model

In [1]:
from sqlova.model.nl2sql.wikisql_models import *

In [2]:
import json
import torch
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import transformers

from typing import Tuple, Dict, List, Union, Any
print(f"PyTroch Version: {torch.__version__}")
print(f"Transfomers Version: {transformers.__version__}")

PyTroch Version: 1.8.1
Transfomers Version: 4.6.1


In [3]:
def load_data():
    path_sql = Path("NLSQL.jsonl")
    path_table = Path("table.jsonl")

    data = []
    table = {}
    with path_sql.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            t1 = json.loads(line.strip())
            data.append(t1)

    with path_table.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            t1 = json.loads(line.strip())
            table[t1['id']] = t1
            
    return data, table

In [4]:
data, table = load_data()

In [5]:
import os

# multiprocessing lib doesn’t have it implemented on Windows
# https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/14
num_workers = 0 if os.name == "nt" else 4

In [6]:
data_loader = torch.utils.data.DataLoader(
    batch_size=2,
    dataset=data,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=lambda x: x # now dictionary values are not merged!
)

db

In [7]:
from tsv2jsonl import DBEngine

In [8]:
db_path = Path("./private")
dbengine = DBEngine(db_path / "samsung_new.db")

In [9]:
table_schema_lengths = []
for t in dbengine.db.get_table_names():
    dbengine.get_schema_info(t)
    table_schema_lengths.append(len(dbengine.schema))

In [10]:
# engine = DBEngine(db_path / "samsung_new.db")
# engine.execute(table_id="receipts", query)

get models

In [11]:
from KoBertTokenizer import KoBertTokenizer
from transformers import BertModel, BertConfig

def get_bert(model_path: str, device: str, max_col_length: int, output_hidden_states: bool=False):
    
    special_tokens = ["[S]", "[E]", "[COL]"] # sequence start, sequence end, column tokens
    tokenizer = KoBertTokenizer.from_pretrained(model_path, add_special_tokens=True, additional_special_tokens=special_tokens)
    config = BertConfig.from_pretrained(model_path)
    config.output_hidden_states = output_hidden_states
    
    model = BertModel.from_pretrained(model_path)
    model.resize_token_embeddings(len(tokenizer))
    model.config.output_hidden_states = output_hidden_states
    model.to(device)
    
    return model, tokenizer, config

In [12]:
model_path = "monologg/kobert"
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" 

agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']

num_target_layers = 1  # "The Number of final layers of BERT to be used in downstream task."
# Get BERT
model_bert, tokenizer_bert, config_bert = get_bert(model_path=model_path, device=device, max_col_length=max(table_schema_lengths))

# Get Seq-to-SQL
iS = config_bert.hidden_size * num_target_layers  # Seq-to-SQL input vector dimenstion
hS = 100
lS = 2
dr = 0.3

n_cond_ops = len(cond_ops)
n_agg_ops = len(agg_ops)
model = Seq2SQL_v1(iS, hS, lS, dr, n_cond_ops, n_agg_ops)
model = model.to(device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


get parameters

In [13]:
lr = 1e-3
lr_bert = 1e-5

opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=lr, weight_decay=0)
opt_bert = torch.optim.AdamW(filter(lambda p: p.requires_grad, model_bert.parameters()),
                            lr=lr_bert, weight_decay=0)

In [14]:
model.train()
model_bert.train()

ave_loss = 0
cnt = 0  # count the # of examples
cnt_sc = 0  # count the # of correct predictions of select column
cnt_sa = 0  # of selectd aggregation
cnt_wn = 0  # of where number
cnt_wc = 0  # of where column
cnt_wo = 0  # of where operator
cnt_wv = 0  # of where-value
cnt_wvi = 0  # of where-value index (on question tokens)
cnt_lx = 0  # of logical form acc
cnt_x = 0  # of execution acc

In [15]:
st_pos = 0

for iB, data in enumerate(tqdm(data_loader, desc='TRAIN')):
    cnt += len(data)
    
    if cnt < st_pos:
        continue
    break

TRAIN:   0%|          | 0/21120 [00:00<?, ?it/s]

In [16]:
start_tkn = tokenizer_bert.additional_special_tokens[0]
start_tkn_id = tokenizer_bert.additional_special_tokens_ids[0]
end_tkn = tokenizer_bert.additional_special_tokens[1]
end_tkn_id = tokenizer_bert.additional_special_tokens_ids[1]

batch_qs = [jsonl["question"] for jsonl in data]
tid = [jsonl["table_id"] for jsonl in data]
batch_sqls = [jsonl["sql"] for jsonl in data]
batch_ts = []
for table_id in tid:
    dbengine.get_schema_info(table_id)
    table_str = f"{start_tkn}{table_id}{end_tkn}" + "".join([
        f"[COL]{start_tkn}{col}{end_tkn}" for col in dbengine.schema
    ]) 
    batch_ts.append(table_str)
    
encode_input = tokenizer_bert(
    batch_qs, batch_ts, 
    max_length=512, padding=True, truncation=True, return_tensors="pt", 
    return_attention_mask=True, 
    return_special_tokens_mask=False, 
)

In [17]:
print(tokenizer_bert.decode(encode_input["input_ids"][0]))

[CLS] 삼성전자의 2017년 법인세차감전 순이익은 얼마야?[SEP] [S] receipts [E] [COL] [S] index [E] [COL] [S] rcept_no [E] [COL] [S] reprt_code [E] [COL] [S] bsns_year [E] [COL] [S] corp_code [E] [COL] [S] stock_code [E] [COL] [S] fs_div [E] [COL] [S] fs_nm [E] [COL] [S] sj_div [E] [COL] [S] sj_nm [E] [COL] [S] account_nm [E] [COL] [S] thstrm_nm [E] [COL] [S] thstrm_dt [E] [COL] [S] thstrm_amount [E] [COL] [S] frmtrm_nm [E] [COL] [S] frmtrm_dt [E] [COL] [S] frmtrm_amount [E] [COL] [S] bfefrmtrm_nm [E] [COL] [S] bfefrmtrm_dt [E] [COL] [S] bfefrmtrm_amount [E] [SEP]


In [18]:
def get_answer(input_ids, mask, batch_size, start_tkn_id, end_tkn_id):
    r"""
    answer should include end token: [E]
    """
    masked_input_ids = input_ids[mask]
    start_tkn_mask = masked_input_ids == start_tkn_id
    end_tkn_mask = masked_input_ids == end_tkn_id
    table_col_length = masked_input_ids.view(batch_size, -1).size(1)
    start_end_mask = torch.bitwise_or(start_tkn_mask, end_tkn_mask)
    index = torch.arange(table_col_length).repeat(batch_size)[start_end_mask].view(batch_size, -1, 2)
    tkn_lengths = index[:, :, 1] - index[:, :, 0]
    answer_col_tkns = [x.split(tkn_length.tolist()) for x, tkn_length in zip(
        masked_input_ids[~start_tkn_mask].view(batch_size, -1), tkn_lengths)]
    return answer_col_tkns

# def get_decoder_input_mask(input_ids, mask, batch_size, end_tkn_id):
#     r"""
#     input should only include end token: [S]
#     """
#     end_tkn_mask = input_ids == end_tkn_id
#     end_index = torch.arange(input_ids.size(1)).repeat(batch_size)[end_tkn_mask.view(-1)].view(batch_size, -1)
#     return mask.scatter(1, end_index, False)

def get_decoder_input_mask(input_ids, mask, batch_size, end_tkn_id):
    r"""
    input should only contains word tokens:
    """
    start_tkn_mask = input_ids == start_tkn_id
    end_tkn_mask = input_ids == end_tkn_id
    start_end_mask = torch.bitwise_or(start_tkn_mask, end_tkn_mask)
    index = torch.arange(input_ids.size(1)).repeat(batch_size)[start_end_mask.view(-1)].view(batch_size, -1)
    return mask.scatter(1, index, False)

In [19]:
def get_input_mask_and_answer(encode_input, tokenizer):
    r"""
    table -> database table name(id)
    header -> database header
    
    returns:
        input_question_mask, input_table_mask, input_header_mask, answer_table_tkns, answer_header_tkns
    """
    batch_size, max_length = encode_input["input_ids"].size()
    sep_tkn_mask = encode_input["input_ids"] == tokenizer.sep_token_id
    start_tkn_id, end_tkn_id, col_tkn_id = tokenizer.additional_special_tokens_ids
    
    input_question_mask = torch.bitwise_and(encode_input["token_type_ids"] == 0, encode_input["attention_mask"].bool())
    input_question_mask = torch.bitwise_and(input_question_mask, ~sep_tkn_mask) # [SEP] mask out
    input_question_mask[:, 0] = False  # [CLS] mask out

    db_mask = torch.bitwise_and(encode_input["token_type_ids"] == 1, encode_input["attention_mask"].bool())
    db_mask = torch.bitwise_xor(db_mask, sep_tkn_mask)
    col_tkn_mask = encode_input["input_ids"] == col_tkn_id
    db_mask = torch.bitwise_and(db_mask, ~col_tkn_mask)
    # split table_mask and header_mask
    input_idx = torch.arange(max_length).repeat(batch_size, 1)
    db_idx = input_idx[db_mask]
    table_header_tkn_idx = db_idx[db_idx > 0]
    table_start_idx = table_header_tkn_idx.view(batch_size, -1)[:, 0] + 1
    start_tkn_idx = table_header_tkn_idx[1:][table_header_tkn_idx.diff() == 2].view(batch_size, -1)
    table_end_sep_idx = start_tkn_idx[:, 0] - 1
    split_size = torch.stack([
        table_end_sep_idx-table_start_idx+1, table_header_tkn_idx.view(batch_size, -1).size(1)-(table_end_sep_idx-table_start_idx+1)
    ]).transpose(0, 1)

    # Token idx
    table_tkn_idx, header_tkn_idx = map(
        lambda x: torch.stack(x), 
        zip(*[torch.split(x, size.tolist()) for x, size in zip(table_header_tkn_idx.view(batch_size, -1), split_size)])
    )

    table_tkn_idx = table_tkn_idx[:, 1:]
    # Mask include [S] & [E] tokens
    table_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, table_tkn_idx, True)
    header_tkn_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, header_tkn_idx, True)

    # For Decoder Input, Maskout [S], [E] for table & header  
    input_table_mask = get_decoder_input_mask(
        encode_input["input_ids"], table_tkn_mask, batch_size, end_tkn_id
    )
    input_header_mask = get_decoder_input_mask(
        encode_input["input_ids"], header_tkn_mask, batch_size, end_tkn_id
    )
    # [COL] token mask: this is for attention
    col_tkn_idx = input_idx[col_tkn_mask].view(batch_size, -1)
    input_col_mask = torch.zeros_like(encode_input["input_ids"], dtype=torch.bool).scatter(1, col_tkn_idx, True)
        
    # For Answer, Maskout [S] for table & header 
    answer_table_tkns = get_answer(
        encode_input["input_ids"], table_tkn_mask, batch_size, start_tkn_id, end_tkn_id
    )
    answer_header_tkns = get_answer(
        encode_input["input_ids"], header_tkn_mask, batch_size, start_tkn_id, end_tkn_id
    )
    
    return input_question_mask, input_table_mask, input_header_mask, input_col_mask, answer_table_tkns, answer_header_tkns

In [20]:
input_question_mask, input_table_mask, input_header_mask, input_col_mask, answer_table_tkns, answer_header_tkns \
    = get_input_mask_and_answer(encode_input, tokenizer_bert)

In [21]:
print(tokenizer_bert.decode(encode_input["input_ids"][input_question_mask]))
print(tokenizer_bert.decode(encode_input["input_ids"][input_table_mask]))
print(tokenizer_bert.decode(encode_input["input_ids"][input_header_mask]))
print(tokenizer_bert.decode(encode_input["input_ids"][input_col_mask]))
print("---answer table tokens---")
print(answer_table_tkns)
print("---answer column tokens---")
for b in answer_header_tkns:
    print("--- batch ---")
    for x in b:
        print(x)

삼성전자의 2017년 법인세차감전 순이익은 얼마야? 삼성전자 50기 유동자산이 어때?
receipts receipts
index rcept_no reprt_code bsns_year corp_code stock_code fs_div fs_nm sj_div sj_nm account_nm thstrm_nm thstrm_dt thstrm_amount frmtrm_nm frmtrm_dt frmtrm_amount bfefrmtrm_nm bfefrmtrm_dt bfefrmtrm_amount index rcept_no reprt_code bsns_year corp_code stock_code fs_div fs_nm sj_div sj_nm account_nm thstrm_nm thstrm_dt thstrm_amount frmtrm_nm frmtrm_dt frmtrm_amount bfefrmtrm_nm bfefrmtrm_dt bfefrmtrm_amount
[COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL] [COL]
---answer table tokens---
[(tensor([ 517,  437,  382,  389,  405,  432,  442,  440, 8003]),), (tensor([ 517,  437,  382,  389,  405,  432,  442,  440, 8003]),)]
---answer column tokens---
--- batch ---
tensor([ 517,  409,  388,  389,  457, 8003])
tensor([ 517,  435,  382,  389,  432,

In [22]:
encode_outputs = model_bert(**encode_input)

In [23]:
encode_outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [24]:
def pad(batches: Tuple[torch.Tensor], lengths: List[int], model: BertModel, pad_idx: int=1) -> torch.Tensor:
    padded = []
    max_length = max(lengths)
    for x in batches:
        if len(x) < max_length:
            pad_tensor = model.embeddings.word_embeddings(torch.LongTensor([pad_idx]*(max_length - len(x))))
            padded.append(torch.cat([x, pad_tensor]))
        else:
            padded.append(x)
    return torch.stack(padded)

def get_decoder_batches(encode_output, mask, model, pad_idx):
    lengths = mask.sum(1)
    tensors = encode_output.last_hidden_state[mask, :]
    batches = torch.split(tensors, lengths.tolist())
    if lengths.ne(lengths.max()).sum().item() != 0:
        # pad not same length tokens
        tensors_padded = pad(batches, lengths.tolist(), model, pad_idx=pad_idx)
    else:
        # just stack the splitted tensors
        tensors_padded = torch.stack(batches)
    return tensors_padded, lengths.tolist()

def get_pad_mask(lengths):
    batch_size = len(lengths)
    max_len = max(lengths)
    mask = torch.ones(batch_size, max_len)
    for i, l in enumerate(lengths):
        mask[i, :l] = 0
    return mask

In [25]:
question_padded, question_lengths = get_decoder_batches(encode_outputs, input_question_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
table_padded, table_lengths = get_decoder_batches(encode_outputs, input_table_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
header_padded, header_lengths = get_decoder_batches(encode_outputs, input_header_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)
col_padded, col_lengths = get_decoder_batches(encode_outputs, input_col_mask, model_bert, pad_idx=tokenizer_bert.pad_token_id)

In [26]:
question_padded.size(), table_padded.size(), header_padded.size(), col_padded.size()

(torch.Size([2, 14, 768]),
 torch.Size([2, 8, 768]),
 torch.Size([2, 181, 768]),
 torch.Size([2, 20, 768]))

In [27]:
print(question_lengths)
print(table_lengths)
print(header_lengths)
print(col_lengths)

[14, 10]
[8, 8]
[181, 181]
[20, 20]


In [28]:
def get_g(batch_sqls):
    """
    for backward compatibility, separated with get_g
    
    sc: select column
    sa: select agg
    wn: where number
    wc: where column
    wo: where operator
    wv: where value
    """
    get_ith_element = lambda li, i: [x[i] for x in li]
    g_sc = []
    g_sa = []
    g_wn = []
    g_wc = []
    g_wo = []
    g_wv = []
    for b, sql_dict in enumerate(batch_sqls):
        g_sc.append( sql_dict["sel"] )
        g_sa.append( sql_dict["agg"])

        conds = sql_dict["conds"]
        if not sql_dict["agg"] < 0:
            g_wn.append( len(conds) )
            g_wc.append( get_ith_element(conds, 0) )
            g_wo.append( get_ith_element(conds, 1) )
            g_wv.append( get_ith_element(conds, 2) )
        else:
            raise EnvironmentError
    return g_sc, g_sa, g_wn, g_wc, g_wo, g_wv

In [29]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(batch_sqls)

In [30]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv

([16, 16],
 [0, 0],
 [2, 2],
 [[10, 3], [10, 3]],
 [[0, 0], [0, 0]],
 [['법인세차감전 순이익', 2018], ['유동자산', 2019]])

In [31]:
class AttentionBase(nn.Module):
    def __init__(self):
        super().__init__()
    
    def wipe_out_pad_tkn_score(self, score, lengths, dim=2):
        max_len = max(lengths)
        for batch_idx, length in enumerate(lengths):
            if length < max_len:
                if dim == 2:
                    score[batch_idx, :, length:] = -10000000
                elif dim == 1:
                    score[batch_idx, length:, :] = 0.0
                else:
                    raise ValueError(f"`dim` in wipe_out_pad_tkn_score should be 1 or 2")
        return score 


class C2QAttention(AttentionBase):
    r"""Decoder Column to Question Attention Module"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, o_c, o_q, q_lengths, c_lengths=None, rt_attn=False):
        r"""
        Calculate for each column tokens, How much related to question tokens?
        
        o_c: LSTM output of column
        o_q: LSTM output of question 
        
        c_lengths: wipe out row length
        return context atttended to question tokens
        """
        sqrt_H = torch.sqrt(torch.FloatTensor([o_c.size(-1)], device=o_c.device))  # Apply Attention is All you Need Technique
        o_q_transform = self.linear(o_q)  # (B, T_q, H)
        score_c2q = torch.bmm(o_c, o_q_transform.transpose(1, 2)) / sqrt_H  # (B, T_c, H) x (B, H, T_q) = (B, T_c, T_q)
        score_c2q = self.wipe_out_pad_tkn_score(score_c2q, q_lengths, dim=2)
        
        prob_c2q = self.softmax(score_c2q)
        if c_lengths is not None:
            prob_c2q = self.wipe_out_pad_tkn_score(prob_c2q, c_lengths, dim=1)
        # prob_c2q: (B, T_c, T_q) -> (B, T_c, T_q, 1)
        # o_q: (B, 1, T_q, H)
        # p_col2question \odot o_q = (B, T_c, T_q, 1) \odot (B, 1, T_q, H) = (B, T_c, T_q, H)
        # -> reduce sum to T_q to get context for each column (B, T_c, H)
        context = torch.mul(prob_c2q.unsqueeze(3), o_q.unsqueeze(1)).sum(dim=2)
        if rt_attn:
            attn = prob_c2q
        else:
            attn = None
        return context, attn

class SelfAttention(AttentionBase):
    r"""Decoder Self Attention Module"""
    def __init__(self, in_features, out_features=1):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, o, lengths, rt_attn=False):
        r"""
        Calculate for each o tokens, How much related to o tokens?
        
        return attended summary of o
        """
        o_transform = self.linear(o)  # (B, T_o, H) -> (B, T_o, 1)
        o_transform = self.wipe_out_pad_tkn_score(o_transform, lengths) 
        o_prob = self.softmax(o_transform)  # (B, T_o, 1)
        
        o_summary = torch.mul(o, o_prob).sum(1)  # (B, T_o, H) \odot (B, T_o, 1) -> (B, H)

        if rt_attn:
            attn = o_prob
        else:
            attn = None
        return o_summary, attn


In [39]:
class SelectDecoder(nn.Module):
    r"""SELECT Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict column index
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_q_context, attn = self.col2question_attn(col_context, o_q, question_lengths, col_lengths, rt_attn)  # (B, T_c, H), (B, T_c, T_q)
        
        vec = torch.cat([col_q_context, col_context], dim=2)  # (B, T_c, 2H)
        output = self.output_layer(vec)
        # TODO: add penalty for padded header(column) information
        
        return output.squeeze(-1), attn
    

class AggDecoder(nn.Module):
    r"""AGG Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size)
        )
                
    def forward(self, question_padded, col_padded, question_lengths: List[int], col_lengths: List[int], select_idxes: List[int], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        
        col_selected = col_context[list(range(batch_size)), select_idxes].unsqueeze(1)  # col_selected: (B, 1, H)
        
        col_q_context, attn = self.col2question_attn(col_selected, o_q, question_lengths, col_lengths, rt_attn)  # (B, 1, H), (B, 1, T_q)
        output = self.output_layer(col_q_context.squeeze(1))
        
        return output, attn
    
    
class WhereNumDecoder(nn.Module):
    r"""WHERE number Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio:float=0.3, max_where_conds=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        if self.output_size > self.max_where_conds+1:
            # HERE output will be dilivered to cross-entropy loss, not guessing the real number of where clause
            raise ValueError(f"`WhereNumDecoder` only support maximum {max_where_conds} where clause")
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_self_attn = SelfAttention(2*hidden_size, 1)
        self.lstm_q_hidden_init_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        self.lstm_q_cell_init_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        
        self.context_self_attn = SelfAttention(hidden_size, 1)
        
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size)
        )
        
        
    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        """
        batch_size, n_col, _ = col_padded.size()
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)

        col_self_attn, col_attn = self.col_self_attn(col_context, col_lengths, rt_attn)  # (B, 2H), (B, T_c)

        h_0 = self.lstm_q_hidden_init_linear(col_self_attn)  # (B, 2H)
        h_0 = h_0.view(batch_size, 2*self.num_layers, -1).transpose(0, 1).contiguous()  # (B, 2*num_layers, H/2)
        c_0 = self.lstm_q_cell_init_linear(col_self_attn)  # (B, 2H)
        c_0 = c_0.view(batch_size, 2*self.num_layers, -1).transpose(0, 1).contiguous()  # (B, 2*num_layers, H/2)
        
        o_q, (h_q, c_q) = self.lstm_q(question_padded, (h_0, c_0))  # o_q: (B, T_q, H)
        o_summary, o_attn = self.context_self_attn(o_q, question_lengths, rt_attn)  # (B, H), (B, T_q)
        output = self.output_layer(o_summary)
        
        return output, (col_attn, o_attn)

    
class WhereColumnDecoder(nn.Module):
    r"""WHERE Column Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int=1, num_layers: int=2, dropout_ratio:float=0.3, max_where_conds: int=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio

        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, header_padded, col_padded, question_lengths: List[int], col_lengths: List[int], rt_attn=False):
        r"""
        predict column index
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_q_context, attn = self.col2question_attn(col_context, o_q, question_lengths, col_lengths, rt_attn)  # (B, T_c, H), (B, T_c, T_q)
        
        vec = torch.cat([col_q_context, col_context], dim=2)  # (B, T_c, 2H)
        output = self.output_layer(vec)
        # TODO: add penalty for padded header(column) information
        
        return output.squeeze(-1), attn
    
    
class WhereOpDecoder(nn.Module):
    r"""WHERE Opperator Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio: float=0.3, max_where_conds: int=4) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )
    
    def forward(self, question_padded, col_padded, question_lengths: List[int], where_nums: List[int], where_col_idxes: List[List[int]], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        max_where_col_nums is settled at WhereColumnDecoder, but it can be lower than or equal to `max_where_conds`
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_context_padded = self.get_context_padded(col_context, where_nums, where_col_idxes)  # (B, max_where_col_nums, H)
        
        col_q_context, attn = self.col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn)  # (B, max_where_col_nums, H), (B, max_where_col_nums, T_q)
        
        vec = torch.cat([col_q_context, col_context_padded], dim=2)  # (B, max_where_col_nums, 2H)
        output = self.output_layer(vec)  # (B, max_where_col_nums, n_cond_ops)
        # TODO: add penalty for padded header(column) information
        return output
        
    def get_context_padded(self, col_context, where_nums, where_col_idxes):
        r"""
        Select the where column index and pad if some batch doesn't match the max length of tensor
        In case for have different where column lengths
        """
        batch_size, n_col, hidden_size = col_context.size()
        max_where_col_nums = max(where_nums)
        batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size, device=col_context.device)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
            
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)
    
    
class WhereValueDecoder(nn.Module):
    r"""WHERE Value Decoder"""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int=2, dropout_ratio: float=0.3, max_where_conds: int=4, n_cond_ops: int=4,
                 start_tkn_id=8002, end_tkn_id=8003, embedding_layer=None) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_ratio = dropout_ratio
        self.max_where_conds = max_where_conds
        self.n_cond_ops = n_cond_ops
        
        self.start_tkn_id = start_tkn_id
        self.end_tkn_id = end_tkn_id
        if embedding_layer is None:
            raise KeyError("Must initialize the embedding_layer to BertModel's word embedding layer")
        else:
            self.embedding_layer = embedding_layer
        self.lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        self.lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
        
        self.col_context_linear = nn.Linear(2*hidden_size, hidden_size)
        self.col2question_attn = C2QAttention(hidden_size, hidden_size)
        self.output_layer = nn.Sequential(
            nn.Linear(2*hidden_size, 2*hidden_size),
            nn.Tanh(),
            nn.Linear(2*hidden_size, output_size)
        )

    def forward(self, question_padded, col_padded, question_lengths: List[int], where_nums: List[int], where_col_idxes: List[List[int]], where_op_idxes: List[List[int]], rt_attn=False):
        r"""
        predict agg index
        select_prob: selected argmax indices of select_output score
        max_where_col_nums is setted at WhereColumnDecoder
        """
        batch_size, n_col, _ = col_padded.size()
        o_q, (h_q, c_q) = self.lstm_q(question_padded)  # o_q: (B, T_q, H)
        o_c, (h_c, c_c) = self.lstm_h(col_padded)  # o_c: (B, T_c, H)
        o_h, (h_h, c_h) = self.lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)
        
        header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1)  # (B, T_c, H)
        col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
        col_context = self.col_context_linear(col_context)  # (B, T_c, H)
        col_context_padded = self.get_context_padded(col_context, where_nums, where_col_idxes)  # (B, max_where_col_nums, H)
        
        col_q_context, attn = self.col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn)  # (B, max_where_col_nums, H), (B, max_where_col_nums, T_q)
        
        
    def get_context_padded(self, col_context, where_nums, where_col_idxes):
        r"""
        Select the where column index and pad if some batch doesn't match the max length of tensor
        In case for have different where column lengths
        """
        batch_size, n_col, hidden_size = col_context.size()
        max_where_col_nums = max(where_nums)
        batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
        batches_padded = []
        for b in batches:
            where_col_nums = b.size(0)
            if where_col_nums < max_where_col_nums:
                b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size, device=col_context.device)], dim=0)
            else:
                b_padded = b
            batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
            
        return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)
    
        
    def start_token(self):
        pass
    
    def decode(self):
        pass

In [None]:
embedding_layer = model_bert.embeddings.word_embeddings

In [91]:
end_tkn_id

8003

In [34]:
input_size = config_bert.hidden_size
hidden_size = 100
num_layers = 2
dropout_ratio = 0.3
max_where_num = 4
n_agg_ops = len(dbengine.agg_ops)
n_cond_ops = len(dbengine.cond_ops)

In [40]:
select_decoder = SelectDecoder(
    input_size, hidden_size, output_size=1, num_layers=num_layers, dropout_ratio=dropout_ratio
)
agg_decoder = AggDecoder(
    input_size, hidden_size, output_size=n_agg_ops, num_layers=num_layers, dropout_ratio=dropout_ratio
)
where_num_decoder = WhereNumDecoder(
    input_size, hidden_size, output_size=(max_where_num+1), num_layers=num_layers, dropout_ratio=dropout_ratio
)
where_col_decoder = WhereColumnDecoder(
    input_size, hidden_size, output_size=1, num_layers=num_layers, dropout_ratio=dropout_ratio
)
where_op_decoder = WhereOpDecoder(
    input_size, hidden_size, output_size=n_cond_ops, num_layers=num_layers, dropout_ratio=dropout_ratio
)

In [70]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(batch_sqls)
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv

([16, 16],
 [0, 0],
 [2, 2],
 [[10, 3], [10, 3]],
 [[0, 0], [0, 0]],
 [['법인세차감전 순이익', 2018], ['유동자산', 2019]])

In [74]:
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = [13, 2], [0, 1], [3, 1], [[1, 2, 3], [1]], [[0, 0, 1], [0]], []

## Training: Decoding with max probability

In [87]:
def predict_decoder(typ, **kwargs):
    r"""
    if not using teacher force model will use this function to predict answer
    """
    if typ == "sc":  # SELECT column
        select_outputs = kwargs["select_outputs"]
        return select_outputs.argmax(1).tolist()
    elif typ == "sa":  # SELECT aggregation operator
        # not need actually
        agg_outputs = kwargs["agg_outputs"]
        return agg_outputs.argmax(1)
    elif typ == "wn":  # WHERE number
        where_num_outputs = kwargs["where_num_outputs"]
        return where_num_outputs.argmax(1).tolist()
    elif typ == "wc":  # WHERE clause column
        where_col_argsort = kwargs["where_col_argsort"]
        where_nums = kwargs["where_nums"]
        where_col_idxes = [where_col_argsort[b_idx, :w_num].tolist() for b_idx, w_num in enumerate(where_nums)]
        return where_col_idxes
    elif typ == "wo":  # WHERE clause operator
        where_op_outputs = kwargs["where_op_outputs"]
        where_nums = kwargs["where_nums"]
        where_op_idxes = [where_op_outputs.argmax(2)[b_idx, :w_num].tolist() for b_idx, w_num in enumerate(where_nums)]
        return where_op_idxes
    elif typ == "wv":  # WHERE clause value
        pass
    else:
        raise KeyError("`typ` must be in ['sc', 'sa', 'wn', 'wc', 'wo', 'wv']")

In [75]:
decoder_outputs = {}

select_outputs, _ = select_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
select_idxes = g_sc if g_sc else predict_decoder("sc", select_outputs=select_outputs)

agg_outputs, _ = agg_decoder(question_padded, col_padded, question_lengths, col_lengths, select_idxes)

where_num_outputs, _  = where_num_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
where_nums = g_wn if g_wn else predict_decoder("wn", where_num_outputs=where_num_outputs)

where_col_outputs, _ = where_col_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
where_col_argsort = torch.sigmoid(where_col_outputs).argsort(1)
where_col_idxes = g_wc if g_wc else predict_decoder("wc", where_col_argsort=where_col_argsort, where_nums=where_nums)

where_op_outputs = where_op_decoder(question_padded, col_padded, question_lengths, where_nums, where_col_idxes)
where_op_idxes = g_wo if g_wo else predict_decoder("wo", where_op_outputs=where_op_outputs, where_nums=where_nums)


In [88]:
g_wo

[[0, 0, 1], [0]]

In [76]:
where_col_idxes

[[1, 2, 3], [1]]

In [89]:
predict_decoder("wo", where_op_outputs=where_op_outputs, where_nums=where_nums)

[[0, 0, 0], [0]]

In [86]:
[where_op_outputs.argmax(2)[b_idx, :w_num].tolist() for b_idx, w_num in enumerate(where_nums)]

[[0, 0, 0], [0]]

In [69]:
where_col_idxes

[[0, 4, 1, 3], [7, 6, 5, 0]]

In [320]:
agg_outputs, _ = agg_decoder(question_padded, col_padded, question_lengths, col_lengths, select_idxes)
select_outputs, _ = select_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
where_num_outputs, _  = where_num_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
where_col_outputs, _ = where_col_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)


In [322]:
agg_outputs.argmax(1)

tensor([4, 4])

In [66]:
values, indices = [where_col_prob.topk(where_num)]

TypeError: topk(): argument 'k' (position 1) must be int, not Tensor

In [62]:
values

tensor([[0.5125, 0.5125, 0.5125, 0.5125, 0.5125, 0.5125],
        [0.5120, 0.5120, 0.5120, 0.5120, 0.5120, 0.5120]],
       grad_fn=<TopkBackward>)

In [61]:
indices

tensor([[ 2,  3, 19,  7,  1,  8],
        [ 0,  7,  4,  6,  8,  3]])

## Testing: Execution-guided beam decoding

In [37]:
beam_size = 4

select 

In [38]:
select_output, _ = select_decoder(question_padded, header_padded, col_padded, question_lengths)
select_output.size()

torch.Size([2, 20])

construct all possible select + (agg) score

In [39]:
batch_size, n_col = select_output.size()

select_prob = torch.softmax(select_output, 1)  # prob_sc
if n_col < beam_size:
    beam_size_max_col = n_col
else:
    beam_size_max_col = beam_size

prob_sc_sa = torch.zeros([batch_size, beam_size_max_col, n_agg_ops])
prob_sca = torch.zeros_like(prob_sc_sa)
print(prob_sca.size())  # (B, beam-size, n_agg_ops)

torch.Size([2, 4, 6])


In [42]:
# beamseacrh
_, pr_sc_beam = select_output.topk(k=beam_size_max_col)
print(f"sc top k: {pr_sc_beam.tolist()}")

for i_beam in range(beam_size_max_col):
    select_idx = pr_sc_beam[:, i_beam].tolist() # pr_sc
    agg_output, _ = agg_decoder(question_padded, col_padded, question_lengths, select_idx)
    agg_prob = torch.softmax(agg_output, dim=-1)  # prob_sa: (B, n_agg_ops)
    prob_sc_sa[:, i_beam, :] = agg_prob
    
    prob_sc_selected = select_prob[range(batch_size), select_idx]  # (B,)
    prob_sca[:, i_beam, :] = (agg_prob.t() * prob_sc_selected).t()  # (n_agg_ops, B) \odot (1, B) (broadcast) -> (B, max_col)

sc top k: [[19, 18, 0, 14], [3, 4, 6, 5]]


In [43]:
print(prob_sc_sa.data)

tensor([[[0.1765, 0.1692, 0.1639, 0.1756, 0.1588, 0.1561],
         [0.1777, 0.1687, 0.1647, 0.1742, 0.1588, 0.1558],
         [0.1764, 0.1690, 0.1643, 0.1751, 0.1593, 0.1558],
         [0.1779, 0.1693, 0.1647, 0.1732, 0.1581, 0.1568]],

        [[0.1778, 0.1697, 0.1638, 0.1724, 0.1561, 0.1601],
         [0.1778, 0.1712, 0.1638, 0.1742, 0.1554, 0.1577],
         [0.1789, 0.1702, 0.1634, 0.1741, 0.1557, 0.1578],
         [0.1774, 0.1712, 0.1640, 0.1729, 0.1564, 0.1581]]])


In [44]:
print(prob_sca.size())  # (B, beam_size, prob_sc(beam size selected) * prob_agg)
print(prob_sca.data)

torch.Size([2, 4, 6])
tensor([[[0.0088, 0.0085, 0.0082, 0.0088, 0.0079, 0.0078],
         [0.0089, 0.0084, 0.0082, 0.0087, 0.0079, 0.0078],
         [0.0088, 0.0084, 0.0082, 0.0088, 0.0080, 0.0078],
         [0.0089, 0.0085, 0.0082, 0.0087, 0.0079, 0.0078]],

        [[0.0089, 0.0085, 0.0082, 0.0086, 0.0078, 0.0080],
         [0.0089, 0.0086, 0.0082, 0.0087, 0.0078, 0.0079],
         [0.0089, 0.0085, 0.0082, 0.0087, 0.0078, 0.0079],
         [0.0089, 0.0086, 0.0082, 0.0086, 0.0078, 0.0079]]])


In [45]:
def topk_multi_dim(tensor, n_topk):
    batch_size = tensor.size(0)
    values_1d, idxes_1d = tensor.view(batch_size, -1).topk(n_topk)
    idxes = np.stack(np.unravel_index(idxes_1d, tensor.size()[1:])).transpose(1, 2, 0)
    values = tensor.view(batch_size, -1).gather(1, idxes_1d).numpy()
    return idxes, values

In [46]:
# First flatten to 1-d
if np.prod(prob_sca.shape[1:]) < beam_size:
    beam_size_sca = np.prod(prob_sca.shape[1:])
else:
    beam_size_sca = beam_size
# Now as sc_idx is already sorted, re-map them properly.
# idxes: [sc_beam_idx, sa_idx] -> sca_idxes: [sc_idx, sa_idx]
idxes, values = topk_multi_dim(prob_sca.detach().cpu(), n_topk=beam_size_sca)
sc_beam_idxes = idxes[:, :, 0]
sc_idxes = np.stack([pr_sc_beam.numpy()[i, sc_beam_idx] for i, sc_beam_idx in enumerate(sc_beam_idxes)])
sca_idxes = np.stack([sc_idxes, idxes[:, :, 1]]).transpose(1, 2, 0)

In [47]:
sca_idxes

array([[[14,  0],
        [18,  0],
        [19,  0],
        [ 0,  0]],

       [[ 6,  0],
        [ 3,  0],
        [ 4,  0],
        [ 5,  0]]], dtype=int64)

select agg

In [301]:


# argmax
idxes = select_output.argmax(dim=1)
print(idxes.tolist())

# beamseacrh
ith_beam = 0
v, idxes = select_output.topk(k=4)
print(idxes[:, ith_beam].tolist())

[7, 3]
[7, 3]


In [None]:
if 

In [290]:
select_idx = idxes[:, ith_beam]

In [291]:
agg_output, _  = agg_decoder(question_padded, col_padded, question_lengths, select_idx)
agg_output.size()

torch.Size([2, 6])

In [292]:
where_num_output, _  = where_num_decoder(question_padded, header_padded, col_padded, question_lengths, col_lengths)
where_num_output.size()

torch.Size([2, 5])

In [295]:
where_col_output, _ = where_col_decoder(question_padded, header_padded, col_padded, question_lengths)
where_col_output.size()

torch.Size([2, 20])

In [297]:
where_col_prob = torch.sigmoid(where_col_output)

In [92]:
lstm_q = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)
lstm_h = nn.LSTM(input_size, int(hidden_size / 2), num_layers, dropout=dropout_ratio, batch_first=True, bidirectional=True)

col2question_attn = C2QAttention(hidden_size, hidden_size)
col_context_linear = nn.Linear(2*hidden_size, hidden_size)

In [93]:
batch_size, n_col, _ = col_padded.size()
o_q, (h_q, c_q) = lstm_q(question_padded)  # o_q: (B, T_q, H)
o_c, (h_c, c_c) = lstm_h(col_padded)  # o_c: (B, T_c, H)
o_h, (h_h, c_h) = lstm_h(header_padded)  # h_h: (n_direc*num_layers, B, H/2)

header_summary = torch.cat([h for h in h_h[-2:]], dim=1).unsqueeze(1).repeat(1, n_col, 1) # (B, T_c, H)
col_context = torch.cat([o_c, header_summary], dim=2)  # (B, T_c, 2H)
col_context = col_context_linear(col_context)

In [94]:
col_context.size()

torch.Size([2, 20, 100])

In [98]:
where_nums, where_col_idxes

([3, 1], [[1, 2, 3], [1]])

In [99]:
def get_context_padded(col_context, where_nums, where_col_idxes):
    r"""
    Select the where col index and pad it
    In case for have different where column lengths
    """
    batch_size, n_col, hidden_size = col_context.size()
    max_where_col_nums = max(where_nums)
    batches = [col_context[i, batch_col] for i, batch_col in enumerate(where_col_idxes)]  # [(where_col_nums, hidden_size), ...]  len = B
    batches_padded = []
    for b in batches:
        where_col_nums = b.size(0)
        if where_col_nums < max_where_col_nums:
            b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size, device=col_context.device)], dim=0)
        else:
            b_padded = b
        batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
    return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)

In [103]:
col_context_padded = get_context_padded(col_context, where_nums, where_col_idxes)
col_context_padded.size()

torch.Size([2, 3, 100])

In [104]:
x, a = col2question_attn(col_context_padded, o_q, question_lengths, where_nums, rt_attn=True)

In [108]:
max_where_col_nums = max(where_nums)

In [124]:
where_nums

[3, 1]

In [112]:
where_op_idxes

[[0, 0, 1], [0]]

In [127]:
where_op_one_hot = torch.zeros(batch_size, max_where_col_nums, n_cond_ops)

In [129]:
torch.scatter(where_op_one_hot, 2, where_op_idxes, 1)

TypeError: scatter() received an invalid combination of arguments - got (Tensor, int, list, int), but expected one of:
 * (Tensor input, name dim, Tensor index, Tensor src)
      didn't match because some of the arguments have invalid types: (Tensor, !int!, !list!, !int!)
 * (Tensor input, int dim, Tensor index, Tensor src)
      didn't match because some of the arguments have invalid types: (Tensor, int, !list!, !int!)
 * (Tensor input, name dim, Tensor index, Number value)
      didn't match because some of the arguments have invalid types: (Tensor, !int!, !list!, !int!)
 * (Tensor input, int dim, Tensor index, Number value)
      didn't match because some of the arguments have invalid types: (Tensor, int, !list!, !int!)


In [133]:
for i, batch_col in enumerate(where_op_idxes):
    break

In [144]:
n_cond_ops

4

In [142]:
def get_where_op_one_hot_padded(where_op_idxes, where_nums, where_col_idxes):
    r"""
    Select the where col index and pad it
    In case for have different where column lengths
    """
    batch_size, n_col, hidden_size = col_context.size()
    max_where_col_nums = max(where_nums)
    batches = [torch.zeros(where_num, n_cond_ops).scatter(1, torch.LongTensor(batch_col).unsqueeze(1), 1) for where_num, batch_col in zip(where_nums, where_op_idxes)]  
    # [(where_col_nums, n_cond_ops), ...]  len = B
    batches_padded = []
    for b in batches:
        where_col_nums = b.size(0)
        if where_col_nums < max_where_col_nums:
            b_padded = torch.cat([b, torch.zeros((max_where_col_nums-where_col_nums), hidden_size, device=col_context.device)], dim=0)
        else:
            b_padded = b
        batches_padded.append(b_padded)  # (max_where_col_nums, hidden_size)
    return torch.stack(batches_padded) # (B, max_where_col_nums, hidden_size)

tensor([[0],
        [0],
        [1]])

In [141]:
torch.zeros(len(batch_col), n_cond_ops).scatter(1, torch.LongTensor(batch_col).unsqueeze(1), 1)

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.]])

In [143]:
[torch.zeros(where_num, n_cond_ops).scatter(1, torch.LongTensor(batch_col).unsqueeze(1), 1) for where_num, batch_col in zip(where_nums, where_op_idxes)]

[tensor([[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.]]),
 tensor([[1., 0., 0., 0.]])]

In [121]:
wenc_op = []
for b in range(batch_size):
    # [[...], [...]]
    # Pad list to maximum number of selections
    wenc_op1 = torch.zeros(max_where_col_nums, n_cond_ops)
    wo1 = where_op_idxes[b]
    idx_scatter = []
    l_wo1 = len(wo1)
    for i_wo11 in range(max_where_col_nums):
        if i_wo11 < l_wo1:
            wo11 = wo1[i_wo11]
            idx_scatter.append([int(wo11)])
        else:
            idx_scatter.append([3]) # not used anyway

    wenc_op1 = wenc_op1.scatter(1, torch.tensor(idx_scatter), 1)

    wenc_op.append(wenc_op1)

In [122]:
torch.stack(wenc_op).size()

torch.Size([2, 3, 4])

In [123]:
idx_scatter

[[0], [3], [3]]