In [2]:
import os
import random
import json
import requests
import tarfile
import argparse

import numpy as np
import pandas as pd

import torch

import pyserini
from pyserini.search import SimpleSearcher
from pyserini.dsearch import SimpleDenseSearcher

import transformers
# from transformers import set_seed
# set_seed(42)

from peft import LoraConfig
from transformers import (AutoTokenizer, 
                          AutoModelForCausalLM, 
                          BitsAndBytesConfig)

In [3]:
parser = argparse.ArgumentParser(description='Reranking with LLaMA2')

parser.add_argument('--model_name', type=str, default='Llama-2-7b-hf')
parser.add_argument('--dataset', type=str, default='msmarco-passage')
parser.add_argument('--data_path', type=str, default='./collections/')
parser.add_argument('--seed',type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--max_len', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--use_cuda', type=bool, default=True)
parser.add_argument('--k', type=int, default=10, help='top k')
parser.add_argument('--k1', type=float, default=1.5, help='BM25 parameter')
parser.add_argument('--b', type=float, default=0.75, help='BM25 parameter')

parser.add_argument

config = parser.parse_args([])

In [4]:
dataset_path = os.path.join(config.data_path, config.dataset)
targz_path = os.path.join(dataset_path, 'collectionandqueries.tar.gz')
msmarco_url = 'https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz'
tsv_path = os.path.join(dataset_path, 'collection.tsv')
jsonl_path = os.path.join(dataset_path, 'collection_jsonl')
index_path = os.path.join('./indexes', 'lucene-index-msmarco-passage')

In [None]:
# download data
# https://github.com/castorini/pyserini/blob/master/docs/experiments-msmarco-passage.md
if not os.path.exists(dataset_path):
    os.mkdir(dataset_path)

response = requests.get(msmarco_url, stream=True)
file = tarfile.open(fileobj=response.raw, mode='r|gz')
file.extractall(path=dataset_path)

In [None]:
# tsv to jsonl
os.system(f'python anserini-tools/scripts/msmarco/convert_collection_to_jsonl.py ' +
          f'--collection-path {tsv_path} ' +
          f'--output-folder {jsonl_path}')

In [5]:
# indexing for BM25
# https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
os.system('python -m pyserini.index.lucene ' + 
          '--collection JsonCollection ' +
          f'--input {jsonl_path} ' +
          f'--index {index_path} ' +
          '--generator DefaultLuceneDocumentGenerator ' +
          '--threads 1 --storeRaw')

# --storePositions: builds a standard positional index
# --storeDocvectors: stores doc vectors (required for relevance feedback)
# --storeRaw: stores raw documents

0

In [30]:
from pyserini.search.lucene import LuceneSearcher

searcher = LuceneSearcher('indexes/lucene-index-msmarco-passage')
hits = searcher.search('what is rba', 10)

for i in range(0, 5):
    print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.6f}')
    print(json.loads(hits[i].raw)['contents'])

 1 1569393 12.336000
RBA stands for retrobulbar anaesthesia. Q: A: How to abbreviate retrobulbar anaesthesia? retrobulbar anaesthesia can be abbreviated as RBA. Q: A: What is the meaning of RBA abbreviation? The meaning of RBA abbreviation is retrobulbar anaesthesia. Q: A: What is RBA abbreviation? One of the definitions of RBA is retrobulbar anaesthesia. Q: A: What does RBA mean? RBA as abbreviation means retrobulbar anaesthesia. Q: A: What is shorthand of retrobulbar anaesthesia? The most common shorthand of retrobulbar anaesthesia is RBA.
 2 1569389 10.709900
What does Medical & Science RBA stand for? Hop on to get the meaning of RBA. The Medical & Science Acronym /Abbreviation/Slang RBA means retrobulbar anaesthesia. by AcronymAndSlang.com
 3 5358241 9.721600
What are the differences between RBA, RDA and RTA? Our no-nonsense guide to the world of building your own! The acronym RBA stands for ReBuildable Atomizers, an important category of vaping atomizer systems. RTA's and RDA's ar

In [62]:
from pyserini.search.lucene import LuceneSearcher

searcher = LuceneSearcher('indexes/lucene-index-msmarco-passage')

hits = searcher.search('what is rba')

for i in range(0, 5):
    print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.6f}')
    print(json.loads(hits[i].raw)['contents']) 

 1 1569393 12.336000
RBA stands for retrobulbar anaesthesia. Q: A: How to abbreviate retrobulbar anaesthesia? retrobulbar anaesthesia can be abbreviated as RBA. Q: A: What is the meaning of RBA abbreviation? The meaning of RBA abbreviation is retrobulbar anaesthesia. Q: A: What is RBA abbreviation? One of the definitions of RBA is retrobulbar anaesthesia. Q: A: What does RBA mean? RBA as abbreviation means retrobulbar anaesthesia. Q: A: What is shorthand of retrobulbar anaesthesia? The most common shorthand of retrobulbar anaesthesia is RBA.
 2 1569389 10.709900
What does Medical & Science RBA stand for? Hop on to get the meaning of RBA. The Medical & Science Acronym /Abbreviation/Slang RBA means retrobulbar anaesthesia. by AcronymAndSlang.com
 3 5358241 9.721600
What are the differences between RBA, RDA and RTA? Our no-nonsense guide to the world of building your own! The acronym RBA stands for ReBuildable Atomizers, an important category of vaping atomizer systems. RTA's and RDA's ar

In [None]:
# indexing for ANCE 
# os.system('python -m pyserini.encode ' + 
#           'input   --corpus tests/resources/simple_cacm_corpus.json ' +
#                    '--fields text ' +
#           'output  --embeddings {index_path} ' + 
#                     '--to-faiss ' + 
#           'encoder --encoder castorini/ance-msmarco-doc-maxp ' + #  --encoder castorini/tct_colbert-v2-hnp-msmarco TCT ColBERT
#                    '--fields text ' + 
#                    '--batch 64 --device cpu ')

# python -m pyserini.index.faiss \
#   --input path/to/encoded/corpus \  # in jsonl format
#   --output path/to/output/index \
    
# from pyserini.search import FaissSearcher

# searcher = FaissSearcher(
#     'indexes/dindex-sample-dpr-multi',
#     'facebook/dpr-question_encoder-multiset-base'
# )
# hits = searcher.search('what is a lobster roll')

# for i in range(0, 10):
#     print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}')

In [60]:
from typing import Optional, List
from pyserini.search.lucene import LuceneSearcher 


class BM25Retriever:
    def __init__(self, jsonl_path, index_path, k1=config.k1, b=config.b):
        self.jsonl_path = jsonl_path
        if not os.path.exists(index_path):
            self.build_index(jsonl_path, index_path)
        self.searcher = LuceneSearcher(index_path) # searcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')
        self.searcher.set_bm25(k1=k1, b=b)
        # self.searcher.set_language()
    
    def build_sparse_index(self, jsonl_path, index_path): # 나중에 dense, hybird하기 위해 build_dense_index 만들고 새로운 class 만들기
        execute_code = os.system('python -m pyserini.index.lucene ' + 
                                 '--collection JsonCollection ' +
                                 f'--input {jsonl_path} ' +
                                 f'--index {index_path} ' +
                                 '--generator DefaultLuceneDocumentGenerator ' +
                                 '--threads 1 --storeRaw')
        if execute_code != 0:
            raise Exception('Indexing Failed!')
        else:
            print('Indexing Success!')
            
    def _get_results(self, qid, hits):
        results = []
        
        for i, hit in enumerate(hits):
            docid = hit.docid
            content = json.loads(hits[i].raw)['contents']
            bm25_score = hit.score
            result = {'rank': i,
                      'qid': qid,
                      'docid': docid, 
                      'bm25_score': bm25_score,
                      'content': content}
            results.append(result)
            
        return results
    
    def search(self, qid, query, k:int=config.k):
        hits = self.searcher.search(query, k=k)
        search_results  = self._get_results(qid, hits)
        return search_results
        
    
    def batch_search(self, queries: List[str], qids: List[str], k:int=config.k):
        batch_hits = self.searcher.batch_search(queries, qids, k=k)
        bsearch_results = {}
        
        for qid, hits in batch_hits.items():
            bsearch_results[qid] = self._get_results(qid, hits)
        
        return bsearch_results

In [61]:
bm25_retriever = BM25Retriever(jsonl_path, index_path)
quries = ['what is rba', 'what is oilskin fabric']

results = bm25_retriever.batch_search(queries=quries, qids=['19699', '19717'])
results

{'19699': [{'rank': 0,
   'qid': '19699',
   'docid': '1569393',
   'bm25_score': 11.199000358581543,
   'content': 'RBA stands for retrobulbar anaesthesia. Q: A: How to abbreviate retrobulbar anaesthesia? retrobulbar anaesthesia can be abbreviated as RBA. Q: A: What is the meaning of RBA abbreviation? The meaning of RBA abbreviation is retrobulbar anaesthesia. Q: A: What is RBA abbreviation? One of the definitions of RBA is retrobulbar anaesthesia. Q: A: What does RBA mean? RBA as abbreviation means retrobulbar anaesthesia. Q: A: What is shorthand of retrobulbar anaesthesia? The most common shorthand of retrobulbar anaesthesia is RBA.'},
  {'rank': 1,
   'qid': '19699',
   'docid': '1569389',
   'bm25_score': 10.021699905395508,
   'content': 'What does Medical & Science RBA stand for? Hop on to get the meaning of RBA. The Medical & Science Acronym /Abbreviation/Slang RBA means retrobulbar anaesthesia. by AcronymAndSlang.com'},
  {'rank': 2,
   'qid': '19699',
   'docid': '6107671',
 

In [None]:
from typing import List 
from pygaggle.rerank.base import Reranker, Query, Text

# https://github.com/castorini/pygaggle/blob/08339dd31f58ef40fbaa109726402e164eeba125/pygaggle/rerank/transformer.py#L14
# https://github.com/informagi/EMBERT/blob/f89efeeeef53d4dc9e2cc1f2b547aa34aa4f7945/Code/pygaggle/rerank/transformer.py
class LLaMAReranker(Reranker):
    def __init__(self, model_name, tokenizer, max_len):
        assert model_name in ['Llama-2-7b-hf'], 'Wrong Model Name'
        self.model = AutoModelForCausalLM.from_pretrained(f'meta-llama/{model_name}')
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def build_model(self):
        pass
    
    def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
        query =
        
        return super().rerank(query, texts)

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
)

tokenizer = AutoTokenizer.from_pretrained(config.model)
model = AutoModelForCausalLM.from_pretrained(config.model,
                                             quantization_config=quantization_config,
                                             trust_remote_code=True,)

model.config.use_cache=True

In [None]:
peft_config = LoraConfig(
    task_type='CAUSAL_LM',
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules = ["q_proj", "v_proj"]
)