# Text2SQL Test Notebook

In [1]:
import numpy as np
from pathlib import Path
from main import argument_parsing
from src.model import Text2SQL
# import pytorch_lightning as pl
# import numpy as np

# pl.seed_everything(np.random.randint(0, 100))
args_parser = argument_parsing(preparse=True)
args = args_parser.parse_known_args()[0]

model = Text2SQL.load_from_checkpoint(Path(args.ckpt_dir) / "epochepoch=00-val_loss=63.102-val_acc_sc=0.949-val_acc_sa=1.000-val_acc_wn=0.985-val_acc_wo=0.992.ckpt")
model.eval()
print()




```
  | Name                 | Type              | Params
-----------------------------------------------------------
0 | model_bert           | BertModel         | 92.2 M
1 | model_decoder        | Decoder           | 22.5 M
2 | cross_entropy        | CrossEntropyLoss  | 0
3 | binary_cross_entropy | BCEWithLogitsLoss | 0
4 | cross_entropy_wv     | CrossEntropyLoss  | 0
5 | acc_sc               | Accuracy          | 0
6 | acc_sa               | Accuracy          | 0
7 | acc_wn               | Accuracy          | 0
8 | acc_wo               | Accuracy          | 0
9 | pp_wv                | Perplexity        | 0
-----------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
434.292   Total estimated model params size (MB)
```

In [17]:
table = model.load_tables(Path(args.train_table_file))
table_id = np.random.choice(list(table.keys()))
header = table[table_id]["header"]
print(table_id)
print(header)

002100
['index', 'rcept_no', 'reprt_code', 'bsns_year', 'corp_code', 'stock_code', 'fs_div', 'fs_nm', 'sj_div', 'sj_nm', 'account_nm', 'thstrm_nm', 'thstrm_dt', 'thstrm_amount']


In [18]:
res = model.dbengine.db.query(f'SELECT * FROM "{table_id}"').export("df")
res.tail()

Unnamed: 0,index,rcept_no,reprt_code,bsns_year,corp_code,stock_code,fs_div,fs_nm,sj_div,sj_nm,account_nm,thstrm_nm,thstrm_dt,thstrm_amount
34,34,20210322000474,11011,2020,101433,2100,CFS,연결재무제표,BS,재무상태표,자본총계,제 64 기,2020.12.31 현재,918835950
35,35,20210322000474,11011,2020,101433,2100,CFS,연결재무제표,IS,손익계산서,매출액,제 64 기,2020.01.01 ~ 2020.12.31,-812579367
36,36,20210322000474,11011,2020,101433,2100,CFS,연결재무제표,IS,손익계산서,영업이익,제 64 기,2020.01.01 ~ 2020.12.31,1397074739
37,37,20210322000474,11011,2020,101433,2100,CFS,연결재무제표,IS,손익계산서,법인세차감전 순이익,제 64 기,2020.01.01 ~ 2020.12.31,422284310
38,38,20210322000474,11011,2020,101433,2100,OFS,재무제표,BS,재무상태표,유동자산,제 64 기,2020.12.31 현재,-1983377738


In [19]:
def get_predict_sql(table_id, header, model, predicts):
    
    p_sc = header[predicts['sc'][0]]
    p_sa = model.dbengine.agg_ops[predicts['sa'][0]]
    where_num = predicts['wn'][0]
    predict_SQL = f"SELECT {p_sa}({p_sc}) FROM {table_id} WHERE "

    wcs = predicts['wc'][0]
    wos = predicts['wo'][0]
    wvs = predicts['wv'][0]
    for i, (wc, wo, wv) in enumerate(zip(wcs, wos, wvs)):
        p_wc = header[wc]
        p_wo = model.dbengine.cond_ops[wo]
        p_wv = wv.replace("[E]", "").strip()
        s = f"{p_wc} {p_wo} '{p_wv}'"
        if i == where_num-1:
            predict_SQL += s
        else:
            s += " AND "
            predict_SQL += s
    return predict_SQL

## Example 1

In [27]:
Q = f"{table_id}의 제 64기 영업이익은 얼마니 ???"
data = [{"question": Q, "table_id": table_id}]
data

[{'question': '002100의 제 64기 영업이익은 얼마니 ???', 'table_id': '002100'}]

In [28]:
predicts, attns = model.predict_outputs(data, table, rt_attn=True)
# ANSWER: 'SELECT thstrm_amount FROM "003490" WHERE account_nm = "법인세차감전 순이익" AND bsns_year = 2020'
get_predict_sql(table_id, header, model, predicts)

"SELECT (thstrm_amount) FROM 002100 WHERE fs_div = '2019' AND stock_code = '2019'"

In [29]:
predicts

{'sc': [13],
 'sa': [0],
 'wn': [2],
 'wc': [[6, 5]],
 'wo': [[0, 0]],
 'wv_tkns': [[[554, 116, 8003]],
  [[554, 116, 8003]],
  [[2514, 7063, 5872, 6398, 7405, 8003]],
  [[2514, 7063, 5872, 6398, 7405, 8003]]],
 'wv': [('2019 [E]', '2019 [E]', '비유동부채 [E]', '비유동부채 [E]')]}

In [14]:
Q = f"{table_id} 제 59 기는 언제야 ?"
data = [{"question": Q, "table_id": table_id}]
data

[{'question': '003490 제 59 기는 언제야 ?', 'table_id': '003490'}]

In [15]:
predicts, attns = model.predict_outputs(data, table, rt_attn=True)
get_predict_sql(table_id, header, model, predicts)

"SELECT (thstrm_amount) FROM 003490 WHERE fs_div = '2019' AND stock_code = '2019'"

In [16]:
predicts

{'sc': [13],
 'sa': [0],
 'wn': [2],
 'wc': [[6, 5]],
 'wo': [[0, 0]],
 'wv_tkns': [[[554, 116, 8003]],
  [[554, 116, 8003]],
  [[2514, 7063, 5872, 6398, 7405, 8003]],
  [[2514, 7063, 5872, 6398, 7405, 8003]]],
 'wv': [('2019 [E]', '2019 [E]', '비유동부채 [E]', '비유동부채 [E]')]}

Reason: 

- data is too simple(need diversity)
- loss function issue

TODO:

- May need execution guided decoding(+ beam search)
- Extend the dataset to all companies
- Improve Model
- Code Refactoring
- Get the original purpose: resolve the ambigious parts in questions
- Build Application with Streamlit

In [None]:
tensor_to_img_array(window_grid_mask)