In [2]:
import os
import random
import json
import requests
import io
from zipfile import ZipFile
import argparse
import multiprocessing

import numpy as np
import pandas as pd

import torch

import pyserini
from pyserini.search import SimpleSearcher
from pyserini.dsearch import SimpleDenseSearcher

import transformers
# from transformers import set_seed
# set_seed(42)

from peft import LoraConfig
from transformers import (AutoTokenizer, 
                          AutoModelForCausalLM, 
                          BitsAndBytesConfig)

In [3]:
parser = argparse.ArgumentParser(description='Reranking with OPT & NFCorpus')

parser.add_argument('--model_name', type=str, default='facebook/opt-125m')
parser.add_argument('--dataset', type=str, default='nfcorpus')
parser.add_argument('--data_path', type=str, default='./collections/')
parser.add_argument('--seed',type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--max_len', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--use_cuda', type=bool, default=False)
parser.add_argument('--k', type=int, default=10, help='top k')
parser.add_argument('--k1', type=float, default=1.5, help='BM25 parameter')
parser.add_argument('--b', type=float, default=0.75, help='BM25 parameter')

parser.add_argument

config = parser.parse_args([])

### Prepare NFCorpus

In [5]:
dataset_path = os.path.join(config.data_path, config.dataset)
nfcorpus_url = 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip'
tsv_path = os.path.join(dataset_path, config.dataset, 'queries.tsv')
queries_jsonl_path = os.path.join(dataset_path, config.dataset, 'queries.jsonl')
corpus_jsonl_path = os.path.join(dataset_path, config.dataset, 'corpus.jsonl')
pyserini_jsonl_path = os.path.join(dataset_path, 'pyserini-corpus', 'corpus.jsonl')

In [5]:
# download data
if not os.path.exists(dataset_path):
    os.mkdir(dataset_path)

response = requests.get(nfcorpus_url, stream=True)
file = ZipFile(io.BytesIO(response.content))
file.extractall(path=dataset_path)

In [8]:
with open(tsv_path, 'w') as out:
    with open(queries_jsonl_path, 'r') as f:
        for line in f:
            l = json.loads(line)
            out.write(l['_id'] + '\t' + l['text'] + '\n')

In [11]:
# https://github.com/castorini/pyserini/blob/e371ed3661e90db6b797290493d973cb6c089c43/docs/conceptual-framework2.md
with open(pyserini_jsonl_path, 'w') as out:
    with open(corpus_jsonl_path, 'r') as f:
        for line in f:
            l = json.loads(line)
            s = json.dumps({'id': l['_id'], 'contents': l['title'] + ' ' + l['text']})
            out.write(s + '\n')

### Sparse Indexing & BM25

In [6]:
from typing import List
from pyserini.search.lucene import LuceneSearcher 

class BM25Retriever:
    def __init__(self, jsonl_path, index_path, k1=1.5, b=0.75):
        self.jsonl_path = jsonl_path
        if not os.path.exists(index_path):
            self.build_sparse_index(jsonl_path, index_path)
        self.searcher = LuceneSearcher(index_path)
        self.searcher.set_bm25(k1=k1, b=b)
        # self.searcher.set_language()
    
    def build_sparse_index(self, jsonl_path, index_path):
        execute_code = os.system('python -m pyserini.index.lucene ' + 
                                 '--collection JsonCollection ' +
                                 f'--input {jsonl_path} ' +
                                 f'--index {index_path} ' +
                                 '--generator DefaultLuceneDocumentGenerator ' +
                                 '--threads 1 --storeRaw')
        if execute_code != 0:
            raise Exception('Indexing Failed!')
        else:
            print('Indexing Success!')
            
    def _get_results(self, qid, hits:List):
        results = []
        
        for i, hit in enumerate(hits):
            docid = hit.docid
            content = json.loads(hits[i].raw)['contents']
            bm25_score = hit.score
            result = {'rank': i,
                      'qid': qid,
                      'docid': docid, 
                      'bm25_score': bm25_score,
                      'content': content}
            results.append(result)
            
        return results
    
    def search(self, qid, query_text:str, k:int=10):
        hits = self.searcher.search(query_text, k=k,)
        search_results  = self._get_results(qid, hits)
        
        return search_results
    
    def batch_search(self, qids, query_texts: List[str], k:int=10):
        batch_hits = self.searcher.batch_search(query_texts, qids, k=k, threads=multiprocessing.cpu_count())
        bsearch_results = {}
        
        for qid, hits in batch_hits.items():
            bsearch_results[qid] = self._get_results(qid, hits)
        
        return bsearch_results

In [7]:
index_path = os.path.join('./indexes', 'lucene-index.nfcorpus')
bm25_retriever = BM25Retriever('collections/nfcorpus/pyserini-corpus/', index_path)

In [8]:
def read_jsonl(jsonl_path:str):
    with open(jsonl_path, 'r') as f:
        lines = f.readlines()
        result = [json.loads(line) for line in lines]
        return result

In [25]:
queries = read_jsonl(queries_jsonl_path)
queries[0]

{'_id': 'PLAIN-3',
 'text': 'Breast Cancer Cells Feed on Cholesterol',
 'metadata': {'url': 'http://nutritionfacts.org/2015/07/14/breast-cancer-cells-feed-on-cholesterol/'}}

In [40]:
results = []
for query in (queries[:3]):
    result = bm25_retriever.search(qid=query['_id'], query_text=query['text'])
    results.append(result)

results

[[{'rank': 0,
   'qid': 'PLAIN-3',
   'docid': 'MED-2434',
   'bm25_score': 6.411600112915039,
   'content': 'High ACAT1 expression in estrogen receptor negative basal-like breast cancer cells is associated with LDL-induced proliferation. The specific role of dietary fat in breast cancer progression is unclear, although a low-fat diet was associated with decreased recurrence of estrogen receptor alpha negative (ER(-)) breast cancer. ER(-) basal-like MDA-MB-231 and MDA-MB-436 breast cancer cell lines contained a greater number of cytoplasmic lipid droplets compared to luminal ER(+) MCF-7 cells. Therefore, we studied lipid storage functions in these cells. Both triacylglycerol and cholesteryl ester (CE) concentrations were higher in the ER(-) cells, but the ability to synthesize CE distinguished the two types of breast cancer cells. Higher baseline, oleic acid- and LDL-stimulated CE concentrations were found in ER(-) compared to ER(+) cells. The differences corresponded to greater mRNA a

In [42]:
results = bm25_retriever.batch_search(qids=[query['_id'] for query in queries[:3]], query_texts=[query['text'] for query in queries[:3]])
results['PLAIN-4'][0]

{'rank': 0,
 'qid': 'PLAIN-4',
 'docid': 'MED-2646',
 'bm25_score': 8.17710018157959,
 'content': 'Do fast foods cause asthma, rhinoconjunctivitis and eczema? Global findings from the International Study of Asthma and Allergies in Childhood (ISAA... BACKGROUND: Certain foods may increase or decrease the risk of developing asthma, rhinoconjunctivitis and eczema. We explored the impact of the intake of types of food on these diseases in Phase Three of the International Study of Asthma and Allergies in Childhood. METHODS: Written questionnaires on the symptom prevalence of asthma, rhinoconjunctivitis and eczema and types and frequency of food intake over the past 12 months were completed by 13-14-year-old adolescents and by the parents/guardians of 6-7-year-old children. Prevalence ORs were estimated using logistic regression, adjusting for confounders, and using a random (mixed) effects model. RESULTS: For adolescents and children, a potential protective effect on severe asthma was assoc

### Rerank

In [30]:
from typing import List 
from base import Reranker, Query, Text # pygaggle


class GPTReranker(Reranker):
    def __init__(self, model_name, use_cuda, max_len):
        self.model = self.load_model(model_name, use_cuda)
        self.tokenizer = self.load_tokenizer(model_name)
        self.max_len = max_len
    
    def load_model(self, model_name:str, use_cuda:bool):
        device = torch.device('cuda' if torch.cuda.is_available() & use_cuda else 'cpu')
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
        model.config.use_cache=True
        return model
    
    def load_tokenizer(self, model_name:str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return tokenizer
    
    def score(self, dataset):
        encodings = self.tokenizer(dataset, add_special_tokens=False, return_tensors='pt')
        
        dataset_len = encodings.input_ids.size(1)
        
    
    # def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
    #     for text in texts:
    #         result = self.tokenizer(query.text)
    #         input_ids
    #         attn_mask
            
        
            
        
    #     return super().rerank(query, texts)
    
    # # similarity score
    # # deft score(self, input_ids, item)

In [31]:
gpt_reranker = GPTReranker(config.model_name, config.use_cuda, max_len=512)

Downloading pytorch_model.bin: 100%|██████████| 251M/251M [02:09<00:00, 1.93MB/s]
Downloading (…)neration_config.json: 100%|██████████| 137/137 [00:00<00:00, 14.5kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 685/685 [00:00<00:00, 52.6kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 1.27MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 2.02MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 441/441 [00:00<00:00, 55.0kB/s]


In [None]:
def get_inputs(item, device, tokenizer):
    input_ids = torch.tensor(['input_ids'], device=device).unsqueeze(0)
    input_ids = tokenizer.decode()
    input_ids, 
    token_type 
    attn_mask
    

In [None]:
# top_results = torch.topk(scores, k=5).indices 
# reranked_corpus = [corpus[i] for i in top_results] 

# scored_articles = zip(articles, cosine_similarities)

# # Sort articles by cosine similarity
# sorted_articles = sorted(scored_articles, key=lambda x: x[1], reverse=True)

# scores = []
# https://github.com/amazon-science/datatuner/blob/f70369659e1c58e6ddb44d6db467978679dbdd3c/src/datatuner/lm/reranker.py#L5 

# sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
        #     : self.top_n
        # ]


In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
)

tokenizer = AutoTokenizer.from_pretrained(config.model)
model = AutoModelForCausalLM.from_pretrained(config.model,
                                             quantization_config=quantization_config,
                                             trust_remote_code=True,)

model.config.use_cache=True

In [None]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules = ["q_proj", "v_proj"],
    inference_mode=False,
    bias='none',
    task_type='CAUSAL_LM',
)

model = get