In [1]:
import tensorflow as tf
print(tf.__version__)

import tensorflow_hub as hub

import numpy as np
import os
import pandas as pd
import re

import gzip
from tqdm import tqdm
import json
from nltk.tokenize import sent_tokenize

# Reduce logging output.
tf.logging.set_verbosity(tf.logging.ERROR)

1.14.0


In [2]:
def open_file(filename) :
    if filename.split('.')[-1] == "gz" :
        file = gzip.open(filename,'rt')
    else :
        file = open(filename, 'rt')
    return file

def collect_target_samples(filename) :
    
    num_lines = sum(1 for line in open_file(filename))    
    data = open_file(filename)
    
    # collected samples
    target_samples = []
    
    for i, line in tqdm(enumerate(data), total = num_lines) :
    
        jsondata = json.loads(line)
    
        if i == 0 :
            continue
            
        context = jsondata['context']
        context_len = len(context)
        
        # preventing sentence split according to . ! ? in answer texts
        for q in jsondata['qas'] :

            ans =  q['detected_answers'][0]['text']
            pos =  q['detected_answers'][0]['char_spans'][0][0]

            if '.' in ans :
                context = context[:pos] + re.sub('\.', '♬', context[pos : pos + len(ans)]) + context[pos + len(ans):]
                
            if '!' in ans :
                context = context[:pos] + re.sub('!', '♪', context[pos : pos + len(ans)]) + context[pos + len(ans):]
                
            if '?' in ans :
                context = context[:pos] + re.sub('\?', '♩', context[pos : pos + len(ans)]) + context[pos + len(ans):]

        # split passage to sentences
        sen_text = sent_tokenize(context)
        sen_pos = []
        
        sample = dict()
        
        for j in range(len(sen_text)) :

            pos = sum(sen_pos) 
            pos += len(sen_text[j])
                
            # adjusting sentence position because sent_tokenize automatically strips splited sentence texts
            while pos < context_len and ord(context[pos]) in [10, 32, 160] : 
                pos += 1

            sen_pos.append(pos - sum(sen_pos))

            # remove tag texts for improving sentence embedding quality
            sen_text[j] = re.sub('\[TLE\]', ' ', sen_text[j])
            sen_text[j] = re.sub('\[DOC\]', ' ', sen_text[j])
            sen_text[j] = re.sub('\[PAR\]', ' ', sen_text[j])
            
            #sen_text[j] = re.sub('<P>', ' ', sen_text[j])
            #sen_text[j] = re.sub('</P>', ' ', sen_text[j])

            # restore replaced characters
            sen_text[j] = re.sub('♬', '.', sen_text[j])
            sen_text[j] = re.sub('♪', '!', sen_text[j])
            sen_text[j] = re.sub('♩', '?', sen_text[j])

        sample['sentence_text'] = sen_text
        sample['sentence_pos'] = sen_pos
            
        questions = []
        for q in jsondata['qas'] :
            
            # only use first detected answer
            answer = q['detected_answers'][0]
            
            spans = []
            unique_sen = set()
            for j in range(len(answer['char_spans'])) :
                pos = answer['char_spans'][j][0]
                
                # find sentence index including target span
                idx = 0
                for k in range(len(sen_text)) :
                    if pos < (sum(sen_pos[:k+1])) :
                        idx = k
                        break
                
                # ignoring spans with duplicated sentence index
                if idx not in unique_sen :
                    spans.append((answer['char_spans'][j], idx))
                    unique_sen.add(idx)
              
            if len(spans) > 1 :
                questions.append({
                     'qid' : q['qid']
                    ,'question' : q['question']
                    ,'answer'   : answer['text']
                    ,'ans_spans': spans
                })
            
        if len(questions) > 0 :
            sample['question'] = questions
            target_samples.append(sample)

    return target_samples

In [3]:
import pickle

if True :

    train_files = [os.path.join("./MRQA-Shared-Task-2019/download_train", file) for file in os.listdir("./MRQA-Shared-Task-2019/download_train")]
    dev_files = [os.path.join("./MRQA-Shared-Task-2019/download_in_domain_dev", file) for file in os.listdir("./MRQA-Shared-Task-2019/download_in_domain_dev")]

    print(train_files)
    print(dev_files)

    all_target_samples = dict()
    for file in train_files[:-1] + dev_files[:-1] :
        print(file)
        target_samples = collect_target_samples(file)
        all_target_samples[file] = target_samples
        questions = []
        for l in target_samples :
            questions = questions + l['question']
        print("Num. collected samples :", len(questions))

    with open("all_target_samples.pickle", 'wb') as handle:
        pickle.dump(all_target_samples, handle)
    
else :
    
    with open("all_target_samples.pickle", 'rb') as handle:
        all_target_samples = pickle.load(handle)
        
print(all_target_samples.keys())

['./MRQA-Shared-Task-2019/download_train/NaturalQuestions.jsonl.gz', './MRQA-Shared-Task-2019/download_train/HotpotQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/NewsQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/SQuAD.jsonl.gz', './MRQA-Shared-Task-2019/download_train/TriviaQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/SearchQA.jsonl.gz']
['./MRQA-Shared-Task-2019/download_in_domain_dev/NaturalQuestions.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/HotpotQA.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/NewsQA.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/SQuAD.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/NaturalQuestionsrevised.jsonl', './MRQA-Shared-Task-2019/download_in_domain_dev/SQuADrevised.jsonl', './MRQA-Shared-Task-2019/download_in_domain_dev/TriviaQA.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/SearchQA.jsonl.gz']
./MRQA-Shared-Task-2019/download_train/NaturalQuestions.jsonl.gz


100%|██████████| 104072/104072 [00:31<00:00, 3287.54it/s]


Num. collected samples : 0
./MRQA-Shared-Task-2019/download_train/HotpotQA.jsonl.gz


100%|██████████| 72929/72929 [00:32<00:00, 2267.89it/s]


Num. collected samples : 26715
./MRQA-Shared-Task-2019/download_train/NewsQA.jsonl.gz


100%|██████████| 11429/11429 [00:17<00:00, 655.62it/s]


Num. collected samples : 41
./MRQA-Shared-Task-2019/download_train/SQuAD.jsonl.gz


100%|██████████| 18886/18886 [00:06<00:00, 2722.10it/s]


Num. collected samples : 598
./MRQA-Shared-Task-2019/download_train/TriviaQA.jsonl.gz


100%|██████████| 61689/61689 [02:02<00:00, 502.26it/s]


Num. collected samples : 45983
./MRQA-Shared-Task-2019/download_in_domain_dev/NaturalQuestions.jsonl.gz


100%|██████████| 12837/12837 [00:03<00:00, 3380.58it/s]


Num. collected samples : 0
./MRQA-Shared-Task-2019/download_in_domain_dev/HotpotQA.jsonl.gz


100%|██████████| 5905/5905 [00:02<00:00, 2281.97it/s]


Num. collected samples : 1681
./MRQA-Shared-Task-2019/download_in_domain_dev/NewsQA.jsonl.gz


100%|██████████| 639/639 [00:00<00:00, 709.90it/s]
  0%|          | 0/2068 [00:00<?, ?it/s]

Num. collected samples : 0
./MRQA-Shared-Task-2019/download_in_domain_dev/SQuAD.jsonl.gz


100%|██████████| 2068/2068 [00:00<00:00, 2430.85it/s]
0it [00:00, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 3297.41it/s]


Num. collected samples : 63
./MRQA-Shared-Task-2019/download_in_domain_dev/NaturalQuestionsrevised.jsonl
Num. collected samples : 0
./MRQA-Shared-Task-2019/download_in_domain_dev/SQuADrevised.jsonl
Num. collected samples : 0
./MRQA-Shared-Task-2019/download_in_domain_dev/TriviaQA.jsonl.gz


100%|██████████| 7786/7786 [00:15<00:00, 559.27it/s]


Num. collected samples : 5860
dict_keys(['./MRQA-Shared-Task-2019/download_train/NaturalQuestions.jsonl.gz', './MRQA-Shared-Task-2019/download_train/HotpotQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/NewsQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/SQuAD.jsonl.gz', './MRQA-Shared-Task-2019/download_train/TriviaQA.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/NaturalQuestions.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/HotpotQA.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/NewsQA.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/SQuAD.jsonl.gz', './MRQA-Shared-Task-2019/download_in_domain_dev/NaturalQuestionsrevised.jsonl', './MRQA-Shared-Task-2019/download_in_domain_dev/SQuADrevised.jsonl', './MRQA-Shared-Task-2019/download_in_domain_dev/TriviaQA.jsonl.gz'])


In [4]:
from sklearn.metrics.pairwise import cosine_similarity

def correct_target_samples(data) :
    
    qid = []
    span_num = []
    spans = []
    texts = []
    ans = []
    for d in data :
        sentence_text = d['sentence_text']
        for q in d['question'] :
            texts.append(q['question'])
            spans.append(q['ans_spans'])
            span_texts = [sentence_text[t[1]] for t in q['ans_spans']]
            texts = texts + span_texts      
            qid.append(q['qid'])
            ans.append(q['answer'])
            span_num.append(len(span_texts))
            
    print("Num. quesitons :", len(qid))
    print("Num. texts     :", len(texts))

    # Import the Universal Sentence Encoder's TF Hub module
    module_url = "https://tfhub.dev/google/universal-sentence-encoder/2" #@param ["https://tfhub.dev/google/universal-sentence-encoder/2", "https://tfhub.dev/google/universal-sentence-encoder-large/3"]
    embed = hub.Module(module_url)

    session = tf.Session()
    session.run([tf.global_variables_initializer(), tf.tables_initializer()])
    
    batch_size = 512
    text_embed = []
    for i in tqdm(range(0, len(texts), batch_size)) :
        text_embed.append(session.run(embed(texts[i:i+batch_size])))
    text_embed = np.concatenate(text_embed)
    #text_embed = np.zeros((len(texts), 512))
        
    correct_spans = []
    for i in range(len(qid)) :
        # get accumulated previous text number
        acc_num = sum(span_num[:i]) + i
        q_embed = text_embed[acc_num].reshape(-1, 512)
        s_embed = text_embed[acc_num+1:acc_num+1+span_num[i]].reshape(-1, 512)
        sim = cosine_similarity(q_embed, s_embed)
        most_sim_idx = np.argsort(sim)[0][-1]
        if most_sim_idx != 0 :
            correct_spans.append({
                 "qid" : qid[i]
                ,"question" : texts[acc_num]
                ,"answer"   : ans[i]
                ,"origin_span" : (texts[acc_num + 1] , spans[i][0])
                ,"revise_span" : (texts[acc_num + 1 + most_sim_idx] , spans[i][most_sim_idx])
            })
        
    return correct_spans

correct_spans = correct_target_samples(all_target_samples['./MRQA-Shared-Task-2019/download_train/SQuAD.jsonl.gz'].copy())
len(correct_spans)

Num. quesitons : 598
Num. texts     : 2247


100%|██████████| 5/5 [00:15<00:00,  3.27s/it]


230

In [5]:
import pandas as pd

data = []

for s in correct_spans :
    data.append([
         s['qid']
        ,s['question']
        ,s['answer']
        ,s['origin_span']
        ,s['revise_span']
    ])
    
df = pd.DataFrame(data)
df.columns = ['qid', 'question', 'answer', 'origin_span', 'revise_span']
df.to_csv("revise_spans.csv", index = False)