In [None]:
import os
import random
import json
import requests
import tarfile
import argparse

import numpy as np
import pandas as pd

import torch

import pyserini
from pyserini.search import SimpleSearcher
from pyserini.dsearch import SimpleDenseSearcher

import transformers
# from transformers import set_seed
# set_seed(42)

from peft import LoraConfig
from transformers import (AutoTokenizer, 
                          AutoModelForCausalLM, 
                          BitsAndBytesConfig)

In [None]:
parser = argparse.ArgumentParser(description='Reranking with LLaMA2')

parser.add_argument('--model_name', type=str, default='Llama-2-7b-hf')
parser.add_argument('--dataset', type=str, default='msmarco-passage')
parser.add_argument('--data_path', type=str, default='./collections/')
parser.add_argument('--seed',type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--max_len', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--use_cuda', type=bool, default=False)
parser.add_argument('--k', type=int, default=100, help='top k')
parser.add_argument('--k1', type=float, default=1.5, help='BM25 parameter')
parser.add_argument('--b', type=float, default=0.75, help='BM25 parameter')

parser.add_argument

config = parser.parse_args([])

In [None]:
dataset_path = os.path.join(config.data_path, config.dataset)
targz_path = os.path.join(dataset_path, 'collectionandqueries.tar.gz')
msmarco_url = 'https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz'
tsv_path = os.path.join(dataset_path, 'collection.tsv')
jsonl_path = os.path.join(dataset_path, 'collection_jsonl')
index_path = os.path.join('./indexes', 'lucene-index-msmarco-passage')

In [None]:
# download data
# https://github.com/castorini/pyserini/blob/master/docs/experiments-msmarco-passage.md
if not os.path.exists(dataset_path):
    os.mkdir(dataset_path)

response = requests.get(msmarco_url, stream=True)
file = tarfile.open(fileobj=response.raw, mode='r|gz')
file.extractall(path=dataset_path)

In [None]:
# tsv to jsonl
os.system(f'python anserini-tools/scripts/msmarco/convert_collection_to_jsonl.py ' +
          f'--collection-path {tsv_path} ' +
          f'--output-folder {jsonl_path}')

In [None]:
# indexing for BM25
# https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
os.system('python -m pyserini.index.lucene ' + 
          '--collection JsonCollection ' +
          f'--input {jsonl_path} ' +
          f'--index {index_path} ' +
          '--generator DefaultLuceneDocumentGenerator ' +
          '--threads 1 --storeRaw')

# --storePositions: builds a standard positional index
# --storeDocvectors: stores doc vectors (required for relevance feedback)
# --storeRaw: stores raw documents

In [None]:
from pyserini.search.lucene import LuceneSearcher

searcher = LuceneSearcher('indexes/lucene-index-msmarco-passage')

hits = searcher.search('what is rba')

for i in range(0, 5):
    print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.6f}')
    print(json.loads(hits[i].raw)['contents']) 

In [None]:
# indexing for ANCE 
# os.system('python -m pyserini.encode ' + 
#           'input   --corpus tests/resources/simple_cacm_corpus.json ' +
#                    '--fields text ' +
#           'output  --embeddings {index_path} ' + 
#                     '--to-faiss ' + 
#           'encoder --encoder castorini/ance-msmarco-doc-maxp ' + #  --encoder castorini/tct_colbert-v2-hnp-msmarco TCT ColBERT
#                    '--fields text ' + 
#                    '--batch 64 --device cpu ')

# python -m pyserini.index.faiss \
#   --input path/to/encoded/corpus \  # in jsonl format
#   --output path/to/output/index \
    
# from pyserini.search import FaissSearcher

# searcher = FaissSearcher(
#     'indexes/dindex-sample-dpr-multi',
#     'facebook/dpr-question_encoder-multiset-base'
# )
# hits = searcher.search('what is a lobster roll')

# for i in range(0, 10):
#     print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}')

In [None]:
from typing import List
from pyserini.search.lucene import LuceneSearcher 

# Indexer # Retriever # BaseRetriever 만들고 BM25, ANCE, Hybrid
# build_sparse_index 
# build_dense_index
# ssearch, dsearch, hsearch

class BM25Retriever:
    def __init__(self, jsonl_path, index_path, k1=1.5, b=0.75):
        self.jsonl_path = jsonl_path
        if not os.path.exists(index_path):
            self.build_sparse_index(jsonl_path, index_path)
        self.searcher = LuceneSearcher(index_path) # searcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')
        self.searcher.set_bm25(k1=k1, b=b)
        # self.searcher.set_language()
    
    def build_sparse_index(self, jsonl_path, index_path): # 나중에 dense, hybird하기 위해 build_dense_index 만들고 새로운 class 만들기
        execute_code = os.system('python -m pyserini.index.lucene ' + 
                                 '--collection JsonCollection ' +
                                 f'--input {jsonl_path} ' +
                                 f'--index {index_path} ' +
                                 '--generator DefaultLuceneDocumentGenerator ' +
                                 '--threads 1 --storeRaw')
        if execute_code != 0:
            raise Exception('Indexing Failed!')
        else:
            print('Indexing Success!')
            
    def _get_results(self, qid, hits:List):
        results = []
        
        for i, hit in enumerate(hits):
            docid = hit.docid
            content = json.loads(hits[i].raw)['contents']
            bm25_score = hit.score
            result = {'rank': i,
                      'qid': qid,
                      'docid': docid, 
                      'bm25_score': bm25_score,
                      'content': content}
            results.append(result)
            
        return results
    
    def search(self, qid, query:str, k:int=10):
        hits = self.searcher.search(query, k=k)
        search_results  = self._get_results(qid, hits)
        
        return search_results
    
    def batch_search(self, queries: List[str], qids: List[str], k:int=10):
        batch_hits = self.searcher.batch_search(queries, qids, k=k)
        bsearch_results = {}
        
        for qid, hits in batch_hits.items():
            bsearch_results[qid] = self._get_results(qid, hits)
        
        return bsearch_results

In [None]:
# top_results = torch.topk(scores, k=5).indices 
# reranked_corpus = [corpus[i] for i in top_results] 

# scored_articles = zip(articles, cosine_similarities)

# # Sort articles by cosine similarity
# sorted_articles = sorted(scored_articles, key=lambda x: x[1], reverse=True)

# scores = []
# https://github.com/amazon-science/datatuner/blob/f70369659e1c58e6ddb44d6db467978679dbdd3c/src/datatuner/lm/reranker.py#L5 
