In [None]:
import pandas as pd
import pickle
import torch
import os
import re
import random
import csv
import jsonlines
import numpy as np
import pickle
import time
import gzip
from tqdm import tqdm, trange
from sklearn.cluster import KMeans
from typing import Any, List, Sequence, Callable
from itertools import islice, zip_longest
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModelForSeq2SeqLM
from sklearn.cluster import MiniBatchKMeans

## Origina data transformation

###### Download NQ Train and Dev dataset from https://ai.google.com/research/NaturalQuestions/download
###### NQ Train: https://storage.cloud.google.com/natural_questions/v1.0-simplified/simplified-nq-train.jsonl.gz
###### NQ Dev: https://storage.cloud.google.com/natural_questions/v1.0-simplified/nq-dev-all.jsonl.gz

In [None]:
nq_dev = []

with gzip.open("v1.0-simplified_nq-dev-all.jsonl.gz", "r+") as f:
    for item in jsonlines.Reader(f):
        
        arr = []
        ## question_text
        question_text = item['question_text']
        arr.append(question_text)

        tokens = []
        for i in item['document_tokens']:
            tokens.append(i['token'])
        document_text = ' '.join(tokens)
        
        ## example_id
        example_id = str(item['example_id'])
        arr.append(example_id)

        # document_text = item['document_text']
        ## long_answer
        annotation = item['annotations'][0]
        has_long_answer = annotation['long_answer']['start_token'] >= 0

        long_answers = [
            a['long_answer']
            for a in item['annotations']
            if a['long_answer']['start_token'] >= 0 and has_long_answer
        ]
        if has_long_answer:
            start_token = long_answers[0]['start_token']
            end_token = long_answers[0]['end_token']
            x = document_text.split(' ')
            long_answer = ' '.join(x[start_token:end_token])
            long_answer = re.sub('<[^<]+?>', '', long_answer).replace('\n', '').strip()
        arr.append(long_answer) if has_long_answer else arr.append('')

        # short_answer
        has_short_answer = annotation['short_answers'] or annotation['yes_no_answer'] != 'NONE'
        short_answers = [
            a['short_answers']
            for a in item['annotations']
            if a['short_answers'] and has_short_answer
        ]
        if has_short_answer and len(annotation['short_answers']) != 0:
            sa = []
            for i in short_answers[0]:
                start_token_s = i['start_token']
                end_token_s = i['end_token']
                shorta = ' '.join(x[start_token_s:end_token_s])
                sa.append(shorta)
            short_answer = '|'.join(sa)
            short_answer = re.sub('<[^<]+?>', '', short_answer).replace('\n', '').strip()
        arr.append(short_answer) if has_short_answer else arr.append('')

        ## title
        arr.append(item['document_title'])

        ## abs
        if document_text.find('<P>') != -1:
            abs_start = document_text.index('<P>')
            abs_end = document_text.index('</P>')
            abs = document_text[abs_start+3:abs_end]
        else:
            abs = ''
        arr.append(abs)

        ## content
        if document_text.rfind('</Ul>') != -1:
            final = document_text.rindex('</Ul>')
            document_text = document_text[:final]
            if document_text.rfind('</Ul>') != -1:
                final = document_text.rindex('</Ul>')
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
            else:
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
        else:
            content = document_text[abs_end+4:]
            content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
            content = re.sub(' +', ' ', content)
            arr.append(content)
        doc_tac = item['document_title'] + abs + content
        arr.append(doc_tac)
        language = 'en'
        arr.append(language)
        nq_dev.append(arr)

nq_dev_df = pd.DataFrame(nq_dev)
nq_dev_df.to_csv(r"nq_dev.tsv", sep="\t", mode = 'w', header=None, index =False)

In [None]:
nq_train = []
with gzip.open("v1.0-simplified_simplified-nq-train.jsonl.gz", "r+") as f:
    for item in jsonlines.Reader(f):
        ## question_text
        arr = []
        question_text = item['question_text']
        arr.append(question_text)

        ## example_id
        example_id = str(item['example_id'])
        arr.append(example_id)
        
        document_text = item['document_text']
        
        ## long_answer
        annotation = item['annotations'][0]
        has_long_answer = annotation['long_answer']['start_token'] >= 0

        long_answers = [
            a['long_answer']
            for a in item['annotations']
            if a['long_answer']['start_token'] >= 0 and has_long_answer
        ]
        if has_long_answer:
            start_token = long_answers[0]['start_token']
            end_token = long_answers[0]['end_token']
            x = document_text.split(' ')
            long_answer = ' '.join(x[start_token:end_token])
            long_answer = re.sub('<[^<]+?>', '', long_answer).replace('\n', '').strip()
        arr.append(long_answer) if has_long_answer else arr.append('')

        # short_answer
        has_short_answer = annotation['short_answers'] or annotation['yes_no_answer'] != 'NONE'
        short_answers = [
            a['short_answers']
            for a in item['annotations']
            if a['short_answers'] and has_short_answer
        ]
        if has_short_answer and len(annotation['short_answers']) != 0:
            sa = []
            for i in short_answers[0]:
                start_token_s = i['start_token']
                end_token_s = i['end_token']
                shorta = ' '.join(x[start_token_s:end_token_s])
                sa.append(shorta)
            short_answer = '|'.join(sa)
            short_answer = re.sub('<[^<]+?>', '', short_answer).replace('\n', '').strip()
        arr.append(short_answer) if has_short_answer else arr.append('')

        ## title
        if document_text.find('<H1>') != -1:
            title_start = document_text.index('<H1>')
            title_end = document_text.index('</H1>')
            title = document_text[title_start+4:title_end]
        else:
            title = ''
        arr.append(title)

        ## abs
        if document_text.find('<P>') != -1:
            abs_start = document_text.index('<P>')
            abs_end = document_text.index('</P>')
            abs = document_text[abs_start+3:abs_end]
        else:
            abs = ''
        arr.append(abs)

        ## content
        if document_text.rfind('</Ul>') != -1:
            final = document_text.rindex('</Ul>')
            document_text = document_text[:final]
            if document_text.rfind('</Ul>') != -1:
                final = document_text.rindex('</Ul>')
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
            else:
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
        else:
            content = document_text[abs_end+4:]
            content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
            content = re.sub(' +', ' ', content)
            arr.append(content)

        doc_tac = title + abs + content
        arr.append(doc_tac)

        language = 'en'
        arr.append(language)
        nq_train.append(arr)

nq_train_df = pd.DataFrame(nq_train)
nq_train_df.to_csv(r"nq_train.tsv", sep="\t", mode = 'w', header=None, index =False)

In [None]:
## Mapping tool

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def lower(x):
    text = tokenizer.tokenize(x)
    id_ = tokenizer.convert_tokens_to_ids(text)
    return tokenizer.decode(id_)

In [None]:
## doc_tac denotes the concatenation of title, abstract and content

nq_dev = pd.read_csv('nq_dev.tsv', \
                     names=['query', 'id', 'long_answer', 'short_answer', 'title', 'abstract', 'content', 'doc_tac', 'language'],\
                     header=None, sep='\t')

nq_train = pd.read_csv('nq_train.tsv', \
                       names=['query', 'id', 'long_answer', 'short_answer', 'title', 'abstract', 'content', 'doc_tac', 'language'],\
                       header=None, sep='\t')

nq_dev['title'] = nq_dev['title'].map(lower)
nq_train['title'] = nq_train['title'].map(lower)


In [None]:
## Concat train doc and validation doc to obtain full document collection

nq_all_doc = nq_train.append(nq_dev)
nq_all_doc.reset_index(inplace = True)

In [None]:
## Remove duplicated documents based on titles

nq_all_doc.drop_duplicates('title', inplace = True)
nq_all_doc.reset_index(inplace = True)

In [None]:
## The total amount of documents : 109739

len(nq_all_doc)

In [None]:
## Construct mapping relation

title_doc = {}
title_doc_id = {}
id_doc = {}
ran_id_old_id = {}
idx = 0
for i in range(len(nq_all_doc)):
    title_doc[nq_all_doc['title'][i]] =  nq_all_doc['doc_tac'][i]
    title_doc_id[nq_all_doc['title'][i]] = idx
    id_doc[idx] = nq_all_doc['doc_tac'][i]
    ran_id_old_id[idx] = nq_all_doc['id'][i]
    idx += 1

In [None]:
## Construct Document Content File

train_file = open("NQ_doc_content.tsv", 'w') 

for docid in id_doc.keys():
    train_file.write('\t'.join([str(docid), '', '', id_doc[docid], '', '', 'en']) + '\n')
    train_file.flush()

## Generate BERT embeddings for each document

In [None]:
## Execute the following command to get bert embedding pkl file
## Use 4 GPU
!./bert/bert_NQ.sh 4

In [None]:
## Concat bert embedding
output_bert_base_tensor_nq_qg = []
output_bert_base_id_tensor_nq_qg = []
for num in trange(4):
    with open(f'bert/pkl/NQ_output_tensor_512_content_{num}.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_tensor_nq_qg.extend(data)

    with open(f'bert/pkl/NQ_output_tensor_512_content_{num}_id.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_id_tensor_nq_qg.extend(data)

train_file = open(f"bert/NQ_doc_content_embedding_bert_512.tsv", 'w') 

for idx, doc_tensor in enumerate(output_bert_base_tensor_nq_qg):
    embedding = '|'.join([str(elem) for elem in doc_tensor])
    train_file.write('\t'.join([str(output_bert_base_id_tensor_nq_qg[idx]), '', '', '', '', '', 'en', embedding]) + '\n')
    train_file.flush()

## Apply Hierarchical K-Means on it to generate semantic IDs

In [None]:
## Execute the following command to get kmeans id of the documents
!./kmeans/kmeans_NQ.sh

In [None]:
with open('kmeans/IDMapping_NQ_bert_512_k30_c30_seed_7.pkl', 'rb') as f:
    kmeans_nq_doc_dict = pickle.load(f)
## random id : newid
new_kmeans_nq_doc_dict_512 = {}
for old_docid in kmeans_nq_doc_dict.keys():
    new_kmeans_nq_doc_dict_512[str(old_docid)] = '-'.join(str(elem) for elem in kmeans_nq_doc_dict[old_docid])
    
new_kmeans_nq_doc_dict_512_int_key = {}
for key in new_kmeans_nq_doc_dict_512:
    new_kmeans_nq_doc_dict_512_int_key[int(key)] = new_kmeans_nq_doc_dict_512[key]

## Query Generation

In [None]:
## Execute the following command to generate queries for the documents
## Use 4 GPU
!./qg/NQ_qg.sh 4

In [None]:
## merge parallel results
output_bert_base_tensor_nq_qg = []
output_bert_base_id_tensor_nq_qg = []
for num in trange(4):
    with open(f'qg/pkl/NQ_output_tensor_512_content_64_15_{num}.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_tensor_nq_qg.extend(data)

    with open(f'qg/pkl/NQ_output_tensor_512_content_64_15_{num}_id.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_id_tensor_nq_qg.extend(data)

In [None]:
qg_dict = {}
for i in trange(len(output_bert_base_tensor_nq_qg)):
    if(output_bert_base_id_tensor_nq_qg[i] not in qg_dict):
        qg_dict[output_bert_base_id_tensor_nq_qg[i]] = [output_bert_base_tensor_nq_qg[i]]
    else:
        qg_dict[output_bert_base_id_tensor_nq_qg[i]].append(output_bert_base_tensor_nq_qg[i])

## Genarate training data

In [None]:
## nq_512_qg20.tsv
QG_NUM = 15

In [None]:
qg_file = open("NQ_512_qg.tsv", 'w') 

for queryid in tqdm(qg_dict):
    for query in qg_dict[queryid][:QG_NUM]:
        qg_file.write('\t'.join([query, str(ran_id_old_id[int(queryid)]), queryid, new_kmeans_nq_doc_dict_512[queryid]]) + '\n')
        qg_file.flush()

In [None]:
new_kmeans_nq_doc_dict_512_int_key = {}
for key in new_kmeans_nq_doc_dict_512:
    new_kmeans_nq_doc_dict_512_int_key[int(key)] = new_kmeans_nq_doc_dict_512[key]

In [None]:
## Replace Original IDs with Semantic IDs

In [None]:
## nq_train_doc_newid.tsv

In [None]:
nq_train['randomid'] = nq_train['title'].map(title_doc_id)
nq_train['id_512'] = nq_train['randomid'].map(new_kmeans_nq_doc_dict_512_int_key)

nq_train_ = nq_train.loc[:, ['query', 'id', 'randomid', 'id_512']]  
nq_train_.to_csv('nq_train_doc_newid.tsv', sep='\t', header=None, index=False, encoding='utf-8')

In [None]:
## nq_dev_doc_newid.tsv

In [None]:
nq_dev['randomid'] = nq_dev['title'].map(title_doc_id)
nq_dev['id_512'] = nq_dev['randomid'].map(new_kmeans_nq_doc_dict_512_int_key)


nq_dev_ = nq_dev.loc[:, ['query', 'id', 'randomid', 'id_512']]  
nq_dev_.to_csv('nq_dev_doc_newid.tsv', sep='\t', header=None, index=False, encoding='utf-8')

In [None]:
## title+abs oldid newid

In [None]:
## nq_title_abs.tsv

In [None]:
nq_all_doc_non_duplicate = nq_train.append(nq_dev)
nq_all_doc_non_duplicate.reset_index(inplace = True)

nq_all_doc_non_duplicate['id_512'] = nq_all_doc_non_duplicate['randomid'].map(new_kmeans_nq_doc_dict_512_int_key)

nq_all_doc_non_duplicate['ta'] = nq_all_doc_non_duplicate['title'] + ' ' + nq_all_doc_non_duplicate['abstract']

nq_all_doc_non_duplicate = nq_all_doc_non_duplicate.loc[:, ['ta', 'id', 'randomid','id_512']]  
nq_all_doc_non_duplicate.to_csv('nq_title_abs.tsv', sep='\t', header=None, index=False, encoding='utf-8')

In [None]:
## all_doc_aug_query

In [None]:
## NQ_doc_aug.tsv

In [None]:
queryid_oldid_dict = {}
bertid_oldid_dict = {}
map_file = "./nq_title_abs.tsv"
with open(map_file, 'r') as f:
    for line in f.readlines():
        query, queryid, oldid, bert_k30_c30 = line.split("\t")
        queryid_oldid_dict[oldid] = queryid
        bertid_oldid_dict[oldid] = bert_k30_c30

train_file = "./NQ_doc_content.tsv"
doc_aug_file = open(f"./NQ_doc_aug.tsv", 'w') 
with open(train_file, 'r') as f:
    for line in f.readlines():
        docid, _, _, content, _, _, _ = line.split("\t")
        content = content.split(' ')
        add_num = max(0, len(content)-3000) / 3000
        for i in range(10+int(add_num)):
            begin = random.randrange(0, len(content))
            # if begin >= (len(content)-64):
            #     begin = max(0, len(content)-64)
            end = begin + 64 if len(content) > begin + 64 else len(content)
            doc_aug = content[begin:end]
            doc_aug = ' '.join(doc_aug)
            queryid = queryid_oldid_dict[docid]
            bert_k30_c30 = bertid_oldid_dict[docid]
            # doc_aug_file.write('\t'.join([doc_aug, str(queryid), str(docid), str(bert_k30_c30)]) + '\n')
            doc_aug_file.write('\t'.join([doc_aug, str(queryid), str(docid), str(bert_k30_c30)]))
            doc_aug_file.flush()