In [220]:
import numpy as np
import pandas as pd
import os
import datetime
import tqdm
import pickle
import re

# 1. load MNL data

In [159]:
! ls ../data/MNLI

dev_matched.tsv
dev_mismatched.tsv
diagnostic.tsv
diagnostic-full.tsv
original
README.txt
test_matched.tsv
test_mismatched.tsv
train.tsv


In [161]:
matched = []
with open('../data/MNLI/dev_matched.tsv', 'r', encoding="utf8") as fi:
    for l in fi:
        matched.append(l.replace("\n", "").split('\t'))

len(matched), matched[0]

(9816,
 ['index',
  'promptID',
  'pairID',
  'genre',
  'sentence1_binary_parse',
  'sentence2_binary_parse',
  'sentence1_parse',
  'sentence2_parse',
  'sentence1',
  'sentence2',
  'label1',
  'label2',
  'label3',
  'label4',
  'label5',
  'gold_label'])

In [162]:
mismatched = []
with open('../data/MNLI/dev_mismatched.tsv', 'r', encoding="utf8") as fi:
    for l in fi:
        mismatched.append(l.replace("\n", "").split('\t'))

len(mismatched), mismatched[0]

(9833,
 ['index',
  'promptID',
  'pairID',
  'genre',
  'sentence1_binary_parse',
  'sentence2_binary_parse',
  'sentence1_parse',
  'sentence2_parse',
  'sentence1',
  'sentence2',
  'label1',
  'label2',
  'label3',
  'label4',
  'label5',
  'gold_label'])

# 2. Load LRP data and merge with MNLI data

In [163]:
! ls *.pkl

data_list.pkl
data_list-new.pkl


In [164]:
data = pickle.load(open('data_list-new.pkl', 'rb'))
len(data), data[0].keys()

(19647,
 dict_keys(['final_prob', 'answer', 'prediction', 'text_lst', 'text', 'question', 'sentence', 'answer_list', 'reward0', 'reward_norm0', 'reward_norm_dict0', 'reward1', 'reward_norm1', 'reward_norm_dict1', 'reward2', 'reward_norm2', 'reward_norm_dict2']))

## 2.1 Merge with MNLI data

In [167]:
lbl = matched[0]
for i, m in enumerate(matched[1:]):
    d = {}
    for ii, l in enumerate(lbl):
        d[l] = m[ii]
    if len(data) <= i:
        break
    data[i]['src'] = d

In [168]:
lbl = mismatched[0]
for i, m in enumerate(mismatched[1:]):
    d = {}
    for ii, l in enumerate(lbl):
        d[l] = m[ii]
    if len(data) <= i:
        break
    data[i + 9815]['src'] = d

In [169]:
i = -1
data[i]['src'], '||||', data[i]['question'], data[i]['sentence'], len(data[i]['reward0'][0]), data[i].keys()


({'index': '9831',
  'promptID': '8693',
  'pairID': '8693e',
  'genre': 'verbatim',
  'sentence1_binary_parse': "( ( ( ( ( ( ( ( ( ( ( Bloomer ( -LRB- ( ( for ( ` ( flower ' ) ) ) -RRB- ) ) ) , ) ( butter ( -LRB- ( ( for ( ` ( ram ' ) ) ) -RRB- ) ) ) ) , ) or ) ( even flower ) ) ( -LRB- ( ( for ( ` ( river ' ) ) ) -RRB- ) ) ) ( are ( recurrent examples ) ) ) , ) but ) ( solvers ( ( must always ) ( be ( on ( ( the alert ) ( for ( ( new traps ) ( of this ) ) ) ) ) ) ) ) )",
  'sentence2_binary_parse': '( ( Bloomer ( is ( ( another word ) ( for flower ) ) ) ) ( , ( butter ( ( is ( for ( ( ( ram and ) flower ) ( for river ) ) ) ) . ) ) ) )',
  'sentence1_parse': "(ROOT (FRAG (S (S (NP (NP (NP (NP (NNP Bloomer)) (PRN (-LRB- -LRB-) (PP (IN for) (NP (`` `) (NN flower) ('' '))) (-RRB- -RRB-))) (, ,) (NP (NP (NN butter)) (PRN (-LRB- -LRB-) (PP (IN for) (NP (`` `) (NN ram) ('' '))) (-RRB- -RRB-))) (, ,) (CC or) (NP (RB even) (NN flower))) (PRN (-LRB- -LRB-) (PP (IN for) (NP (`` `) (NN river) ('

# 2.2 Reconnect splitted words

In [264]:
def add_space_to_punc(txt):
    punc = './:{}()'
    for p in punc:
        txt = txt.replace(p, ' ' + p + ' ')
    txt = txt.replace("'s", " 's ").replace("n't", " n't ")
    return txt

In [265]:
debug = False

for i in tqdm.tqdm(range(len(data))):
    question = add_space_to_punc(data[i]['src']['sentence1']).split()
    sentence = add_space_to_punc(data[i]['src']['sentence2']).split()
    question_id = []
    sentence_id = []
    cnt = 0
    is_q = True
    ii = -1

    while ii < len(data[i]['text_lst']) - 1:
        ii += 1
        w = data[i]['text_lst'][ii]
        if w == '[CLS]':
            src = question
            target = question_id
            cnt = 0
            continue

        if debug is True:
            print(w, ii)

        if w == '[SEP]':
            is_q = False
            src = sentence
            target = sentence_id

            if debug is True:
                print('-' * 70)

            cnt = 0
        elif w == src[cnt].lower().strip():
            if debug is True:
                target.append([ii, w])
            else:
                target.append([ii])

            cnt +=1
        elif w == src[cnt].lower().strip()[:len(w)]:
            if debug is True:
                print(w, src[cnt])

            target.append([])
            for iii in range(1, 8):
                w_new = ''.join(data[i]['text_lst'][ii:ii+iii]).replace("#", "")
                if w_new != src[cnt].lower().strip()[:len(w_new)]:
                    ii += (iii - 2)
                    cnt += 1
                    if debug is True:
                        print(ii, cnt)
                        if ii < len(data[i]['text_lst']) and cnt < len(src):
                            print("===>", w_new, data[i]['text_lst'][ii] , src[cnt].lower().strip())

                    break
                else:
                    if debug is True:
                        target[-1] += [ii + iii - 1, w_new]
                    else:
                        target[-1] += [ii + iii - 1]
    
    data[i]['question_ids'] = question_id
    data[i]['sentence_ids'] = sentence_id
    data[i]['question_list'] = question
    data[i]['sentence_list'] = sentence
    

100%|██████████| 19647/19647 [00:01<00:00, 19588.71it/s]


## 2.3 Extract POS

In [286]:
debug = False

for i in tqdm.tqdm(range(len(data))):
#     print(i)
    for j in range(2):
        if j == 0:
            txt = data[i]['src']['sentence1_parse']
            txt_list = data[i]['question_list']
            data[i]['question_pos'] = []
            target = data[i]['question_pos']
        else:
            txt = data[i]['src']['sentence2_parse']
            txt_list = data[i]['answer_list']
            data[i]['answer_pos'] = []
            target = data[i]['answer_pos']

        pos_list = []
        st = ''
        for t in txt:
            if t == '(':
                st = ''
            elif t == ')':
                if len(st) > 0:
                    st_1 =  re.sub("[^0-9a-z]", "", st.lower().split()[-1])
                    if st[0] != '-' and len(st_1) > 0:
                        pos_list.append(st)
                st = ''
            else:
                st += t
        pos_list2 = [p.split() for p in pos_list]
        sentence1_pos = ['']* len(txt_list)

        iii = 0
        for ii, w in enumerate(txt_list):
            w_new =  re.sub("[^0-9a-z]", "", w.lower())
            if debug is True:
                print(ii, iii, w, w_new)

            if len(w_new) > 0:
                pos_w = re.sub("[^0-9a-z]", "", pos_list2[iii][-1].lower())
                if debug is True:
                    print('===>', w_new, pos_w)

                if w_new == pos_w:
                    sentence1_pos[ii] = pos_list2[iii][0]
                    iii += 1
        target = sentence1_pos

100%|██████████| 19647/19647 [00:04<00:00, 4042.55it/s]


## 2.3 Recalculate rewards

In [297]:
debug = False

for i in tqdm.tqdm(range(len(data))):
    if debug is True:
        print(i)

    data[i]['question_id_reward0'] = []
    data[i]['question_id_reward1'] = []
    data[i]['question_id_reward2'] = []
    
    for l1 in data[i]['question_ids']:
        data[i]['question_id_reward0'].append(sum([data[i]['reward0'][0][l] for l in l1]))
        data[i]['question_id_reward1'].append(sum([data[i]['reward1'][0][l] for l in l1]))
        data[i]['question_id_reward2'].append(sum([data[i]['reward2'][0][l] for l in l1]))

    data[i]['sentence_id_reward0'] = []
    data[i]['sentence_id_reward1'] = []
    data[i]['sentence_id_reward2'] = []
    
    for l1 in data[i]['sentence_ids']:
        data[i]['sentence_id_reward0'].append(sum([data[i]['reward0'][0][l] for l in l1]))
        data[i]['sentence_id_reward1'].append(sum([data[i]['reward1'][0][l] for l in l1]))
        data[i]['sentence_id_reward2'].append(sum([data[i]['reward2'][0][l] for l in l1]))

        

100%|██████████| 19647/19647 [00:01<00:00, 19148.96it/s]


In [299]:
data[i]['sentence_id_reward0'], data[i]['reward0'][0]
#, data[i]['sentence_id_reward1'], data[i]['sentence_id_reward2']

([0.002134892623871565,
  -0.0014312209095805883,
  -0.0038709824439138174,
  -0.009584102779626846,
  0.002653901930898428,
  -0.002251827740110457,
  0.00963689386844635,
  -0.0017088635358959436,
  -0.00034058024175465107,
  -0.01138190645724535,
  -0.0016721236752346158,
  -0.0015907816123217344,
  -0.000903072883374989,
  -0.00022454402642324567,
  0.005671633407473564],
 [0.02040218375623226,
  -0.0010138035286217928,
  0.0014954593498259783,
  0.0030733866151422262,
  -0.0005306612583808601,
  0.006268484517931938,
  -0.0027839052490890026,
  -0.00012536458962131292,
  -0.0043652961030602455,
  0.0018189530819654465,
  0.0050583165138959885,
  0.0031399717554450035,
  -0.00022344780154526234,
  0.002818953013047576,
  0.0021987920626997948,
  -0.0008798292838037014,
  -0.0014544443693012,
  0.0034465391654521227,
  -0.0006571013946086168,
  -0.004202545620501041,
  -0.002362002618610859,
  0.0037485056091099977,
  -0.0020993894431740046,
  0.001844980288296938,
  -0.002464176854

In [294]:
for l1 in data[i]['question_ids']:
    print(l1, sum([ data[i]['reward0'][0][l] for l in l1]))

[1, 2] 0.0004816558212041855
[3] 0.0030733866151422262
[4] -0.0005306612583808601
[5, 6, 7] 0.0033592146792216226
[8] -0.0043652961030602455
[9] 0.0018189530819654465
[10] 0.0050583165138959885
[11] 0.0031399717554450035
[12] -0.00022344780154526234
[13, 14, 15] 0.004137915791943669
[16] -0.0014544443693012
[17] 0.0034465391654521227
[18] -0.0006571013946086168
[19] -0.004202545620501041
[20] -0.002362002618610859
[21] 0.0037485056091099977
[22] -0.0020993894431740046
[23, 24, 25] -0.002374849747866392
[26] 2.998911077156663e-05
[27] -0.00023667566711083055
[28, 29] 0.008569346508011222
[30, 31] 0.006389181362465024
[32] -0.004167092498391867
[33, 34] 0.004194544395431876
[35] 0.0007390343816950917
[36] 0.0005092096398584545
[37] 0.0008171062218025327
[38] -0.0014446277637034655
[39] -0.0010226150043308735
[40] -0.005621697753667831
[41] -0.0026817647740244865
[42] -0.0007365758065134287
[43] -0.006110175047069788
[44] -0.0002788978163152933
[45] 0.0011149966157972813


In [282]:
debug = True
i = 122
j = 0
if j == 0:
    txt = data[i]['src']['sentence1_parse']
    txt_list = data[i]['question_list']
    data[i]['question_pos'] = []
    target = data[i]['question_pos']
else:
    txt = data[i]['src']['sentence2_parse']
    txt_list = data[i]['answer_list']
    data[i]['answer_pos'] = []
    target = data[i]['answer_pos']

pos_list = []
st = ''
for t in txt:
    if t == '(':
        st = ''
    elif t == ')':
        if len(st) > 0:
            st_1 =  re.sub("[^0-9a-z]", "", st.lower().split()[-1])
            if st[0] != '-' and len(st_1) > 0:
                pos_list.append(st)
        st = ''
    else:
        st += t
pos_list2 = [p.split() for p in pos_list]
sentence1_pos = ['']* len(txt_list)

iii = 0
for ii, w in enumerate(txt_list):
    w_new =  re.sub("[^0-9a-z]", "", w.lower())
    if debug is True:
        print(ii, iii, w, w_new)

    if len(w_new) > 0:
        pos_w = re.sub("[^0-9a-z]", "", pos_list2[iii][-1].lower())
        if debug is True:
            print('===>', w_new, pos_w)

        if w_new == pos_w:
            sentence1_pos[ii] = pos_list2[iii][0]
            iii += 1
target = sentence1_pos


0 0 Candidates candidates
===> candidates candidates
1 1 must must
===> must must
2 2 submit submit
===> submit submit
3 3 a a
===> a a
4 4 set set
===> set set
5 5 of of
===> of of
6 6 fingerprints fingerprints
===> fingerprints fingerprints
7 7 for for
===> for for
8 8 review review
===> review review
9 9 by by
===> by by
10 10 the the
===> the the
11 11 FBI fbi
===> fbi fbi
12 12 . 


In [202]:
len(data)

19647

In [93]:
len(data[i]['sentence'].split()), data[i]['text'],data[i]['sentence'], len(data[i]['text_lst'])

(17,
 "uh i don ' t know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him i like him for the most part , but would still enjoy seeing someone beat him .",
 'i like him for the most part , but would still enjoy seeing someone beat him .',
 50)

In [196]:
i = 14895
debug = True

question = add_space_to_punc(data[i]['src']['sentence1']).split()
sentence = add_space_to_punc(data[i]['src']['sentence2']).split()
question_id = []
sentence_id = []
cnt = 0
is_q = True
ii = -1

while ii < len(data[i]['text_lst']) - 1:
    ii += 1
    w = data[i]['text_lst'][ii]
    if w == '[CLS]':
        src = question
        target = question_id
        cnt = 0
        continue

    if debug is True:
        print(w, ii)

    if w == '[SEP]':
        is_q = False
        src = sentence
        target = sentence_id
        
        if debug is True:
            print('-' * 70)
        
        cnt = 0
    elif w == src[cnt].lower().strip():
        if debug is True:
            target.append([ii, w])
        else:
            target.append([ii])

        cnt +=1
    elif w == src[cnt].lower().strip()[:len(w)]:
        if debug is True:
            print(w, src[cnt])
            
        target.append([])
        for iii in range(1, 8):
            w_new = ''.join(data[i]['text_lst'][ii:ii+iii]).replace("#", "")
            if w_new != src[cnt].lower().strip()[:len(w_new)]:
                ii += (iii - 2)
                cnt += 1
                if debug is True:
                    print(ii, cnt)
                    if ii < len(data[i]['text_lst']) and cnt < len(src):
                        print("===>", w_new, data[i]['text_lst'][ii] , src[cnt].lower().strip())

                break
            else:
                if debug is True:
                    target[-1] += [ii + iii - 1, w_new]
                else:
                    target[-1] += [ii + iii - 1]
                    
question_id, sentence_id, '||||', question, sentence, '|||', data[i]['src'], data[i]['text_lst']

- 1
- --the
3 1
===> --theteachers the teachers
teachers 4
were 5
much 6
more 7
more more,
8 5
===> more,i , i
i 9
guess 10
guess guess,
11 7
===> guess,i , i
i 12
don 13
don don't
15 9
===> don'tknow t know???
know 16
know know???
19 10
[SEP] 20
----------------------------------------------------------------------
i 21
don 22
don don't
24 2
===> don'tknow t know,
know 25
know know,
26 3
===> know,the , the
the 27
teachers 28
were 29
- 30
- ---??
34 7
[SEP] 35
----------------------------------------------------------------------


([[1, '-', 2, '--', 3, '--the'],
  [4, 'teachers'],
  [5, 'were'],
  [6, 'much'],
  [7, 'more', 8, 'more,'],
  [9, 'i'],
  [10, 'guess', 11, 'guess,'],
  [12, 'i'],
  [13, 'don', 14, "don'", 15, "don't"],
  [16, 'know', 17, 'know?', 18, 'know??', 19, 'know???']],
 [[21, 'i'],
  [22, 'don', 23, "don'", 24, "don't"],
  [25, 'know', 26, 'know,'],
  [27, 'the'],
  [28, 'teachers'],
  [29, 'were'],
  [30, '-', 31, '--', 32, '---', 33, '---?', 34, '---??']],
 '||||',
 ['--the',
  'teachers',
  'were',
  'much',
  'more,',
  'I',
  'guess,',
  'I',
  "don't",
  'know???'],
 ['I', "don't", 'know,', 'the', 'teachers', 'were', '---??'],
 '|||',
 {'index': '5080',
  'promptID': '38542',
  'pairID': '38542e',
  'genre': 'facetoface',
  'sentence1_binary_parse': "( -- ( ( the teachers ) ( were ( ( ( ( much more ) ( , ( I guess ) ) ) , ) ( I ( ( do n't ) ( know ??? ) ) ) ) ) ) )",
  'sentence2_binary_parse': "( I ( ( do n't ) ( ( ( ( know , ) ( ( the teachers ) were ) ) -- ) ?? ) ) )",
  'sentence1_

In [197]:
question_id, sentence_id, len(data[i]['text_lst']), '||||', question, sentence, '|||', data[i]['text_lst'], data[i]['src']

([[1, '-', 2, '--', 3, '--the'],
  [4, 'teachers'],
  [5, 'were'],
  [6, 'much'],
  [7, 'more', 8, 'more,'],
  [9, 'i'],
  [10, 'guess', 11, 'guess,'],
  [12, 'i'],
  [13, 'don', 14, "don'", 15, "don't"],
  [16, 'know', 17, 'know?', 18, 'know??', 19, 'know???']],
 [[21, 'i'],
  [22, 'don', 23, "don'", 24, "don't"],
  [25, 'know', 26, 'know,'],
  [27, 'the'],
  [28, 'teachers'],
  [29, 'were'],
  [30, '-', 31, '--', 32, '---', 33, '---?', 34, '---??']],
 36,
 '||||',
 ['--the',
  'teachers',
  'were',
  'much',
  'more,',
  'I',
  'guess,',
  'I',
  "don't",
  'know???'],
 ['I', "don't", 'know,', 'the', 'teachers', 'were', '---??'],
 '|||',
 ['[CLS]',
  '-',
  '-',
  'the',
  'teachers',
  'were',
  'much',
  'more',
  ',',
  'i',
  'guess',
  ',',
  'i',
  'don',
  "'",
  't',
  'know',
  '?',
  '?',
  '?',
  '[SEP]',
  'i',
  'don',
  "'",
  't',
  'know',
  ',',
  'the',
  'teachers',
  'were',
  '-',
  '-',
  '-',
  '?',
  '?',
  '[SEP]'],
 {'index': '5080',
  'promptID': '38542',
 

In [46]:
data[i]['text_lst']

['[CLS]',
 'this',
 'site',
 'includes',
 'a',
 'list',
 'of',
 'all',
 'award',
 'winners',
 'and',
 'a',
 'search',
 '##able',
 'database',
 'of',
 'government',
 'executive',
 'articles',
 '.',
 '[SEP]',
 'the',
 'government',
 'executive',
 'articles',
 'housed',
 'on',
 'the',
 'website',
 'are',
 'not',
 'able',
 'to',
 'be',
 'searched',
 '.',
 '[SEP]']

In [301]:
data[i].keys()

dict_keys(['final_prob', 'answer', 'prediction', 'text_lst', 'text', 'question', 'sentence', 'answer_list', 'reward0', 'reward_norm0', 'reward_norm_dict0', 'reward1', 'reward_norm1', 'reward_norm_dict1', 'reward2', 'reward_norm2', 'reward_norm_dict2', 'src', 'question_ids', 'sentence_ids', 'question_list', 'sentence_list', 'question_pos', 'answer_pos', 'question_id_reward0', 'question_id_reward1', 'question_id_reward2', 'sentence_id_reward0', 'sentence_id_reward1', 'sentence_id_reward2'])

In [302]:
data[i]['sentence_id_reward1']

[-0.0037531480193138123,
 0.006491946056485176,
 0.00605211453512311,
 0.0017841905355453491,
 0.0005510298069566488,
 0.007249016081914306,
 -0.01052998099476099,
 0.008899006992578506,
 0.012983223423361778,
 0.0051713059656322,
 -0.009235817939043045,
 -0.004486431367695332,
 -0.008614360354840755,
 -0.015913477167487144,
 -0.006748727988451719]