# BiDAF

* SQuAD: https://rajpurkar.github.io/SQuAD-explorer/
* BiDAF Paper: https://arxiv.org/abs/1611.01603 
* Score: 
    - BiDAF++: EM) 77.573, F1) 84.858
    - BiDAF: EM) 67.974, F1) 77.323

## Data EDA: SQuAD

**squad-json**

* version
* data: 442
    * title: (str)
    * paragraphs: (m, list(dict))
        * context: (n, str)
        * qas: (k, dict)
            * answers: (list(dict))
                * answer_start: (int)
                * text: (str)
            * question: (str)
            * id: (str)

In [42]:
import json
squad_data = json.load(open('../data/SQuAD/squad/train-v1.1.json'))
print('lenght of data: {}'.format(len(squad_data['data'])))

lenght of data: 442


In [46]:
def show_ex(data, data_idx):
    ex = data[data_idx]
    # select idxes
    pa_idx = int(input('insert paragraph idx (total: {}):'.format(len(ex['paragraphs']))))
    assert (pa_idx >= 1) & (pa_idx <= len(ex['paragraphs'])), 'error'
    pa_idx -= 1
    ex_pa = ex['paragraphs'][pa_idx]
    qa_idx = int(input('insert qa idx (total: {}):'.format(len(ex_pa['qas']))))
    assert (qa_idx >= 1) & (qa_idx <= len(ex_pa['qas'])), 'error'
    qa_idx -= 1
    # hightlight
    qas = ex_pa['qas'][qa_idx]
    highlight_idxes = [(x['answer_start'], 
                        x['answer_start']+len(x['text']),
                        x['text']) for x in qas['answers']]
    highlight_context = ex_pa['context']
    for (*_, t) in highlight_idxes:
        temp = highlight_context.split(t)
        temp.insert(1, '\033[40;33m'+t+'\033[m')
        highlight_context = ''.join(temp)
    print('-'*20)
    print('Title: {}'.format(ex['title']))
    print('-'*20)
    print('paragraph({}) context:'.format(pa_idx+1))
    print('{}'.format(highlight_context))
    print('-'*20)
    print('1st qa:')
    print('question: {}'.format(qas['question']))
    for i, ans in enumerate(qas['answers']):
        print('answers: {}'.format(ans['text']))
        print('answers start {} , end {}'.format(highlight_idxes[i][0], highlight_idxes[i][1]))
    return highlight_idxes  # (start_idx, end_idx, text)

In [49]:
highlight_idxes = show_ex(squad_data['data'], 1)

insert paragraph idx (total: 66):1
insert qa idx (total: 20):1
--------------------
Title: Beyoncé
--------------------
paragraph(1) context:
Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame [40;33min the late 1990s[m as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".
--------------------
1st qa:
question: When did Beyonce start becoming popular?
answers: in the late 1990s
answers start 269 , end 286


Expected data form

**inputs**

* context: $[x_1, x_2, \cdots, x_T ]$
* question: $[q_1, q_2, \cdots, q_J]$

**outputs**

* start idx, end idx


In [19]:
import os
import nltk
import torch
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe
def word_tokenize(tokens):
    return [token.replace("''", '"').replace("``", '"') for token in nltk.word_tokenize(tokens)]

In [6]:
RAW = data.RawField()
CHAR_NESTING = data.Field(batch_first=True, tokenize=list, lower=True)
CHAR = data.NestedField(CHAR_NESTING, tokenize=word_tokenize)
WORD = data.Field(batch_first=True, tokenize=word_tokenize, lower=True, include_lengths=True)
LABEL = data.Field(sequential=False, unk_token=None, use_vocab=False)

In [30]:
def preprocess_file(path):
    dump = []
    abnormals = [' ', '\n', '\u3000', '\u202f', '\u2009']

    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        data = data['data']

        for article in data:
            for paragraph in article['paragraphs']:
                context = paragraph['context']
                tokens = word_tokenize(context)
                for qa in paragraph['qas']:
                    id = qa['id']
                    question = qa['question']
                    for ans in qa['answers']:
                        answer = ans['text']
                        s_idx = ans['answer_start']
                        e_idx = s_idx + len(answer)

                        l = 0
                        s_found = False
                        for i, t in enumerate(tokens):
                            while l < len(context):
                                if context[l] in abnormals:
                                    l += 1
                                else:
                                    break
                            # exceptional cases
                            if t[0] == '"' and context[l:l + 2] == '\'\'':
                                t = '\'\'' + t[1:]
                            elif t == '"' and context[l:l + 2] == '\'\'':
                                t = '\'\''

                            l += len(t)
                            if l > s_idx and s_found == False:
                                s_idx = i
                                s_found = True
                            if l >= e_idx:
                                e_idx = i
                                break

                        dump.append(dict([('id', id),
                                          ('context', context),
                                          ('question', question),
                                          ('answer', answer),
                                          ('s_idx', s_idx),
                                          ('e_idx', e_idx)]))
    
    filename = 'prepro_' + path.split('/')[-1]
    path = os.path.join(os.path.split(path)[0], filename)
    with open(f'{path}', 'w', encoding='utf-8') as f:
            for line in dump:
                json.dump(line, f)
                print('', file=f)

In [33]:
path = '../data/SQuAD/squad/train-v1.1.json'
path2 = '../data/SQuAD/squad/dev-v1.1.json'

In [32]:
preprocess_file(path)

In [34]:
preprocess_file(path2)

In [35]:
dict_fields = {'id': ('id', RAW),
               's_idx': ('s_idx', LABEL),
               'e_idx': ('e_idx', LABEL),
               'context': [('c_word', WORD), ('c_char', CHAR)],
               'question': [('q_word', WORD), ('q_char', CHAR)]}

list_fields = [('id', RAW), ('s_idx', LABEL), ('e_idx', LABEL),
               ('c_word', WORD), ('c_char', CHAR),
               ('q_word', WORD), ('q_char', CHAR)]

In [36]:
train, dev = data.TabularDataset.splits(
                path='../data/SQuAD/squad/',
                train='prepro_train-v1.1.json',
                validation='prepro_dev-v1.1.json',
                format='json',
                fields=dict_fields)

In [50]:
a = train.examples[1]

### Character Embedding Layer

### Word Embedding Layer

### Contextual Embedding Layer

### Attention flow Layer

In [15]:
import torch
import torch.nn as nn
vocab_size = 30
B = 1
T = 3
J = 2
d = 5

In [16]:
shape = (B, T, J, 2*d)

In [17]:
H = torch.randn(B, T, 2*d)
U = torch.randn(B, J, 2*d)

In [18]:
H.unsqueeze(2).size()

torch.Size([1, 3, 1, 10])

In [19]:
H_ex = H.unsqueeze(2).expand(shape)
U_ex = U.unsqueeze(1).expand(shape)
H_ex.size(), U_ex.size()

(torch.Size([1, 3, 2, 10]), torch.Size([1, 3, 2, 10]))

In [20]:
w_cat = torch.cat([H_ex, U_ex, H_ex * U_ex], dim=-1)

In [21]:
WS = nn.Linear(6*d, 1)

In [22]:
S = WS(w_cat).squeeze(-1)
S

tensor([[[-0.1604, -0.4624],
         [ 0.6376,  0.2100],
         [-0.1872, -0.3641]]], grad_fn=<SqueezeBackward1>)

#### context2query

In [23]:
S_soft = S.softmax(2)
S_soft

tensor([[[0.5749, 0.4251],
         [0.6053, 0.3947],
         [0.5441, 0.4559]]], grad_fn=<SoftmaxBackward>)

In [24]:
U.size(), S_soft.size()

(torch.Size([1, 2, 10]), torch.Size([1, 3, 2]))

In [25]:
c2q = torch.bmm(S_soft, U)
c2q, c2q.size()

(tensor([[[-0.6446,  0.7566,  0.0930,  0.2170, -0.5246, -0.3520, -0.7119,
            0.2180, -0.5943, -0.3211],
          [-0.6382,  0.6575,  0.1115,  0.1388, -0.5668, -0.3541, -0.6949,
            0.2342, -0.5836, -0.3690],
          [-0.6511,  0.8572,  0.0742,  0.2963, -0.4818, -0.3499, -0.7291,
            0.2015, -0.6052, -0.2724]]], grad_fn=<BmmBackward>),
 torch.Size([1, 3, 10]))

#### query2context

In [26]:
S_max = S.max(2)[0]  # B, T
S_max

tensor([[-0.1604,  0.6376, -0.1872]], grad_fn=<MaxBackward0>)

In [27]:
b = S_max.softmax(1)
b

tensor([[0.2384, 0.5295, 0.2321]], grad_fn=<SoftmaxBackward>)

In [28]:
H.size(), b.size()

(torch.Size([1, 3, 10]), torch.Size([1, 3]))

In [29]:
q2c = H * b.unsqueeze(2)
q2c, q2c.size()

(tensor([[[ 0.1319,  0.0136, -0.0862, -0.0593, -0.2027,  0.1671, -0.3964,
            0.0915,  0.0247, -0.1300],
          [-0.0866,  0.6776, -0.6119,  0.4322, -0.0576, -1.1274,  0.0776,
            0.5618,  0.3047,  0.6047],
          [-0.0753, -0.2872, -0.1330,  0.3314,  0.3718,  0.3554,  0.0554,
            0.0717, -0.3402,  0.1866]]], grad_fn=<ThMulBackward>),
 torch.Size([1, 3, 10]))

In [30]:
H.size(), c2q.size(), q2c.size()

(torch.Size([1, 3, 10]), torch.Size([1, 3, 10]), torch.Size([1, 3, 10]))

In [31]:
d_G = 8*d
beta = nn.Linear(4*2*d, d_G)

In [32]:
G = beta(torch.cat([H, c2q, H*c2q, H*q2c], dim=-1))
G.size()

torch.Size([1, 3, 40])

### Modeling Layer

In [33]:
modeling_layer = nn.LSTM(input_size=d_G,
                         hidden_size=d,
                         num_layers=3,
                         batch_first=True,
                         bidirectional=True)

In [34]:
M, _ = modeling_layer(G)
M.size()

torch.Size([1, 3, 10])

### Output Layer

application-specific: QA task

In [35]:
M.size(), G.size()

(torch.Size([1, 3, 10]), torch.Size([1, 3, 40]))

In [36]:
start_linear = nn.Linear(10*d, 1, bias=False)
output_lstm = nn.LSTM(input_size=2*d,
                      hidden_size=d,
                      num_layers=3,
                      batch_first=True,
                      bidirectional=True)
end_linear = nn.Linear(10*d, 1)

In [91]:
p1_cat = torch.cat([G, M], dim=-1)
p1 = torch.log_softmax(start_linear(p1_cat).squeeze(2), dim=-1)
p1.size()

torch.Size([1, 3])

In [87]:
M2, _ = p2_lstm(M)

In [92]:
p2_cat = torch.cat([G, M2], dim=-1)
p2 = torch.log_softmax(end_linear(p2_cat).squeeze(2), dim=-1)
p2.size()

torch.Size([1, 3])

### Loss

In [98]:
loss_f1 = nn.NLLLoss()
loss_f2 = nn.NLLLoss()
loss_f1(p1, torch.LongTensor([0])) + loss_f2(p2, torch.LongTensor([2]))

tensor(2.1541, grad_fn=<ThAddBackward>)