# Document Retrieval using Wikipedia Dataset

**General Approach:** Implemented 3 Bag-of-Words- and Embeddings-based models and picked the best for final retrieval. Drew reference from the framework provided here: https://medium.com/mlearning-ai/enhancing-information-retrieval-via-semantic-and-relevance-matching-64973ff81818

## 1. Pre-processing
### 1.1 Initialization

In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
from datasets import load_dataset
ds = load_dataset('wikipedia','20220301.simple',trust_remote_code=True)

In [91]:
import os
import pandas as pd
import numpy as np
from functools import partial
from tqdm import tqdm
from sklearn.model_selection import KFold, train_test_split
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs
import torch
import warnings

path = r'C:\Users\yongz\Personal Projects\ahrefs_task'
os.chdir(path)
tqdm = partial(tqdm,position=0,leave=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


### 1.2 Data Processing

In [12]:
train = pd.read_json('train.jsonl',lines=True)
doc_base = pd.DataFrame(ds['train'][:10000])
train = train.rename(columns={'article':'title'})
df = train.merge(doc_base[['title','text']],how='left',on='title')
df.head()

Unnamed: 0,question,points,title,text
0,how do living organisms in a natural environme...,57,Environment,Environment means anything that surrounds us. ...
1,what is the name of the poem written by juan r...,72,Marcha Real,La Marcha Real (English translation: The Royal...
2,what body parts can mri scans study?,45,Magnetic resonance imaging,"Magnetic resonance imaging (MRI), or nuclear m..."
3,what are the names of the 12 boroughs in berli...,46,Boroughs of Berlin,The German capital Berlin is subdivided into 1...
4,what was the cause of charles dickens' death?,87,Heinrich Rudolf Hertz,Heinrich Rudolf Hertz (22 February 1857 – 1 Ja...


In [13]:
doc_base = doc_base.rename(columns={'text':'gold_passage'})
doc_base.head()

Unnamed: 0,id,url,title,gold_passage
0,1,https://simple.wikipedia.org/wiki/April,April,April is the fourth month of the year in the J...
1,2,https://simple.wikipedia.org/wiki/August,August,August (Aug.) is the eighth month of the year ...
2,6,https://simple.wikipedia.org/wiki/Art,Art,Art is a creative activity that expresses imag...
3,8,https://simple.wikipedia.org/wiki/A,A,A or a is the first letter of the English alph...
4,9,https://simple.wikipedia.org/wiki/Air,Air,Air refers to the Earth's atmosphere. Air is a...


In [14]:
# Test for NA values -> impute if necessary
def NA_test(df):
    for c in df.columns:
        if df[c].isnull().values.any():
            print('null',c)
    print('test done')
NA_test(df)

test done


In [64]:
# Utils
n_docs = 10

def gen_splits(df):
    X, X_test = train_test_split(df,test_size=0.2,random_state=27)
    X_train, X_val = train_test_split(df,test_size=0.2,random_state=27)
    return X_train, X_val, X_test

'''
For model evaluation functions, use EITHER ONE of the functions below
'''

def compute_simple_accuracy(query_df,predicted_passages):
    '''
    query_df: DataFrame['query_text','title','gold_passage']
    predicted_passages: List[List[str]] (query x top k docs per query)
    '''
    query_df['retrieved_passage'] = [docs[0] for docs in predicted_passages]
    query_df['correct'] = query_df.apply(lambda row: row.gold_passage == row.retrieved_passage,axis=1)
    accuracy = len(query_df[query_df['correct']==True])/len(query_df)
    return accuracy

def compute_performance(query_df,predicted_passages,top_k=n_docs):
    '''
    n_docs should be >= 10
    Metrics covered:
    - accuracy: top 1, 3, 5, 10
    - mrr: (1/N)Sum(1/rank_i)
    '''
    query_df['retrieved_passages'] = predicted_passages
    def compute_rank(row):
        if row.gold_passage in row.retrieved_passages:
            return row.retrieved_passages.index(row.gold_passage) + 1
        else: return 0
    query_df['rank'] = query_df.apply(lambda row:compute_rank(row),axis=1)
#     print(query_df['rank'].head(20))
    for n in (1,3,5,10):
        query_df[f'top_{n}'] = query_df.apply(lambda row:row['rank']>0 and row['rank']<=n,axis=1)
#     print(query_df['top_3'].head(20),query_df['top_10'].head(20))
    query_df['mrr_component'] = query_df.apply(lambda row:1/row['rank'] if row['rank']>0 else 0,axis=1)
    top_1_accuracy = query_df['top_1'].sum()/len(query_df)
    top_3_accuracy = query_df['top_3'].sum()/len(query_df)
    top_5_accuracy = query_df['top_5'].sum()/len(query_df)
    top_10_accuracy = query_df['top_10'].sum()/len(query_df)
    mrr = query_df['mrr_component'].sum()/len(query_df)
    results_dict = {
        'top 1 accuracy': top_1_accuracy,
        'top 3 accuracy': top_3_accuracy,
        'top 5 accuracy': top_5_accuracy,
        'top 10 accuracy': top_10_accuracy,
        'mrr': mrr
    }
    return results_dict

In [16]:
# Final preparation of training, validation and test sets
'''
doc_base columns: id (not used), url, title, gold_passage
data split columns: query_text, title, gold_passage
'''
df1 = df.rename(columns={'question':'query_text','text':'gold_passage'})
data = df1[['query_text','title','gold_passage']]
x_train, x_val, x_test = gen_splits(data)
x_train = x_train.reset_index(drop=True)
x_val = x_val.reset_index(drop=True)
x_test = x_test.reset_index(drop=True)
queries_train = x_train['query_text'].tolist()
queries_test = x_test['query_text'].tolist()
x_train.head()

Unnamed: 0,query_text,title,gold_passage
0,what is the title of the french nobility that ...,Cardinal Richelieu,"Armand Jean du Plessis, better known as Cardin..."
1,what is the therapeutic index of warfarin?,Pharmacology,Pharmacology is the study of how medicine and ...
2,what is the top-selling 2010s automobile in th...,Fiat Ulysse,The Fiat Ulysse was a large car with seven sea...
3,when did pope innocent i succeed pope anastasi...,401,401 (CDI) was a common year starting on Tuesda...
4,what is the name of euterpe's son according to...,Muse,For the British musical group of the same name...


## 2. Model Training & Evaluation
**NOTE:** All models to be evaluated using test set (i.e. x_test)

### 2.1 TF-IDF Retrieval (Baseline)

In [49]:
from langchain.schema import Document
from langchain_core.retrievers import BaseRetriever
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import Any, Dict, Iterable, List, Optional

In [92]:
n_docs_tfidf = n_docs

class Custom_TFIDFRetriever(BaseRetriever):
    vectorizer: Any
    tfidf_array: Any
    docs: Any
    
    class Config:
        arbitary_types_allowed = True
    
    @classmethod
    def from_documents(cls,
        documents: Iterable[Document],
        tfidf_params: Optional[Dict[str, Any]] = None,
    ):
        vectorizer = TfidfVectorizer()
        docs = [doc.page_content for doc in documents]
        tfidf_array = vectorizer.fit_transform(docs)
        return cls(vectorizer=vectorizer,docs=docs,tfidf_array=tfidf_array)
    
    def get_relevant_documents(
        self,
        queries: Iterable[str],
        n_docs = n_docs_tfidf
    ):
        return_docs = []
        for i in tqdm(queries):
            query_vec = self.vectorizer.transform([i])
            results = cosine_similarity(self.tfidf_array,query_vec).reshape((-1,))
            return_docs_i = [self.docs[i] for i in results.argsort()[-n_docs:][::-1]]
            return_docs.append(return_docs_i)
        return return_docs        

In [57]:
queries_test = x_test['query_text'].tolist()
doc_list = [Document(page_content=doc_base.at[i,'gold_passage']) for i in range(len(doc_base))]
retriever = Custom_TFIDFRetriever().from_documents(doc_list)
tfidf_passages = retriever.get_relevant_documents(queries_test)

100%|██████████████████████████████████████████████████████████████████████████████| 4353/4353 [02:21<00:00, 30.82it/s]


In [53]:
# Evaluate model
tfidf_df = x_test.copy()
# tfidf_accuracy = compute_simple_accuracy(tfidf_df,tfidf_passages)
# print(tfidf_accuracy) # 0.51
tfidf_results = compute_performance(tfidf_df,tfidf_passages)
print(tfidf_results)

{'top 1 accuracy': 0.5171146335860326, 'top 3 accuracy': 0.6772340914311968, 'top 5 accuracy': 0.726165862623478, 'top 10 accuracy': 0.7783138065701815, 'mrr': 0.6071221635143069}

{'top 1 accuracy': 0.5171146335860326, 'top 3 accuracy': 0.5171146335860326, 'top 5 accuracy': 0.5171146335860326, 'top 10 accuracy': 0.5171146335860326, 'mrr': 0.5171146335860326}


### 2.2 Deep Passage Retrieval
**Notes on Implementation:** Requires edit(s) to SimpleTransformers repo, as non-training actions are not supported by the max_seq_length limit:
- retrieval/retrieval_utils.py: add the following arguments to embed():
    - padding ='max_length', max_length = 512

In [94]:
# Init
n_docs_dpr = n_docs
Ep = "facebook/dpr-ctx_encoder-single-nq-base"
Eq = "facebook/dpr-question_encoder-single-nq-base"

In [23]:
# Utils
def parallelize_model(model,device_ids=None):
    '''
    device_ids: List[int]
    '''
    if torch.cuda.device_count() > 1:
        model.args.train_batch_size *= torch.cuda.device_count()
        model = torch.nn.DataParallel(model,device_ids=device_ids)
    return model

def build_dpr_model(mode='train',hard_negatives=False,model_path=None):
    '''
    mode = 'train' or 'load'
    hard_negatives = True -> train with hard_negatives
    The other arguments can be toggled directly within the function
    '''
    model_args = RetrievalArgs(
        max_seq_length = 512,
        overwrite_output_dir = True,
        train_batch_size = 16
    )
    if mode == 'train':
        model = RetrievalModel(
            model_type='dpr',
            context_encoder_name = Ep,
            query_encoder_name = Eq,
            args = model_args,
            use_cuda = torch.cuda.is_available(),
            hard_negatives = True
        )
    elif mode == 'load':
        model = RetrievalModel(
            model_type = 'dpr',
            model_name = model_path
        )
    else: raise ValueError ("Input a valid mode. Must be 'train' or 'load'.")
    model = parallelize_model(model)
    return model

##### Training Round 1: Without Hard Negatives

In [None]:
dpr_model1 = build_dpr_model(mode='train')
dpr_model1.train_model(x_train,eval_data=x_val,
                 output_dir = 'DPR_model/',
                 show_running_loss=True)

**DPR Checkpoint 1** - After the first training run above, skip the above cell and proceed directly to the following cells:

In [None]:
# Load trained model, for convenience
dpr_model1_path = 'DPR_model/checkpoint-2176-epoch-1'
dpr_model1 = build_dpr_model(mode='load',model_path=dpr_model1_path)

In [37]:
# Evaluate model
predicted_passages, _, _, _ = dpr_model1.predict(
    to_predict = x_test['query_text'].tolist(),
    prediction_passages = doc_base['gold_passage'].tolist(),
    retrieve_n_docs = n_docs_dpr
)
dpr_df = x_test.copy()
# dpr_accuracy1 = compute_simple_accuracy(dpr_df,predicted_passages)
# print(dpr_accuracy1) # 0.42706179646221
dpr_results1 = compute_performance(dpr_df,predicted_passages)
print(dpr_results1)

##### Training Round 2: With Generated Hard Negatives

In [None]:
def row_hard_negative(row,k):
    if row.gold_passage != row.retrieved_passage[-k]:
        return row.retrieved_passage[-k]
    else:
        return row_hard_negative(row,-k+1)

def generate_hard_negatives(train_df,passage_base,train_predictions):
    train_df['retrieved_passage'] = train_predictions # or first item per list in list
    train_df['hard_negative'] = train_df.apply(lambda row:row_hard_negative(row,n_docs_dpr),axis=1)
    return train_df.drop(columns=['retrieved_passage'])

In [None]:
# Generate hard_negatives for subsequent training
predicted_passages_hn, _, _, _ = dpr_model1.predict(
    to_predict = x_train['query_text'].tolist(),
    prediction_passages = doc_base['gold_passage'].tolist(),
    retrieve_n_docs = n_docs_dpr
)

x_train_hn = generate_hard_negatives(x_train,doc_base,predicted_passages_hn)
x_train_hn.to_json(r'SBERT/x_train_hn.jsonl',orient='records',lines=True)

**DPR Checkpoint 2** - For the 2nd training round (with hard negatives), start from this cell right after #Init and #Utils (if/when kernel is restarted)

In [18]:
x_train_hn = pd.read_json('SBERT/x_train_hn.jsonl',lines=True)
x_train_hn.head()

dpr_model2 = build_dpr_model(mode='train',hard_negatives=True)
dpr_model2.train_model(x_train_hn,eval_data=x_val,
                       output_dir = 'DPR_model/',
                       show_running_loss = True,
                       hard_negatives = True
                      )   

  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this functi

Map:   0%|          | 0/17408 [00:00<?, ? examples/s]

Map:   0%|          | 0/17408 [00:00<?, ? examples/s]

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Running Epoch 0 of 1:   0%|          | 0/2176 [00:00<?, ?it/s]

  softmax_score = torch.nn.functional.log_softmax(similarity_score, dim=-1)


(2176, 0.6494938274765386)

**DPR Checkpoint 3** - After the second training run with hard negatives, start all DPR activity from the following cell right after #Init and #Utils (if/when kernel is restarted)

In [24]:
# Load trained model, for convenience
dpr_model2_path = 'DPR_model/checkpoint-2176-epoch-1'
dpr_model2 = build_dpr_model(mode='load',model_path=dpr_model2_path)

In [65]:
# Evaluate model
predicted_passages2, _, _, _ = dpr_model2.predict(
    to_predict = x_test['query_text'].tolist(),
    prediction_passages = doc_base['gold_passage'].tolist(),
    retrieve_n_docs = n_docs_dpr
)

dpr_df = x_test.copy()
# dpr_accuracy2 = compute_simple_accuracy(dpr_df,predicted_passages2)
# print(dpr_accuracy2)
dpr_results2 = compute_performance(dpr_df,predicted_passages2)
print(dpr_results2)
{'top 1 accuracy': 0.41557546519641625, 'top 3 accuracy': 0.5662761314036296, 'top 5 accuracy': 0.6250861474844934, 'top 10 accuracy': 0.6944635883298874, 'mrr': 0.5053413263613126}

{'top 1 accuracy': 0.41557546519641625, 'top 3 accuracy': 0.5662761314036296, 'top 5 accuracy': 0.6250861474844934, 'top 10 accuracy': 0.6944635883298874, 'mrr': 0.5053413263613126}


### 2.3 S-BERT

In [169]:
import sys
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.losses import MultipleNegativesRankingLoss as MNRL
from torch.utils.data import DataLoader,Dataset,SequentialSampler
from nltk import sent_tokenize
import nltk
import faiss
from statistics import multimode
import itertools

In [39]:
class SBERT_Dataset(Dataset):
    def __init__(self,df):
        self.df = df
    
    def __len__(self):
        return(len(self.df))
    
    def __getitem__(self,idx):
        row = self.df.loc[idx,'gold_passage']
        if not isinstance(row,str):
            text = []
            for i in row.tolist():
                text.extend(sent_tokenize(i))
        else: text = sent_tokenize(row)
        return text

def SBERT_tokenizer(batch):
    return batch

In [127]:
# Functions to generate embeddings
def get_embeddings(model,dataloader):
    data_array = torch.empty((len(dataloader.dataset),768))
    batch_size = dataloader.batch_size
    for batch_idx, data in enumerate(dataloader): #sentence batches
        model_output = model.encode(data)
        data_array[batch_idx*batch_size: (batch_idx*batch_size) + len(model_output)] = torch.from_numpy(model_output)
    return data_array

def save_embeddings(embeddings,directory,filename):
    if os.path.isdir(directory) == False:
        os.mkdir(directory)
    filename = filename + '.pt'
    torch.save(embeddings,os.path.join(directory,filename))

def generate_embeddings(model,tokenizer,dataset,directory,filename):
    embeddings = torch.Tensor()
    if os.path.isfile(os.path.join(directory,'sent2doc_hash.csv')) == False:
        '''
        sent2doc_hash
        index: doc idx
        value: first sentence idx
        '''
        sent2doc_hash = pd.Series()
        for i in tqdm(range(len(dataset))): #docs
            sent2doc_hash = pd.concat([sent2doc_hash,pd.Series(embeddings.shape[0])],ignore_index=True)
            dataloader = DataLoader(dataset[i], batch_size=4, shuffle=False, pin_memory=True,  
                                    sampler = SequentialSampler(dataset[i]),
                                    drop_last=False, collate_fn=tokenizer)
            embeddings_1 = get_embeddings(model, dataloader)
            embeddings = torch.cat((embeddings,embeddings_1),0)
        save_embeddings(embeddings, directory, filename) 
        sent2doc_hash.to_csv(os.path.join(directory,'sent2doc_hash.csv'))
    else: 
        for i in tqdm(range(len(dataset))): #docs
            dataloader = DataLoader(dataset[i], batch_size=4, shuffle=False, pin_memory=True,  
                                    sampler = SequentialSampler(dataset[i]),
                                    drop_last=False, collate_fn=tokenizer)
            embeddings_1 = get_embeddings(model, dataloader)
            embeddings = torch.cat((embeddings,embeddings_1),0)
        save_embeddings(embeddings, directory, filename) 
# Sentence to doc mapping
def generate_sent2doc(sent_embeddings): # assumes that hash file has been saved
    sent2doc_hash = pd.read_csv('SBERT/sent2doc_hash.csv')
    sent2doc_hash = sent2doc_hash.set_index('Unnamed: 0',drop=True)
    sent2doc_hash.index.names = ['']
    sent2doc_hash = sent2doc_hash.squeeze()
    sent2doc = pd.DataFrame(index=(range(sent_embeddings.shape[0]))) # index: sentence idx
    sent2doc['doc_id'] = (pd.Series(sent2doc.index.map(dict(zip(sent2doc_hash,sent2doc_hash.index))),
                               index=sent2doc.index).ffill())
    sent2doc['doc_id'] = pd.to_numeric(sent2doc['doc_id'],downcast='integer')
    '''
    index: sentence idx
    value: doc idx ('doc_id')
    '''
    return sent2doc 

In [66]:
# Init
SBERT_model = SentenceTransformer('stsb-distilbert-base')
SBERT_model.to(device)
SBERT_directory = 'SBERT'
filename = 'SBERT_embeddings_v1'
passages = SBERT_Dataset(doc_base)
n_docs_sbert = n_docs

##### Round 1: Out-of-the-box use

In [None]:
# Generate and load embeddings
generate_embeddings(SBERT_model,SBERT_tokenizer,passages,SBERT_directory,filename)
'''
Comment out the above line after the running it once
'''
sent_embeddings = torch.load(os.path.join(SBERT_directory,filename+'.pt'))

In [189]:
class FAISS_DocSearch:
    def __init__(self,model,pretok_doc_dataset,sent_embeddings):
        sent2doc_hash = pd.read_csv('SBERT/sent2doc_hash.csv')
        sent2doc_hash = sent2doc_hash.set_index('Unnamed: 0',drop=True)
        sent2doc_hash.index.names = ['']
        self.sent2doc_hash = sent2doc_hash
        self.sent2doc = generate_sent2doc(sent_embeddings)
        index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
        sent_embeddings = sent_embeddings.numpy()
        faiss.normalize_L2(sent_embeddings)
        index.add_with_ids(sent_embeddings,np.array(range(0,sent_embeddings.shape[0])))
        self.model = model
        self.data = pretok_doc_dataset
        self.index = index
        self.sent_embeddings = sent_embeddings
#         gpu = faiss.StandardGpuResources() # faiss-gpu not supported for Windows
#         gpu_index = faiss.index_cpu_to_gpu(gpu,0,index)
#         self.index = gpu_index

    def search_doc(self,query,k=n_docs_sbert):
        query_vector = self.model.encode([query])
        faiss.normalize_L2(query_vector)
        top_k = self.index.search(query_vector,k)
        doc_ids = [self.sent2doc.at[idx,'doc_id'] for idx in top_k[1].tolist()[0]]
        return self.data.loc[doc_ids,'gold_passage'].tolist()

    def search_sentence(self,query,posttok_doc_dataset,full_train_dataset,k=5):
        '''
        For training. 
        sent_embeddings: torch.Tensor() -> Sentence embeddings from the corresponding gold_passage
        '''
        # Create faiss index of sentences for the corresponding document
            # query -> doc-> doc_idx -> get sentence _ids from sent2doc_hash
        doc_id = self.data.index[self.data['gold_passage']==full_train_dataset.loc[full_train_dataset['query_text']==query,'gold_passage'].to_numpy()[0]][0]
        if doc_id < len(self.data)-1:
            sent_emb_ids = list(range(self.sent2doc_hash.at[doc_id,'0'],self.sent2doc_hash.at[doc_id+1,'0']))
        else: sent_emb_ids = list(range(self.sent2doc_hash.at[doc_id,'0'],len(self.data))) # list of sentence_ids for doc
        index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
        index.add_with_ids(self.sent_embeddings[sent_emb_ids],np.array(range(0,len(sent_emb_ids))))
        
        # Generate top sentences from document based on query
        query_vector = self.model.encode([query])
        faiss.normalize_L2(query_vector)
        top_k = index.search(query_vector,k)
        sent_ids = top_k[1].tolist()[0]
        sents = [posttok_doc_dataset[doc_id][sent_id] for sent_id in sent_ids]
        return sents

In [78]:
SBERT_search1 = FAISS_DocSearch(SBERT_model,doc_base,sent_embeddings)
sbert_predictions = []
for i in tqdm(queries_test):
    results = SBERT_search1.search_doc(i)
    sbert_predictions.append(results)

sbert_df = x_test.copy()
# sbert_accuracy = compute_simple_accuracy(sbert_df,sbert_predictions)
# print(sbert_accuracy)
sbert_results = compute_performance(sbert_df,sbert_predictions)
print(sbert_results)

{'top 1 accuracy': 0.35377900298644616, 'top 3 accuracy': 0.4985067769354468, 'top 5 accuracy': 0.5580059728922582, 'top 10 accuracy': 0.6292212267401792, 'mrr': 0.44139746717279454}

100%|██████████████████████████████████████████████████████████████████████████████| 4353/4353 [04:57<00:00, 14.62it/s]


{'top 1 accuracy': 0.35377900298644616, 'top 3 accuracy': 0.4985067769354468, 'top 5 accuracy': 0.5580059728922582, 'top 10 accuracy': 0.6292212267401792, 'mrr': 0.44139746717279454}


##### Round 2: Fine-tuning with Wikipedia Dataset

In [203]:
class SBERT_New_Dataset(Dataset):
    def __init__(self,df):
        '''
        df: List[List[InputExample]] -> List of sentence pairs, wh
        InputExample is a class under sentence_transformers (1 per sentence)
        '''
        self.df = df
    
    def __len__(self):
        return(len(self.df))
    
    def __getitem__(self,idx):
        return self.df[idx]

In [197]:
# Generate dataset for fine-tuning
SBERT_search = FAISS_DocSearch(SBERT_model,doc_base,sent_embeddings)
eg_list = []
for i in tqdm(queries_train):
    sentences = SBERT_search.search_sentence(i,passages,x_train)
    query_repeat = [i]*n_docs_sbert
    eg_list_i = [InputExample(texts=list(x)) for x in zip(query_repeat,sentences)]
    eg_list.extend(eg_list_i)
eg_list = list(x for x,_ in itertools.groupby(eg_list))

100%|████████████████████████████████████████████████████████████████████████████| 17408/17408 [08:04<00:00, 35.89it/s]


In [205]:
# Train with new dataset
sbert_train_data = SBERT_New_Dataset(eg_list)
sbert_train_dataloader = DataLoader(sbert_train_data,shuffle=False,batch_size=8)    
sbert_train_loss = MNRL(model=SBERT_model)
SBERT_model.fit(train_objectives=[(sbert_train_dataloader,sbert_train_loss)],epochs=1)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10880 [00:00<?, ?it/s]

In [206]:
# Generate and load embeddings
filename2 = 'SBERT_embeddings_v2'
generate_embeddings(SBERT_model,SBERT_tokenizer,passages,SBERT_directory,filename2)
'''
Comment out the above line after running it once
'''
sent_embeddings2 = torch.load(os.path.join(SBERT_directory,filename2+'.pt'))

100%|████████████████████████████████████████████████████████████████████████████| 10000/10000 [36:29<00:00,  4.57it/s]


In [210]:
# Evaluate model
SBERT_search = FAISS_DocSearch(SBERT_model,doc_base,sent_embeddings2)
sbert_predictions2 = []
for i in tqdm(queries_test):
    results = SBERT_search.search_doc(i)
    sbert_predictions2.append(results)

sbert_df = x_test.copy()
# sbert_accuracy2 = compute_simple_accuracy(sbert_df,sbert_predictions2)
# print(sbert_accuracy2)
sbert_results2 = compute_performance(sbert_df,sbert_predictions2)
print(sbert_results2)
{'top 1 accuracy': 0.3496439237307604, 'top 3 accuracy': 0.48334481966459913, 'top 5 accuracy': 0.5375603032391454, 'top 10 accuracy': 0.5910866069377441, 'mrr': 0.4284863750232461}

100%|██████████████████████████████████████████████████████████████████████████████| 4353/4353 [04:47<00:00, 15.14it/s]


{'top 1 accuracy': 0.3496439237307604, 'top 3 accuracy': 0.48334481966459913, 'top 5 accuracy': 0.5375603032391454, 'top 10 accuracy': 0.5910866069377441, 'mrr': 0.4284863750232461}


## 3. Consolidation
Generate final predictions and export results

In [212]:
# Compare performance
tfidf_results['model'] = 'tfidf'
dpr_results2['model'] = 'dpr'
sbert_results2['model'] = 'sbert'
model_results = pd.DataFrame(tfidf_results,index=[0])
model_results = pd.concat([model_results,pd.DataFrame(dpr_results2,index=[1])],ignore_index=True)
model_results = pd.concat([model_results,pd.DataFrame(sbert_results2,index=[2])],ignore_index=True)
cols = model_results.columns.tolist()
cols = cols[-1:] + cols[:-1]
model_results = model_results[cols]
model_results.head()

Unnamed: 0,model,top 1 accuracy,top 3 accuracy,top 5 accuracy,top 10 accuracy,mrr
0,tfidf,0.517115,0.677234,0.726166,0.778314,0.607122
1,dpr,0.415575,0.566276,0.625086,0.694464,0.505341
2,sbert,0.349644,0.483345,0.53756,0.591087,0.428486


In [111]:
final_test = pd.read_json('test.jsonl',lines=True)
queries_final = final_test['question'].tolist()

In [112]:
# Build model and generate predictions: TFIDF
doc_list = [Document(page_content=doc_base.at[i,'gold_passage']) for i in range(len(doc_base))]
retriever = Custom_TFIDFRetriever().from_documents(doc_list)
predictions_final = retriever.get_relevant_documents(queries_final)

100%|██████████████████████████████████████████████████████████████████████████████| 5468/5468 [02:54<00:00, 31.29it/s]


In [113]:
# Organize results
'''
final_test: DataFrame(columns=['question','points'])
predictions_final: List[List[str]]
submission.jsonl
'''
final_test['gold_passage'] = [docs[0] for docs in predictions_final]
final_test = final_test.merge(doc_base[['title','gold_passage']],how='left',on='gold_passage')
if final_test['title'].isnull().values.any(): print('null exists')
final_test = final_test.drop(columns=['gold_passage']).rename(columns={'title':'article'})
final_test.head(10)

Unnamed: 0,question,points,article
0,what is the percentage of the population of gu...,53,Guernsey
1,what is the traditional method of making vineg...,50,Vinegar
2,what were the primary ways that medieval towns...,54,Early Middle Ages
3,what are the names of the nine native american...,56,South Dakota
4,who was the young woman who was the inspiratio...,53,Leoš Janáček
5,what is the former capital of the duchy of lor...,66,Capital city
6,what is the process of pattern welding used to...,83,Welding
7,what is the total population of sydney includi...,56,Sydney
8,what is the name of the national recreation ar...,76,Lake Chaubunagungamaug
9,what is the best way we have of understanding ...,49,Bioinformatics
