In [58]:
import os
import random
import json
import requests
import io
from zipfile import ZipFile
import argparse
import multiprocessing

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import pyserini
from pyserini.search import SimpleSearcher
from pyserini.dsearch import SimpleDenseSearcher

import transformers
# from transformers import set_seed
# set_seed(42)

from peft import LoraConfig
from transformers import (AutoTokenizer, 
                          AutoModelForCausalLM, 
                          BitsAndBytesConfig)

In [4]:
parser = argparse.ArgumentParser(description='Reranking with OPT & NFCorpus')

parser.add_argument('--model_name', type=str, default='facebook/opt-125m')
parser.add_argument('--dataset', type=str, default='nfcorpus')
parser.add_argument('--data_path', type=str, default='./collections/')
parser.add_argument('--seed',type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--max_len', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--use_cuda', type=bool, default=False)
parser.add_argument('--k', type=int, default=10, help='top k')
parser.add_argument('--k1', type=float, default=1.5, help='BM25 parameter')
parser.add_argument('--b', type=float, default=0.75, help='BM25 parameter')

parser.add_argument

config = parser.parse_args([])

### Prepare NFCorpus

In [5]:
dataset_path = os.path.join(config.data_path, config.dataset)
nfcorpus_url = 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip'
tsv_path = os.path.join(dataset_path, config.dataset, 'queries.tsv')
queries_jsonl_path = os.path.join(dataset_path, config.dataset, 'queries.jsonl')
corpus_jsonl_path = os.path.join(dataset_path, config.dataset, 'corpus.jsonl')
pyserini_jsonl_path = os.path.join(dataset_path, 'pyserini-corpus', 'corpus.jsonl')

In [5]:
# download data
if not os.path.exists(dataset_path):
    os.mkdir(dataset_path)

response = requests.get(nfcorpus_url, stream=True)
file = ZipFile(io.BytesIO(response.content))
file.extractall(path=dataset_path)

In [8]:
with open(tsv_path, 'w') as out:
    with open(queries_jsonl_path, 'r') as f:
        for line in f:
            l = json.loads(line)
            out.write(l['_id'] + '\t' + l['text'] + '\n')

In [11]:
# https://github.com/castorini/pyserini/blob/e371ed3661e90db6b797290493d973cb6c089c43/docs/conceptual-framework2.md
with open(pyserini_jsonl_path, 'w') as out:
    with open(corpus_jsonl_path, 'r') as f:
        for line in f:
            l = json.loads(line)
            s = json.dumps({'id': l['_id'], 'contents': l['title'] + ' ' + l['text']})
            out.write(s + '\n')

### Sparse Indexing & BM25

In [121]:
from typing import List
from pyserini.search.lucene import LuceneSearcher 

class BM25Retriever:
    def __init__(self, jsonl_path, index_path, k1=1.5, b=0.75):
        self.jsonl_path = jsonl_path
        if not os.path.exists(index_path):
            self.build_sparse_index(jsonl_path, index_path)
        self.searcher = LuceneSearcher(index_path)
        self.searcher.set_bm25(k1=k1, b=b)
        # self.searcher.set_language()
    
    def build_sparse_index(self, jsonl_path, index_path):
        execute_code = os.system('python -m pyserini.index.lucene ' + 
                                 '--collection JsonCollection ' +
                                 f'--input {jsonl_path} ' +
                                 f'--index {index_path} ' +
                                 '--generator DefaultLuceneDocumentGenerator ' +
                                 '--threads 1 --storeRaw')
        if execute_code != 0:
            raise Exception('Indexing Failed!')
        else:
            print('Indexing Success!')
            
    def _get_results(self, qid, hits:List):
        results = []
        
        for i, hit in enumerate(hits):
            docid = hit.docid
            content = json.loads(hits[i].raw)['contents']
            bm25_score = hit.score
            result = {'rank': i+1,
                      'qid': qid,
                      'docid': docid, 
                      'score': bm25_score,
                      'content': content}
            results.append(result)
            
        return results
    
    def search(self, qid, query_text:str, k:int=10):
        search_results = {}
        hits = self.searcher.search(query_text, k=k,)
        search_results['query'] = query_text
        search_results['hits']  = self._get_results(qid, hits)
        
        return search_results
    
    def batch_search(self, qids, query_texts: List[str], k:int=10):
        query_dict = dict(zip(qids, query_texts))
        batch_hits = self.searcher.batch_search(query_texts, qids, k=k, threads=multiprocessing.cpu_count())
        bsearch_results = []
        bsearch_items = {}

        for qid, hits in batch_hits.items():
            bsearch_items['query'] = query_dict[qid]
            bsearch_items['hits'] = self._get_results(qid, hits)
            bsearch_results.append(bsearch_items)
            bsearch_items = {} 
       
        return bsearch_results

In [122]:
index_path = os.path.join('./indexes', 'lucene-index.nfcorpus')
bm25_retriever = BM25Retriever('collections/nfcorpus/pyserini-corpus/', index_path)

In [136]:
def read_jsonl(jsonl_path:str):
    with open(jsonl_path, 'r') as f:
        lines = f.readlines()
        result = [json.loads(line) for line in lines]
        return result

In [124]:
queries = read_jsonl(queries_jsonl_path)
results = bm25_retriever.batch_search(qids=[query['_id'] for query in queries], query_texts=[query['text'] for query in queries])

In [178]:
json.dump(results, open('./retrieved/bm25-nfcorpus.jsonl','w'))

### Rerank

In [None]:
# class QueryDocumentBatch:
#     query: Query
#     documents: List[Text]
#     output: Optional[TokenizerReturnType] = None

#     def __len__(self):
#         return len(self.documents)

In [190]:
[hit['content'] for hit in results[0]['hits']]

['Advanced Glycation End Products in Foods and a Practical Guide to Their Reduction in the Diet Modern diets are largely heat-processed and as a result contain high levels of advanced glycation end products (AGEs). Dietary advanced glycation end products (dAGEs) are known to contribute to increased oxidant stress and inflammation, which are linked to the recent epidemics of diabetes and cardiovascular disease. This report significantly expands the available dAGE database, validates the dAGE testing methodology, compares cooking procedures and inhibitory agents on new dAGE formation, and introduces practical approaches for reducing dAGE consumption in daily life. Based on the findings, dry heat promotes new dAGE formation by >10- to 100-fold above the uncooked state across food categories. Animal-derived foods that are high in fat and protein are generally AGE-rich and prone to new AGE formation during cooking. In contrast, carbohydrate-rich foods such as vegetables, fruits, whole grain

In [116]:
from typing import List 


class GPTReranker:
    def __init__(self, model_name, k, use_cuda, max_len=128):
        self.model = self.load_model(model_name, use_cuda)
        self.tokenizer = self.load_tokenizer(model_name)
        self.k = k
        self.max_len = max_len
        
        self.model.eval()
    
    def load_model(self, model_name:str, use_cuda:bool):
        device = torch.device('cuda' if torch.cuda.is_available() & use_cuda else 'cpu')
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(device)
        model.config.use_cache=True
        return model
    
    def load_tokenizer(self, model_name:str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return tokenizer
    
    def get_chunks(self, corpus, chunk_size=128):
        chunks = [corpus[i:i + chunk_size] for i in range(0, len(corpus), chunk_size)]
        return chunks
    
    def score(self, corpus, query):
        scores = []
        
        segments = corpus.split(' ')
        chunks = self.get_chunks(segments)
        for chunk in chunks:
            prob = []
            cids = self.tokenizer.encode(chunk + ' ' + query, return_tensors='pt')
            qids = self.tokenizer.encode(query)
            corpus_len = len(cids[0]) - len(qids)
            with torch.no_grad():
                outputs = self.model(input_ids=cids)
                logits = outputs[0].squeeze(0)
                log_softmax = F.log_softmax(logits)
                for idx, qid in enumerate(qids):
                    prob.append(log_softmax[corpus_len + idx][qid])
            
            score = torch.sum(prob)
            scores.append(score)
            
        return max(scores)

In [72]:
gpt_reranker = GPTReranker(config.model_name, k=config.k, use_cuda=config.use_cuda)

In [125]:
qids = [query['_id'] for query in queries]
query_texts = [query['text'] for query in queries]

In [135]:
corpus = read_jsonl('./retrieved/bm25-nfcorpus.jsonl')

TypeError: sequence item 0: expected str instance, list found

In [133]:
corpus[0][0]

{'query': 'Advanced Glycation End-products',
 'hits': [{'rank': 1,
   'qid': 'PLAIN-493',
   'docid': 'MED-4554',
   'score': 9.557299613952637,
   'content': 'Advanced Glycation End Products in Foods and a Practical Guide to Their Reduction in the Diet Modern diets are largely heat-processed and as a result contain high levels of advanced glycation end products (AGEs). Dietary advanced glycation end products (dAGEs) are known to contribute to increased oxidant stress and inflammation, which are linked to the recent epidemics of diabetes and cardiovascular disease. This report significantly expands the available dAGE database, validates the dAGE testing methodology, compares cooking procedures and inhibitory agents on new dAGE formation, and introduces practical approaches for reducing dAGE consumption in daily life. Based on the findings, dry heat promotes new dAGE formation by >10- to 100-fold above the uncooked state across food categories. Animal-derived foods that are high in fat 

In [126]:
gpt_reranker.score(q)

NameError: name 'gpt' is not defined

In [86]:
instance = gpt_reranker.tokenizer(qids)
instance['input_ids'][0]

[2, 7205, 33178, 12, 246]

In [88]:
input_ids = torch.tensor(instance['input_ids'][0]).unsqueeze(0)
input_ids

tensor([[    2,  7205, 33178,    12,   246]])

In [89]:
token_type_ids = torch.tensor(instance['token_type_ids'][0]).unsqueeze(0)

KeyError: 'token_type_ids'

In [None]:
def get_inputs(item, device, tokenizer):
    input_ids = torch.tensor(['input_ids'], device=device).unsqueeze(0)
    input_ids = tokenizer.decode()
    input_ids, 
    token_type 
    attn_mask
    

In [72]:
import re

def normalize_answer(s):
    """
    Taken from the official evaluation script for v1.1 of the SQuAD dataset.
    Lower text and remove punctuation, articles and extra whitespace.
    """

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


In [None]:
# top_results = torch.topk(scores, k=5).indices 
# reranked_corpus = [corpus[i] for i in top_results] 

# scored_articles = zip(articles, cosine_similarities)

# # Sort articles by cosine similarity
# sorted_articles = sorted(scored_articles, key=lambda x: x[1], reverse=True)

# scores = []
# https://github.com/amazon-science/datatuner/blob/f70369659e1c58e6ddb44d6db467978679dbdd3c/src/datatuner/lm/reranker.py#L5 

# sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
        #     : self.top_n
        # ]

# reranked = reranker.rerank(query, texts)
#     reranked.sort(key=lambda x: x.score, reverse=True)
#     d['documents'] = [text.metadata for text in reranked]
#   results.append(item)
# 

# json.dump(results, open("iirc_reranked.json",'w'))

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
)

tokenizer = AutoTokenizer.from_pretrained(config.model)
model = AutoModelForCausalLM.from_pretrained(config.model,
                                             quantization_config=quantization_config,
                                             trust_remote_code=True,)

model.config.use_cache=True

In [None]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules = ["q_proj", "v_proj"],
    inference_mode=False,
    bias='none',
    task_type='CAUSAL_LM',
)

model = get