In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import os
import json
from typing import List, Any, Tuple
from datasets import Dataset
from tqdm import tqdm
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from enum import Enum
import openai
import pickle

In [4]:
from InstructorEmbedding import INSTRUCTOR
from gritlm import GritLM
import sentence_transformers
from datasets import load_dataset
query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")
corpus_s2orc_data = load_dataset("princeton-nlp/LitSearch", "corpus_s2orc", split="full")

Downloading readme:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

In [5]:
query_df = query_data.to_pandas()
corpus_clean_df = corpus_clean_data.to_pandas()
corpus_s2orc_df = corpus_s2orc_data.to_pandas()

In [6]:
query_df.head()

Unnamed: 0,query_set,query,specificity,quality,corpusids
0,inline_acl,Are there any research papers on methods to co...,0,2,[202719327]
1,inline_acl,Are there any resources available for translat...,1,2,[227231792]
2,inline_acl,Are there any studies that explore post-hoc te...,0,2,"[226254579, 204976362]"
3,inline_acl,Are there any tools or studies that have focus...,1,2,"[10961392, 12160022]"
4,inline_acl,Are there papers that propose contextualized c...,1,2,[233296182]


In [7]:
corpus_clean_df.head()


Unnamed: 0,corpusid,title,abstract,citations,full_paper
0,252715594,PHENAKI: VARIABLE LENGTH VIDEO GENERATION FROM...,"We present Phenaki, a model capable of realist...","[6628106, 174802916, 238582653]",PHENAKI: VARIABLE LENGTH VIDEO GENERATION FROM...
1,13002849,MODE REGULARIZED GENERATIVE ADVERSARIAL NETWORKS,Although Generative Adversarial Networks achie...,[],MODE REGULARIZED GENERATIVE ADVERSARIAL NETWOR...
2,239998253,What Do We Mean by Generalization in Federated...,Federated learning data is drawn from a distri...,"[235613568, 231924480, 211678094, 195798643, 4...",What Do We Mean by Generalization in Federated...
3,62841605,SPREADING VECTORS FOR SIMILARITY SEARCH,Discretizing multi-dimensional data distributi...,[],SPREADING VECTORS FOR SIMILARITY SEARCH\n\n\nA...
4,253237531,MACHINE UNLEARNING OF FEDERATED CLUSTERS,Federated clustering (FC) is an unsupervised l...,[],MACHINE UNLEARNING OF FEDERATED CLUSTERS\n\n\n...


In [8]:
corpus_s2orc_df.head()

Unnamed: 0,corpusid,externalids,content,year
0,252715594,"{'acl': None, 'arxiv': '2210.02399', 'dblp': '...","{'annotations': {'abstract': '[{""end"":2761,""st...",2023.0
1,13002849,"{'acl': None, 'arxiv': '1612.02136', 'dblp': '...","{'annotations': {'abstract': '[{""end"":1726,""st...",2017.0
2,239998253,"{'acl': None, 'arxiv': None, 'dblp': 'conf/icl...","{'annotations': {'abstract': '[{""end"":1082,""st...",2022.0
3,62841605,"{'acl': None, 'arxiv': '1806.03198', 'dblp': '...","{'annotations': {'abstract': '[{""end"":1503,""st...",2019.0
4,253237531,"{'acl': None, 'arxiv': '2210.16424', 'dblp': '...","{'annotations': {'abstract': '[{""end"":2886,""st...",2023.0


In [9]:
#for loading json files into 'data'
def read_json(filename: str, silent: bool = False) -> List[Any]:
    with open(filename, 'r') as file:
        if filename.endswith(".json"):
            data = json.load(file)
        elif filename.endswith(".jsonl"):
            data = [json.loads(line) for line in file]
        else:
            raise ValueError("Input file must be either a .json or .jsonl file")

    if not silent:
        print(f"Loaded {len(data)} records from {filename}")
    return data

In [12]:
#for writing data into json file format
def write_json(data: List[Any], filename: str, silent: bool = False) -> None:
    if filename.endswith(".json"):
        with open(filename, 'w') as file:
            json.dump(data, file, indent=4)
    elif filename.endswith(".jsonl"):
        with open(filename, 'w') as file:
            for item in data:
                file.write(json.dumps(item) + "\n")
    else:
        raise ValueError("Output file must be either a .json or .jsonl file")

    if not silent:
        print(f"Saved {len(data)} records to {filename}")

In [13]:
def calculate_recall(retrieved: List[int], relevant_docs: List[int]) -> float:
    """Calculate recall metric."""
    num_relevant_retrieved = len(set(retrieved).intersection(set(relevant_docs)))
    num_relevant = len(relevant_docs)
    return num_relevant_retrieved / num_relevant if num_relevant > 0 else 0

In [14]:
def calculate_ndcg(retrieved: List[int], relevant_docs: List[int]) -> float:
    """Calculate Normalized Discounted Cumulative Gain."""
    dcg = sum(1 / (idx + 1) for idx, docid in enumerate(retrieved) if docid in relevant_docs)
    idcg = sum(1 / (idx + 1) for idx in range(len(relevant_docs)))
    return dcg / idcg if idcg > 0 else 0

In [15]:
from enum import Enum

class TextType(Enum):
    KEY = 1
    QUERY = 2

# KVStore Class
class KVStore:
    """Key-Value store base class for retrieval systems."""
    def __init__(self, index_name: str, index_type: str) -> None:
        self.index_name = index_name
        self.index_type = index_type
        self.keys = []
        self.encoded_keys = []
        self.values = []

    def __len__(self) -> int:
        return len(self.keys)

    def _encode(self, text: str, type: TextType) -> Any:
        return self._encode_batch([text], type, show_progress_bar=False)[0]

    def clear(self) -> None:
        self.keys = []
        self.encoded_keys = []
        self.values = []

    def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None:
        if len(self.keys) > 0:
            raise ValueError("Index is not empty. Please create a new index or clear the existing one.")

        for key, value in tqdm(key_value_pairs.items(), desc=f"Creating {self.index_name} index"):
            self.keys.append(key)
            self.values.append(value)
        self.encoded_keys = self._encode_batch(self.keys, TextType.KEY)

    def save(self, dir_name: str) -> None:
        save_dict = {}
        for key, value in self.__dict__.items():
            if key[0] != "_":  # Ignore private attributes
                save_dict[key] = value

        print(f"Saving index to {os.path.join(dir_name, f'{self.index_name}.{self.index_type}')}")
        os.makedirs(dir_name, exist_ok=True)
        with open(os.path.join(dir_name, f"{self.index_name}.{self.index_type}"), 'wb') as file:
            pickle.dump(save_dict, file, protocol=pickle.HIGHEST_PROTOCOL)

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        """
        Query the index and return the top-N results.

        :param query_text: The query text.
        :param n: The number of top results to return.
        :param return_keys: Whether to return keys with values.
        :return: A list of results.
        """
        encoded_query = self._encode(query_text, TextType.QUERY)
        indices = self._query(encoded_query, n)
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results




In [16]:
class GTR(KVStore):
    def __init__(self, index_name: str, model_path: str = "sentence-transformers/gtr-t5-large"):
        super().__init__(index_name, 'gtr')
        self.model_path = model_path
        self._model = sentence_transformers.SentenceTransformer(model_path, device="cpu")



    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]:
        return self._model.encode(texts, batch_size=256, show_progress_bar=show_progress_bar).astype(np.float16)

    def load(self, path: str):
        super().load(path)
        self._model = sentence_transformers.SentenceTransformer(model_path, device="cpu")

        return self

    def _query(self, encoded_query: Any, n: int) -> List[int]:
      cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0]
      top_indices = cosine_similarities.argsort()[-n:][::-1]
      return top_indices


In [17]:
import nltk
import numpy as np
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from typing import List, Tuple, Any


class BM25(KVStore):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'bm25')

        nltk.download('punkt')
        nltk.download('stopwords')

        self._tokenizer = nltk.word_tokenize
        self._stop_words = set(nltk.corpus.stopwords.words('english'))
        self._stemmer = nltk.stem.PorterStemmer().stem
        self.index = None   # BM25 index

    def _encode_batch(self, texts: str, type: TextType, show_progress_bar: bool = True) -> List[str]:
        # lowercase, tokenize, remove stopwords, and stem
        tokens_list = []
        for text in tqdm(texts, disable=not show_progress_bar):
            tokens = self._tokenizer(text.lower())
            tokens = [token for token in tokens if token not in self._stop_words]
            tokens = [self._stemmer(token) for token in tokens]
            tokens_list.append(tokens)
        return tokens_list

    def _query(self, encoded_query: List[str], n: int) -> List[int]:
        top_indices = np.argsort(self.index.get_scores(encoded_query))[::-1][:n].tolist()
        return top_indices

    def clear(self) -> None:
        super().clear()
        self.index = None
    
    def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None:
        super().create_index(key_value_pairs)
        self.index = BM25Okapi(self.encoded_keys)
    
    def load(self, dir_name: str) -> None:
        super().load(dir_name)
        self._tokenizer = nltk.word_tokenize
        self._stop_words = set(nltk.corpus.stopwords.words('english'))
        self._stemmer = nltk.stem.PorterStemmer().stem
        return self

In [18]:
class E5(KVStore):
    def __init__(self, index_name: str, model_path: str = "intfloat/e5-large-v2"):
        super().__init__(index_name, 'e5')
        self.model_path = model_path
        self._model = sentence_transformers.SentenceTransformer(model_path, device="cuda", cache_folder=utils.get_cache_dir()).bfloat16()
    
    def _format_text(self, text: str, type: TextType) -> str:
        if type == TextType.KEY:
            text = "passage: " + text
        elif type == TextType.QUERY:
            text = "query: " + text
        else:
            raise ValueError("Invalid TextType")
        return text
    
    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]:
        texts = [self._format_text(text, type) for text in texts]
        return self._model.encode(texts, batch_size=256, normalize_embeddings=True, show_progress_bar=show_progress_bar).astype(np.float16)
    
    def _query(self, encoded_query: Any, n: int) -> List[int]:
        cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0]
        top_indices = cosine_similarities.argsort()[-n:][::-1]
        return top_indices
    
    def load(self, path: str):
        super().load(path)
        self._model = sentence_transformers.SentenceTransformer(self.model_path, device="cuda", cache_folder=utils.get_cache_dir())
        return self

In [19]:
class GRIT(KVStore):
    def __init__(self, index_name: str, raw_instruction: str, model_path: str = "GritLM/GritLM-7B"):
        super().__init__(index_name, 'grit')
        self.model_path = model_path
        self.raw_instruction = raw_instruction
        self._model = GritLM(model_path, torch_dtype="auto", device_map="auto", mode="embedding")
    
    def _get_instruction(self, type: TextType) -> str:
        if type == TextType.KEY:
            return "<|embed|>\n"
        elif type == TextType.QUERY:
            return "<|user|>\n" + self.raw_instruction + "\n<|embed|>\n"
        else:
            raise ValueError("Invalid TextType")
    
    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]:
        return self._model.encode(texts, batch_size=256, instruction=self._get_instruction(type), show_progress_bar=show_progress_bar).astype(np.float16)
    
    def _query(self, encoded_query: Any, n: int) -> List[int]:
        try:
            cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0]
        except:
            for i, encoded_key in enumerate(self.encoded_keys):
                if np.any(np.isnan(encoded_key)):
                    self.encoded_keys[i] = np.zeros_like(encoded_key)
            cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0]
        top_indices = cosine_similarities.argsort()[-n:][::-1]
        return top_indices
    
    def load(self, path: str):
        super().load(path)
        self._model = GritLM(self.model_path, torch_dtype="auto", device_map="auto", mode="embedding")
        return self

In [20]:
class Instructor(KVStore):
    def __init__(self, index_name: str, key_instruction: str, query_instruction: str, model_path: str = "hkunlp/instructor-xl"):
        super().__init__(index_name, 'instructor')
        self.model_path = model_path
        self.key_instruction = key_instruction
        self.query_instruction = query_instruction
        self._model = INSTRUCTOR(model_path, device="cuda", cache_folder=utils.get_cache_dir())
    
    def _format_text(self, text: str, type: TextType) -> List[str]:
        if type == TextType.KEY:
            return [self.key_instruction, text]
        elif type == TextType.QUERY:
            return [self.query_instruction, text]
        else:
            raise ValueError("Invalid TextType")
    
    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]:
        texts = [self._format_text(text, type) for text in texts]
        return self._model.encode(texts, batch_size=128, normalize_embeddings=True, show_progress_bar=show_progress_bar).astype(np.float16)
    
    def _query(self, encoded_query: Any, n: int) -> List[int]:
        cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0]
        top_indices = cosine_similarities.argsort()[-n:][::-1]
        return top_indices
    
    def load(self, path: str):
        super().load(path)
        self._model = INSTRUCTOR(self.model_path, device="cuda", cache_folder=utils.get_cache_dir())
        return self
        

In [31]:
def evaluate_index(index_name: str, dataset_path: str, top_k: int = 200):
    query_set = [query for query in datasets.load_dataset(dataset_path, "query", split="full")]
    index = load_index(index_name)
    for query in tqdm(query_set):
        query_text = query["query"]
        top_k_results = index.query(query_text, top_k)
        query["retrieved"] = top_k_results
    return query_set


In [32]:
len(query_data)

597

In [33]:
# Test KVStore save method
kv_store = KVStore("test_index", "gtr")
kv_store.keys = ["Document 1", "Document 2"]
kv_store.values = [{"content": "This is the first document."}, {"content": "This is the second document."}]
kv_store.save("retrieval_indices")


Saving index to retrieval_indices/test_index.gtr


In [34]:
query_data['query_set'].nunique()

AttributeError: 'list' object has no attribute 'nunique'

In [35]:
query_data=query_data.to_pandas()

In [36]:
query_data['query_set'].nunique()

4