In [75]:
import os
import random
import json
import requests
import tarfile
import argparse
import multiprocessing

import numpy as np
import pandas as pd

import torch

import pyserini
from pyserini.search import SimpleSearcher
from pyserini.dsearch import SimpleDenseSearcher
from pyserini.search.lucene import LuceneSearcher

import transformers

from peft import LoraConfig
from transformers import (AutoTokenizer, 
                          AutoModelForCausalLM, 
                          BitsAndBytesConfig)

In [76]:
parser = argparse.ArgumentParser(description='Rerank')

parser.add_argument('--model_name', type=str, default='facebook/opt-125m')
parser.add_argument('--collection', type=str, default='msmarco-passage')
parser.add_argument('--collections_path', type=str, default='./collections/')
parser.add_argument('--seed',type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--max_len', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--use_cuda', type=bool, default=False)
parser.add_argument('--k', type=int, default=100, help='top k')
parser.add_argument('--k1', type=float, default=1.5, help='BM25 parameter')
parser.add_argument('--b', type=float, default=0.75, help='BM25 parameter')

parser.add_argument

config = parser.parse_args([])

In [77]:
def get_msmarco_passage_jsonl(collections_path, ):
    msmarco_passage_path = os.path.join(collections_path, 'msmarco-passage')
    msmarco_url = 'https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz'
    
    if not os.path.exists(msmarco_passage_path):
        os.mkdir(msmarco_passage_path)
        
    response = requests.get(msmarco_url, stream=True)
    file = tarfile.open(fileobj=response.raw, mode='r|gz')
    file.extractall(path=msmarco_passage_path)
    
    tsv_path = os.path.join(msmarco_passage_path, 'collection.tsv')
    jsonl_path = os.path.join(msmarco_passage_path, 'collection_jsonl')
    
    if os.path.exists(tsv_path):
        os.system(f'python anserini-tools/scripts/msmarco/convert_collection_to_jsonl.py ' +
                  f'--collection-path {tsv_path} ' +
                  f'--output-folder {jsonl_path}')

In [78]:
from typing import List

class Indexer:
    def __init__(self, jsonl_path, index_path):
        self.jsonl_path = jsonl_path
        self.index_path = index_path
    
    def build_sparse_index(self):
        execute_code = os.system('python -m pyserini.index.lucene ' + 
                                 '--collection JsonCollection ' +
                                 f'--input {self.jsonl_path} ' +
                                 f'--index {self.index_path} ' +
                                 '--generator DefaultLuceneDocumentGenerator ' +
                                 '--threads 1 --storeRaw')
        if execute_code != 0:
            raise Exception('Indexing Failed!')
        else:
            print('Indexing Success!')
    
    def build_dense_index(self):
        pass 

class BM25Retriever:
    def __init__(self, index_path, k, k1=1.5, b=0.75):
        self.searcher = LuceneSearcher(index_path)
        self.searcher.set_bm25(k1=k1, b=b)
        self.k = k
            
    def _get_results(self, qid, hits:List):
        results = []
        
        for i, hit in enumerate(hits):
            docid = hit.docid
            content = json.loads(hits[i].raw)['contents']
            bm25_score = hit.score
            result = {'rank': i+1,
                      'qid': qid,
                      'docid': docid, 
                      'score': bm25_score,
                      'content': content}
            results.append(result)
            
        return results
    
    def _save_json(self, results:List[dict]):
        json_path = os.path.join('./retrieved/', f'bm25-{config.collection}.json')
        json_file = open(json_path, 'w', encoding='utf-8', newline='\n')
        for result in results:
            json_file.write(json.dumps(result) + '\n')
        
        json_file.close()
    
    def search(self, qid, query_text:str):
        search_results = {}
        hits = self.searcher.search(query_text, k=self.k,)
        search_results['query'] = query_text
        search_results['hits']  = self._get_results(qid, hits)
        
        return search_results
    
    def batch_search(self, qids, query_texts: List[str], is_save:bool):
        query_dict = dict(zip(qids, query_texts))
        batch_hits = self.searcher.batch_search(query_texts, qids, k=self.k, threads=multiprocessing.cpu_count())
        bsearch_results = []
        bsearch_items = {}

        for qid, hits in batch_hits.items():
            bsearch_items['query'] = query_dict[qid]
            bsearch_items['hits'] = self._get_results(qid, hits)
            bsearch_results.append(bsearch_items)
            bsearch_items = {}
            
        if is_save:
            self._save_json(bsearch_results)
       
        return bsearch_results
    
# if not os.path.exists(index_path):
#             indexer = Indexer()
#             self.build_sparse_index(jsonl_path, index_path)     

In [79]:
def read_json(json_path:str):
    with open(json_path, 'r') as f:
        lines = f.readlines()
        result = [json.loads(line) for line in lines]
        return result

In [80]:
collection_dir_path = os.path.join(config.collections_path, config.collection)
collection_path = os.path.join(collection_dir_path, 'collection.tsv')
queries_train_path = os.path.join(collection_dir_path, 'queries.train.tsv')
queries_dev_path = os.path.join(collection_dir_path, 'queries.dev.tsv')

In [81]:
queries_train = pd.read_csv(queries_train_path, sep='\t', header=None, names=['qid', 'query'])
queries_train.head()

Downloading pytorch_model.bin:   1%|          | 21.0M/2.63G [00:15<09:53, 4.40MB/s]

ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [None]:
num = 100
# 후에 query 후보를 여러 개 생성해서..? 더 많이.?
# llama2 chat 활용해서
# https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ/discussions/5
llama2_prompt = [
    {
        "role": "system",
        "content": "You are an intelligent assistant capable of generatig queries for given passages.",
    },
    {
        "role": "user",
        "content": f"Please generate a query for the following {num} passages based on its content. \nThe task is to generate a query that summarizes the main points of each passage. \nThe query should be relevant to the content of the passage."
    },
    {"role": "assistant", "content": "Okay, please provide the passages to generate a query."},
]

In [None]:
class GPTReranker:
    def __init__(self):
        self.model = self.load_model(config.model_name, config.use_cuda)
        self.tokenizer = self.load_tokenizer(config.model_name)
        self.model.eval()
    
    def load_model(self, model_name:str, use_cuda:bool):
        device = torch.device('cuda' if torch.cuda.is_available() & use_cuda else 'cpu')
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(device)
        model.config.use_cache=True
        return model
    
    def load_tokenizer(self, model_name:str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return tokenizer
    
    def _get_prompt(self, query)
    
    def rerank(self, query, texts):
        prompt =  f"Please generate a query based on the following passage: {texts}"
        

In [103]:
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125m')

Downloading (…)lve/main/config.json: 100%|██████████| 1.01k/1.01k [00:00<00:00, 64.7kB/s]
Downloading model.safetensors: 100%|██████████| 526M/526M [03:20<00:00, 2.62MB/s] 
Downloading (…)okenizer_config.json: 100%|██████████| 560/560 [00:00<00:00, 78.7kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 899k/899k [00:01<00:00, 623kB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 744kB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 357/357 [00:00<00:00, 84.7kB/s]


In [126]:
tokenizer.pad_token = tokenizer.eos_token

In [127]:
passages = "Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site."
prompt =f"Please generate a question for the following passages: {passages}"

prompt

"Please generate a question for the following passages: Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site."

In [122]:
inputs = tokenizer(prompt, return_tensors='pt')

In [123]:
generate_ids = model.generate(inputs.input_ids, num_return_sequences=1, do_sample=True, num_beams=1, max_new_tokens=32)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [125]:
tokenizer.decode(generate_ids[0], skip_special_tokens=True,)

"Please generate a question for the following passages: Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.\n\nWhen the first Australian Prime Minister, Andrew stress the importance of keeping the Australian dollar in the currency, he would be seen as representing an important development in"

In [89]:
prompt

'Please generate a query for the following passages: The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was.'

In [102]:
inputs = tokenizer(prompt, return_tensors="pt")
generate_ids = model.generate(inputs.input_ids, max_length=128)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

'Please generate a query for the following passages: RBA Recognized with the 2014 Microsoft US Regional Partner of the ... by PR Newswire. Contract Awarded for supply and support the. Securitisations System used for risk management and analysis.                                                                             '

In [None]:
def rerank(self, query, texts):
        reranked_texts = []

        # Encode the query text
        query_inputs = self.tokenizer(query, return_tensors='pt', truncation=True, max_length=self.max_len, padding=True)

        for text in texts:
            # Encode the text
            text_inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=self.max_len, padding=True)

            # Generate the reranking input by concatenating query and text
            rerank_input = {
                'input_ids': torch.cat([query_inputs['input_ids'], text_inputs['input_ids']], dim=1),
                'attention_mask': torch.cat([query_inputs['attention_mask'], text_inputs['attention_mask']], dim=1)
            }

            # Generate reranking scores using the GPT model
            with torch.no_grad():
                logits = self.model(**rerank_input).logits

            # Calculate the total score by summing logits
            total_score = logits.sum().item()

            # Append text and total score to the reranked_texts
            reranked_texts.append({'text': text, 'total_score': total_score})

        # Sort texts based on total_score in descending order
        reranked_texts.sort(key=lambda x: x['total_score'], reverse=True)

        # Extract the sorted texts
        sorted_texts = [item['text'] for item in reranked_texts]

        return sorted_texts