In [1]:
import os
import random
import json
import requests
import tarfile
import argparse

import numpy as np
import pandas as pd

import torch

import transformers
from peft import LoraConfig
from transformers import (AutoTokenizer, 
                          AutoModelForCausalLM, 
                          BitsAndBytesConfig)

  from .autonotebook import tqdm as notebook_tqdm
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [17]:
parser = argparse.ArgumentParser(description='Rerank')

parser.add_argument('--model_name', type=str, default='facebook/opt-125m')
parser.add_argument('--collection', type=str, default='msmarco-passage')
parser.add_argument('--collections_path', type=str, default='./collections/')
parser.add_argument('--seed',type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--max_len', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--use_cuda', type=bool, default=False)
parser.add_argument('--k', type=int, default=100, help='top k')
parser.add_argument('--k1', type=float, default=1.5, help='BM25 parameter')
parser.add_argument('--b', type=float, default=0.75, help='BM25 parameter')

parser.add_argument

config = parser.parse_args([])

In [3]:
def get_msmarco_passage_jsonl(collections_path, ):
    msmarco_passage_path = os.path.join(collections_path, 'msmarco-passage')
    # https://microsoft.github.io/msmarco/Datasets
    msmarco_url = 'https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz'
    
    if not os.path.exists(msmarco_passage_path):
        os.mkdir(msmarco_passage_path)
        
    response = requests.get(msmarco_url, stream=True)
    file = tarfile.open(fileobj=response.raw, mode='r|gz')
    file.extractall(path=msmarco_passage_path)
    
    tsv_path = os.path.join(msmarco_passage_path, 'collection.tsv')
    jsonl_path = os.path.join(msmarco_passage_path, 'collection_jsonl')
    
    if os.path.exists(tsv_path):
        os.system(f'python anserini-tools/scripts/msmarco/convert_collection_to_jsonl.py ' +
                  f'--collection-path {tsv_path} ' +
                  f'--output-folder {jsonl_path}')
        

In [4]:
def get_msmarco_passage_top1000_tr():
    top1000_tr_url = 'https://msmarco.blob.core.windows.net/msmarcoranking/top1000.train.tar.gz'
    response = requests.get(top1000_tr_url, stream=True)
    file = tarfile.open(fileobj=response.raw, mode='r|gz')
    file.extractall(path='./collections/msmarco-passage')

def get_msmarco_passage_top1000_dev():
    top1000_dev_url = 'https://msmarco.blob.core.windows.net/msmarcoranking/top1000.dev.tar.gz'
    response = requests.get(top1000_dev_url, stream=True)
    file = tarfile.open(fileobj=response.raw, mode='r|gz')
    file.extractall(path='./collections/msmarco-passage')

In [64]:
import tarfile
file = tarfile.open('./collections/msmarco-passage/top1000.train.tar.gz', mode='r|gz')
file.extractall(path='./collections/msmarco-passage')

OSError: [Errno 28] No space left on device

In [None]:
get_msmarco_passage_top1000_dev()

In [31]:
collection_dir_path = os.path.join(config.collections_path, config.collection)
collection_path = os.path.join(collection_dir_path, 'collection.tsv')
queries_tr_path = os.path.join(collection_dir_path, 'queries.train.tsv')
qrels_tr_path = os.path.join(collection_dir_path, 'qrels.train.tsv')
qrels_dev_path = os.path.join(collection_dir_path, 'qrels.dev.tsv')
queries_dev_path = os.path.join(collection_dir_path, 'queries.dev.tsv')
queries_eval_path = os.path.join(collection_dir_path, 'queries.eval.tsv')
top1000_tr_path = os.path.join(collection_dir_path, 'top1000.train')
top1000_dev_path = os.path.join(collection_dir_path, 'top1000.dev')

In [20]:
ts = pd.read_csv(queries_dev_path, sep='\t', header=None, names=['qid', 'query'], index_col='qid')
ts

Unnamed: 0_level_0,query
qid,Unnamed: 1_level_1
1048578,cost of endless pools/swim spa
1048579,what is pcnt
1048580,what is pcb waste
1048581,what is pbis?
1048582,what is paysky
...,...
480594,"price of copper by ounce, pound"
524271,trazodone for dogs side effects
1048565,who plays sebastian michaelis
1048570,what is pearls before swine?


In [28]:
qr.loc[qr.index[0]]

query            foods and supplements to lower blood sugar
corpus    Watch portion sizes: ■ Even healthy foods will...
Name: (188714, 1000052), dtype: object

In [5]:
qr = pd.read_csv(top1000_dev_path, sep='\t', header=None, names=['qid', 'pid', 'query', 'corpus']).set_index(["qid", "pid"])
qr

Unnamed: 0_level_0,Unnamed: 1_level_0,query,corpus
qid,pid,Unnamed: 2_level_1,Unnamed: 3_level_1
188714,1000052,foods and supplements to lower blood sugar,Watch portion sizes: ■ Even healthy foods will...
1082792,1000084,what does the golgi apparatus do to the protei...,"Start studying Bonding, Carbs, Proteins, Lipid..."
995526,1000094,where is the federal penitentiary in ind,It takes THOUSANDS of Macy's associates to bri...
199776,1000115,health benefits of eating vegetarian,The good news is that you will discover what g...
660957,1000115,what foods are good if you have gout?,The good news is that you will discover what g...
...,...,...,...
679360,999933,what is a corporate bylaws,Corporate Records for Nonprofit Corporations. ...
36388,999956,average family savings account,When it comes to average retirement savings st...
43781,999956,average savings per age group,When it comes to average retirement savings st...
28442,999956,at what age does the average person retire,When it comes to average retirement savings st...


In [None]:
# https://huggingface.co/models?search=gpt+neo
GPT_PRETRAINED_MODEL_LIST = [
    'gpt-neo-125m',
    'gpt-neo-2.7B',
    'gpt-neo-1.3B'
]

In [None]:
class MarcoDataset:
    def __init__(self, collection_dir_path, tokenizer, mode='train'):
        self.collection_dir_path = collection_dir_path
        self.tokenizer = tokenizer

    
    def __len__(self):
        return len(self.queries)
    
    def __getitem__(self, idx):
        query = self.queries.iloc[idx].query
        corpus = self.collection.iloc[idx].corpus 
        
        encoding = self.get_encoding(query, corpus, idx)
    
    def get_encoding(self, query, corpus, idx):
        qids = self.tokenizer(query, max_length=128, truncation=True).input_ids
        cids = self.tokenizer(corpus, max_length=512, truncation=True).input_ids
        ids = cids + qids
        encoding = self.tokenizer.encode()
        return encoding

In [63]:
from torch.utils.data import DataLoader, Dataset, TensorDataset, IterableDataset

class MarcoEncodeDataset(Dataset):
    def __init__(self, collection_dir, tokenizer, mode='train', q_max_len=64, p_max_len=256):
        self.collection_dir = collection_dir
        self.tokenizer = tokenizer
        self.mode = mode
        self.q_max_len = q_max_len
        self.p_max_len = p_max_len
        # load data
        passages_path = os.path.join(collection_dir, 'collection.tsv')
        queries_path = os.path.join(collection_dir, f'queries.{mode}.tsv')
        qrels_path = os.path.join(collection_dir, f'qrels.{mode}.tsv')
        top1000_path = os.path.join(collection_dir, 'top1000.{mode}')
        self.passages = pd.read_csv(passages_path, sep='\t', header=None, names=['pid', 'passage'], index_col='pid')
        self.queries = pd.read_csv(queries_path, sep='\t', header=None, names=['qid', 'query'], index_col='qid')
        self.relations = pd.read_csv(qrels_path, sep='\t', header=None, names=['qid', '0', 'pid', 'label'])
        self.top1000 = pd.read_csv(top1000_path, sep='\t', header=None, names=['qid', 'pid', 'query', 'passage'])
    
    def __len__(self):
        return len(self.top1000)
        
    def __getitem__(self, idx):
        x = self.top1000.iloc[idx]
        query = self.queries.loc[x.qid].query
        passage = self.collection.loc[x.pid].passage 
        label = 0 if self.relations.loc[(self.relations['qid'] == x.qid) & (self.relations['pid'] == x.pid)].empty else 1
        
        encoded = self.tokenizer.encode_plus(
            passage,
            max_length=self.p_max_len,
            truncation='only_first',
            return_attention_mask=False,
            return_token_type_ids=True,
            pad_to_max_length=True,
        )
        
        encoded['attention_mask'] = torch.tensor(encoded['attention_mask'])

        encoded['input_ids'] = torch.tensor(encoded['input_ids'])

        encoded.update({'label': torch.LongTensor([label]),
                        'idx': torch.tensor(idx)})
        
        return encoded
    
#   feature_dict = {
#             "input_ids": passage_outputs["input_ids"],
#             "attention_mask": passage_outputs["attention_mask"],

#             "input_ids_query": query_outputs["input_ids"],
#             "attention_mask_query": query_outputs["attention_mask"],

#             "qids": qid,
#             "pids": pid,
#             "binary_labels": label,
#         }

# def collate_fn(batch):
#     max_length = 32 + 256 + 3
#     input_ids_lst = [x['query_input_ids'] + x['passage_input_ids'] for x in batch]
#     token_type_ids_lst = [[0]*len(x['query_input_ids']) + [1]*len(x['passage_input_ids']) for x in batch]
#     position_ids_lst = [list(range(len(x["query_input_ids"]) + len(x["doc_input_ids"]))) for x in batch]
#         data = {
#             "input_ids": pack_tensor_2D(input_ids_lst, default=0, dtype=torch.int64, length=max_length),
#             "token_type_ids": pack_tensor_2D(token_type_ids_lst, default=0, dtype=torch.int64, length=max_length),
#             "position_ids": pack_tensor_2D(position_ids_lst, default=0, dtype=torch.int64, length=max_length),
#         }
#         qid_lst = [x['qid'] for x in batch]
#         docid_lst = [x['docid'] for x in batch]
#         if mode == "train":

#             data["labels"] = torch.tensor([x["label"] for x in batch], dtype=torch.int64)  
#         return data, qid_lst, docid_lst
#     return collate_function
    

In [None]:
from torch.utils.data import DataLoader, Dataset, IterableDataset

# https://github.com/OpenMatch/OpenMatch/blob/ad1d6228bcf288ebe86037f93cd4ae20061ec4ea/src/openmatch/retriever/reranker.py
def RerankDataset(IterableDataset):
    def __init__(self, tokenizer, query_dataset, corpus_dataset):
        self.tokenizer = tokenizer
        self.query_dataset = query_dataset
        self.corpus_dataset = corpus_dataset
        
    # def __iter__(self):
    #     for qid, did in items():
    #         yield 
    #         {
    #                 "query_id": qid, 
    #                 "doc_id": did, 
    #                 **encode_pair(
    #                     self.tokenizer, 
    #                     self.query_dataset[qid]["input_ids"], 
    #                     self.corpus_dataset[did]["input_ids"], 
    #                     self.query_dataset.max_len, 
    #                     self.corpus_dataset.max_len,
    #                     encode_as_text_pair=self.encode_as_text_pair
    #                 ),
    #             }


        

In [None]:
class GPTReranker:
    def __init__(self):
        self.model = self.load_model(config.model_name, config.use_cuda)
        self.tokenizer = self.load_tokenizer(config.model_name)
        self.model.eval()
    
    def load_model(self, model_name:str, use_cuda:bool):
        device = torch.device('cuda' if torch.cuda.is_available() & use_cuda else 'cpu')
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(device)
        model.config.use_cache=True
        return model
    
    def load_tokenizer(self, model_name:str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return tokenizer
    
    def _get_prompt(self, query)
    
    def rerank(self, query, texts):
        prompt =  f"Please generate a query based on the following passage: {texts}"
        

In [None]:
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125m')

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def rerank(self, query, texts):
        reranked_texts = []

        # Encode the query text
        query_inputs = self.tokenizer(query, return_tensors='pt', truncation=True, max_length=self.max_len, padding=True)

        for text in texts:
            # Encode the text
            text_inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=self.max_len, padding=True)

            # Generate the reranking input by concatenating query and text
            rerank_input = {
                'input_ids': torch.cat([query_inputs['input_ids'], text_inputs['input_ids']], dim=1),
                'attention_mask': torch.cat([query_inputs['attention_mask'], text_inputs['attention_mask']], dim=1)
            }

            # Generate reranking scores using the GPT model
            with torch.no_grad():
                logits = self.model(**rerank_input).logits

            # Calculate the total score by summing logits
            total_score = logits.sum().item()

            # Append text and total score to the reranked_texts
            reranked_texts.append({'text': text, 'total_score': total_score})

        # Sort texts based on total_score in descending order
        reranked_texts.sort(key=lambda x: x['total_score'], reverse=True)

        # Extract the sorted texts
        sorted_texts = [item['text'] for item in reranked_texts]

        return sorted_texts