In [1]:
import os
import json
import pandas as pd
import pickle
from transformers import BertTokenizer, BertModel
import numpy as np
from tqdm import trange, tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


###### Download TriviaQA dataset from https://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz

In [None]:
evidence_dir = os.path.join("evidence/wikipedia")
qa_train_path = os.path.join("qa/wikipedia-train.json")
qa_dev_path = os.path.join("qa/wikipedia-dev.json")
qa_test_path = os.path.join("qa/wikipedia-test-without-answers.json")

def txt2title(text):
    return os.path.splitext(text)[0].replace('_', ' ')

full_doc = {}
global_index = 0
for file_name in os.listdir(evidence_dir):
    if file_name.endswith(".txt"):
        title = txt2title(file_name)
        assert title not in full_doc, f"dup title for {file_name}"
        with open(os.path.join(evidence_dir, file_name)) as f:
            body = f.read().replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')[:10240]
        full_doc[file_name] = (str(global_index), title, body)
        global_index += 1
lines = ['\t'.join(v) + '\n' for v in full_doc.values()]
with open("trivia_qa_fulldoc.csv", "w") as f:
    f.writelines(lines)


for in_file, out_file in zip((qa_train_path, qa_dev_path, qa_test_path), ("trivia_qa_train.csv", "trivia_qa_dev.csv", "trivia_qa_test.csv")):
    with open(in_file) as f:
        raw = json.load(f)

    lines = []
    for item in raw["Data"]:
        query = item["Question"]
        doc_ids = [full_doc[doc["Filename"]][0] for doc in item["EntityPages"]]
        lines.append(query + "\t" + ','.join(doc_ids) + "\n")

    with open(out_file, "w") as f:
        f.writelines(lines)

In [20]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def lower(x):
    try:
        text = tokenizer.tokenize(x)[:512]
        id_ = tokenizer.convert_tokens_to_ids(text)
        return tokenizer.decode(id_)
    except:
        return x

In [2]:
df_train = pd.read_csv('trivia_qa_train.csv',
                        names=["query", "docid"],
                        encoding='utf-8', header=None, sep='\t')
df_train

Unnamed: 0,query,docid
0,Where in England was Dame Judi Dench born?,2091534792
1,From which country did Angola achieve independ...,4630544684470
2,Which city does David Soul come from?,17309
3,Who won Super Bowl XX?,62618
4,Which was the first European country to abolis...,2148911677
...,...,...
61883,"For a point each, name the 5 countries surroun...",35341
61884,"March 9, 1959 saw the introduction of what Mat...",43025223636991
61885,"On January 16, 2001, President Bill Clinton aw...",527304329366965
61886,"July 27, 1940 saw the introduction of what bel...",2069


In [9]:
df_val = pd.read_csv('trivia_qa_dev.csv',
                        names=["query", "docid"],
                        encoding='utf-8', header=None, sep='\t').loc[:, ['query', 'docid']]
df_val

Unnamed: 0,query,docid
0,Which Lloyd Webber musical premiered in the US...,4290
1,Who was the next British Prime Minister after ...,528375562
2,Who had a 70s No 1 hit with Kiss You All Over?,35959
3,What claimed the life of singer Kathleen Ferrier?,35295
4,Which actress was voted Miss Greenwich Village...,37010
...,...,...
7988,Whose backing band is known as The Miami Sound...,43835
7989,"With a motto of Always Ready, Always There, wh...",46365
7990,Who tried to steal Christmas from the town of ...,72237
7991,What is the name of the parson mentioned in th...,7276124534


In [8]:
df_test = pd.read_csv('trivia_qa_test.csv',
                        names=["query", "docid"],
                        encoding='utf-8', header=None, sep='\t').loc[:, ['query', 'docid']]
df_test

Unnamed: 0,query,docid
0,Asmara international airport is in which country?,5749
1,At whose concert were 11 people trampled to de...,4829266780
2,Andy Warhol/'s 60s exhibition featured cans of...,4371
3,San Giusto international airport is in which c...,51591
4,Who had a 60s No 1 with Travelin' Man?,45668539
...,...,...
7696,"December 15, 1773 saw The Sons of Liberty boar...",57457
7697,The Decepticons are the mortal enemies of whom?,17567
7698,"Next appearing locally in 2013, the Richard Wa...",55255170805945927583
7699,"In South Park, what alter ego does Butters ado...",10967


In [17]:
df_full = pd.read_csv('trivia_qa_fulldoc.csv',
                         names=["docid", "title", "content"],
                        encoding='utf-8', header=None, sep='\t')
df_full

Unnamed: 0,docid,title,content
0,0,Crocodile Dundee,Crocodile Dundee is a 1986 Australian comedy f...
1,1,Heroes (David Bowie song),"Heroes"""" is a song by English musician David B..."
2,2,Weird Al Yankovic,"Alfred Matthew ""Weird Al"" Yankovic ( ; born Oc..."
3,3,Wild Bill Hickok,"James Butler ""Wild Bill"" Hickok (May 27, 1837 ..."
4,4,'Allo 'Allo!,Allo Allo! is a BBC television British sitcom ...
...,...,...,...
73965,73965,Škoda Roomster,The Škoda Roomster (Type 5J) is a MPV-styled l...
73966,73966,Škoda Superb,The current Škoda Superb is a large family car...
73967,73967,Škoda Yeti,"The Škoda Yeti (codenamed Typ 5L)ETKA, accesse..."
73968,73968,’O sole mio,is a globally known Neapolitan song written i...


In [21]:
df_full['title'] = df_full['title'].map(lower)
df_drop_title = df_full.drop_duplicates('title').reset_index(drop=True) 

In [22]:
df_drop_title

Unnamed: 0,docid,title,content
0,0,crocodile dundee,Crocodile Dundee is a 1986 Australian comedy f...
1,1,heroes ( david bowie song ),"Heroes"""" is a song by English musician David B..."
2,2,weird al yankovic,"Alfred Matthew ""Weird Al"" Yankovic ( ; born Oc..."
3,3,wild bill hickok,"James Butler ""Wild Bill"" Hickok (May 27, 1837 ..."
4,4,' allo'allo!,Allo Allo! is a BBC television British sitcom ...
...,...,...,...
73912,73965,skoda roomster,The Škoda Roomster (Type 5J) is a MPV-styled l...
73913,73966,skoda superb,The current Škoda Superb is a large family car...
73914,73967,skoda yeti,"The Škoda Yeti (codenamed Typ 5L)ETKA, accesse..."
73915,73968,’ o sole mio,is a globally known Neapolitan song written i...


In [23]:
title_doc_id = {}
for i in trange(len(df_drop_title)):
    title_doc_id[df_drop_title['title'][i]] = i

origin_new_id = {}
for i in trange(len(df_full)):
    origin_new_id[df_full['docid'][i]] = title_doc_id[df_full['title'][i]]

100%|██████████| 73917/73917 [00:00<00:00, 236845.20it/s]
100%|██████████| 73970/73970 [00:00<00:00, 119021.15it/s]


In [None]:
## doc pool

In [59]:
df_drop_title

Unnamed: 0,docid,title,content
0,0,crocodile dundee,Crocodile Dundee is a 1986 Australian comedy f...
1,1,heroes ( david bowie song ),"Heroes"""" is a song by English musician David B..."
2,2,weird al yankovic,"Alfred Matthew ""Weird Al"" Yankovic ( ; born Oc..."
3,3,wild bill hickok,"James Butler ""Wild Bill"" Hickok (May 27, 1837 ..."
4,4,' allo'allo!,Allo Allo! is a BBC television British sitcom ...
...,...,...,...
73912,73965,skoda roomster,The Škoda Roomster (Type 5J) is a MPV-styled l...
73913,73966,skoda superb,The current Škoda Superb is a large family car...
73914,73967,skoda yeti,"The Škoda Yeti (codenamed Typ 5L)ETKA, accesse..."
73915,73968,’ o sole mio,is a globally known Neapolitan song written i...


In [60]:
file_pool = open("Trivia_doc_content.tsv", 'w') 

for i in trange(len(df_drop_title)):
    file_pool.write('\t'.join([str(df_drop_title['docid'][i]), str(origin_new_id[df_drop_title['docid'][i]]), str(df_drop_title['title'][i]), str(df_drop_title['content'][i]), str(df_drop_title['title'][i]) + str(df_drop_title['content'][i])]) + '\n')
    file_pool.flush()

100%|██████████| 73917/73917 [00:11<00:00, 6685.29it/s]


## Generate BERT embeddings for each document

In [1]:
GPU_NUM = 8

In [59]:
## Execute the following command to get bert embedding pkl file
## NEED 8 GPU
!./bert/Trivia_bert.sh

In [23]:
output_bert_base_tensor = []
output_bert_base_id_tensor = []
for num in trange(GPU_NUM):
    with open(f'bert/pkl/Trivia_outpt_tensor_512_content_{num}.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_tensor.extend(data)

    with open(f'bert/pkl/Trivia_outpt_tensor_512_content_{num}_id.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_id_tensor.extend(data)


train_file = open(f"bert/Trivia_doc_content_embedding_bert_512.tsv", 'w') 

for idx, doc_tensor in enumerate(output_bert_base_tensor):
    embedding = '|'.join([str(elem) for elem in doc_tensor])
    train_file.write('\t'.join([str(output_bert_base_id_tensor[idx]), '', '', '', '', '', 'en', embedding]) + '\n')
    train_file.flush()

100%|██████████| 8/8 [00:08<00:00,  1.02s/it]


## Apply Hierarchical K-Means on it to generate semantic IDs

In [24]:
## Execute the following command to get kmeans id of the documents
## NEED 8 GPU
!./kmeans/kmeans_Trivia.sh

In [6]:
with open('kmeans/IDMapping_Trivia_bert_512_k30_c30_seed_7.pkl', 'rb') as f:
    kmeans_trivia_doc_dict = pickle.load(f)
## random id : newid
new_kmeans_trivia_doc_dict_512 = {}
for old_docid in kmeans_trivia_doc_dict.keys():
    new_kmeans_trivia_doc_dict_512[str(old_docid)] = '-'.join(str(elem) for elem in kmeans_trivia_doc_dict[old_docid])

new_kmeans_trivia_doc_dict_512_int_key = {}
for key in new_kmeans_trivia_doc_dict_512:
    new_kmeans_trivia_doc_dict_512_int_key[int(key)] = new_kmeans_trivia_doc_dict_512[key]

## Query Generation

#### Download docT5query to './qg/' from https://huggingface.co/castorini/doc2query-t5-base-msmarco, like './qg/doc2query-t5-base-msmarco/'

In [None]:
## Execute the following command to generate queries for the documents
## NEED 8 GPU
!./qg/Trivia_qg.sh

In [28]:
## merge parallel results
output_bert_base_tensor_qg = []
output_bert_base_id_tensor_qg = []
for num in trange(GPU_NUM):
    with open(f'qg/pkl/Trivia_outpt_tensor_512_content_64_15_{num}.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_tensor_qg.extend(data)

    with open(f'qg/pkl/Trivia_outpt_tensor_512_content_64_15_{num}_id.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_id_tensor_qg.extend(data)

100%|██████████| 8/8 [00:00<00:00, 48.45it/s]


In [29]:
qg_dict = {}
for i in trange(len(output_bert_base_tensor_qg)):
    if(output_bert_base_id_tensor_qg[i] not in qg_dict):
        qg_dict[output_bert_base_id_tensor_qg[i]] = [output_bert_base_tensor_qg[i]]
    else:
        qg_dict[output_bert_base_id_tensor_qg[i]].append(output_bert_base_tensor_qg[i])

100%|██████████| 1108755/1108755 [00:00<00:00, 2191365.98it/s]


## Genarate training data

In [31]:
train_query_docid = {}
for i in trange(len(df_train)):
    if(len(df_train['query'][i].split('\n')) == 1):
        train_query_docid[df_train['query'][i]] = [int(elem) for elem in df_train['docid'][i].split(',')]

file_train = open("train.tsv", 'w')

count = 0
for query in tqdm(train_query_docid.keys()):
    for i in range(len(train_query_docid[query])):
        id_ori = train_query_docid[query][i]
        new_id = origin_new_id[id_ori]
        file_train.write('\t'.join([query, str(id_ori), str(new_id), new_kmeans_trivia_doc_dict_512_int_key[int(new_id)]]) + '\n')
        file_train.flush()

100%|██████████| 61888/61888 [00:01<00:00, 37774.79it/s]
100%|██████████| 61888/61888 [00:05<00:00, 11848.73it/s]


In [11]:
val_query_docid[df_val['query'][i]]

[30198]

In [30]:
val_query_docid = {}
for i in trange(len(df_val)):
    if(len(df_val['query'][i].split('\n')) == 1):
        val_query_docid[df_val['query'][i]] = [int(elem) for elem in df_val['docid'][i].split(',')]

file_val = open("dev.tsv", 'w')

count = 0
for query in tqdm(val_query_docid.keys()):
    id_ori_ = []
    new_id_ = []
    kmeans_ = []
    for i in range(len(val_query_docid[query])):
        id_ori = str(val_query_docid[query][i])
        new_id = str(origin_new_id[int(id_ori)])
        id_ori_.append(id_ori)
        new_id_.append(new_id)
        kmeans_.append(new_kmeans_trivia_doc_dict_512_int_key[int(new_id)])
    
    id_ori_ = ','.join(id_ori_)
    new_id_ = ','.join(new_id_)
    kmeans_ = ','.join(kmeans_)
    
    file_val.write('\t'.join([query, str(id_ori_), str(new_id_), kmeans_]) + '\n')
    file_val.flush()

100%|██████████| 7993/7993 [00:00<00:00, 74024.43it/s]
100%|██████████| 7993/7993 [00:00<00:00, 20740.57it/s]


In [29]:
kmeans_

'6-29-3-1'

In [31]:
test_query_docid = {}
for i in trange(len(df_test)):
    if(len(df_test['query'][i].split('\n')) == 1):
        test_query_docid[df_val['query'][i]] = [int(elem) for elem in df_test['docid'][i].split(',')]

file_test = open("test.tsv", 'w')

count = 0
for query in tqdm(test_query_docid.keys()):
    id_ori_ = []
    new_id_ = []
    kmeans_ = []
    for i in range(len(test_query_docid[query])):
        id_ori = str(test_query_docid[query][i])
        new_id = str(origin_new_id[int(id_ori)])
        id_ori_.append(id_ori)
        new_id_.append(new_id)
        kmeans_.append(new_kmeans_trivia_doc_dict_512_int_key[int(new_id)])
    
    id_ori_ = ','.join(id_ori_)
    new_id_ = ','.join(new_id_)
    kmeans_ = ','.join(kmeans_)
    
    file_test.write('\t'.join([query, str(id_ori_), str(new_id_), kmeans_]) + '\n')
    file_test.flush()

100%|██████████| 7701/7701 [00:00<00:00, 75884.73it/s]
100%|██████████| 7701/7701 [00:00<00:00, 22174.56it/s]


In [None]:
QG_NUM = 15

In [37]:
qg_file = open("trivia_512_qg.tsv", 'w') 

for queryid in tqdm(qg_dict):
    for query in qg_dict[queryid][:QG_NUM]:
        qg_file.write('\t'.join([query, queryid, new_kmeans_trivia_doc_dict_512_int_key[int(queryid)]]) + '\n')
        qg_file.flush()

100%|██████████| 73917/73917 [00:47<00:00, 1549.99it/s]


In [52]:
df_drop_title['new_id'] = df_drop_title['docid'].map(origin_new_id)

df_drop_title['kmeas_id'] = df_drop_title['new_id'].map(new_kmeans_trivia_doc_dict_512_int_key)


df_drop_title['tc'] = df_drop_title['title'] + ' ' + df_drop_title['content']

df_drop_title_ = df_drop_title.loc[:, ['tc', 'docid', 'new_id', 'kmeas_id']]  

df_drop_title_.to_csv('trivia_title_cont.tsv', sep='\t', header=None, index=False, encoding='utf-8')

In [53]:
df_drop_title

Unnamed: 0,tc,docid,new_id,kmeas_id
0,crocodile dundee Crocodile Dundee is a 1986 Au...,0,0,11-23-23
1,"heroes ( david bowie song ) Heroes"""" is a song...",1,1,29-6-25-0
2,"weird al yankovic Alfred Matthew ""Weird Al"" Ya...",2,2,24-12-4-0
3,"wild bill hickok James Butler ""Wild Bill"" Hick...",3,3,3-6-21-0
4,' allo'allo! Allo Allo! is a BBC television Br...,4,4,26-19-15-0
...,...,...,...,...
73912,skoda roomster The Škoda Roomster (Type 5J) is...,73965,73912,13-20-10-2
73913,skoda superb The current Škoda Superb is a lar...,73966,73913,13-20-4-7
73914,skoda yeti The Škoda Yeti (codenamed Typ 5L)ET...,73967,73914,13-20-10-3
73915,’ o sole mio is a globally known Neapolitan s...,73968,73915,29-12-2-8


In [63]:
queryid_oldid_dict = {}
bertid_oldid_dict = {}
map_file = "trivia_title_cont.tsv"
with open(map_file, 'r') as f:
    for line in f.readlines():
        query, queryid, oldid, bert_k30_c30 = line.split("\t")
        queryid_oldid_dict[oldid] = queryid
        bertid_oldid_dict[oldid] = bert_k30_c30

train_file = "Trivia_doc_content.tsv"
doc_aug_file = open("trivia_doc_aug.tsv", 'w') 
with open(train_file, 'r') as f:
    for line in f.readlines():
        _, docid, _, _, content = line.split("\t")
        content = content.split(' ')
        add_num = max(0, len(content)-3000) / 3000
        for i in range(10+int(add_num)):
            begin = random.randrange(0, len(content))
            # if begin >= (len(content)-64):
            #     begin = max(0, len(content)-64)
            end = begin + 64 if len(content) > begin + 64 else len(content)
            doc_aug = content[begin:end]
            doc_aug = ' '.join(doc_aug)
            queryid = queryid_oldid_dict[docid]
            bert_k30_c30 = bertid_oldid_dict[docid]
            # doc_aug_file.write('\t'.join([doc_aug, str(queryid), str(docid), str(bert_k30_c30)]) + '\n')
            doc_aug_file.write('\t'.join([doc_aug, str(queryid), str(docid), str(bert_k30_c30)]))
            doc_aug_file.flush()