In [1]:
%load_ext autoreload
%autoreload 2

## Step 1: Load the evaluation data source

In [2]:
from datasets import load_dataset

query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")

## Step 2: Load the Corpus and Build the index

### Step 2.1: Load the Corpus

Note: Deduplicate by chunk format!

In [3]:
from datasets import Dataset


corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")
# corpus_s2orc_data = load_dataset("princeton-nlp/LitSearch", "corpus_s2orc", split="full")
corpus_clean_data_with_assembled_title_abstract = corpus_clean_data.map(lambda x: {'chunk': f"Title: {x['title']}\nAbstract: {x['abstract']}"})
kv = dict()
for i in range(len(corpus_clean_data_with_assembled_title_abstract)):
    example = corpus_clean_data_with_assembled_title_abstract[i]
    kv[example['chunk']] = example
corpus_clean_data = corpus_clean_data_with_assembled_title_abstract = Dataset.from_list(list(kv.values()))
print(corpus_clean_data)

Map:   0%|          | 0/64183 [00:00<?, ? examples/s]

Dataset({
    features: ['corpusid', 'title', 'abstract', 'citations', 'full_paper', 'chunk'],
    num_rows: 57657
})


In [None]:
# corpus_clean_data.to_parquet('corpus_clean_dedup.parquet')

### Step 2.2: Load the Indexer

In [19]:
from gritlm import GritLM

# Loads the model for both capabilities; If you only need embedding pass `mode="embedding"` to save memory (no lm head)
# model = GritLM("GritLM/GritLM-7B", torch_dtype="auto", device_map="auto")
model = GritLM("GritLM/GritLM-7B", torch_dtype="auto", device_map="auto", mode='embedding')
# To load the 8x7B you will likely need multiple GPUs.
# All the kwargs are passed to HF from_pretrained so you can just do the below to load on multiple GPUs:
# model = GritLM("GritLM/GritLM-8x7B", torch_dtype="auto", device_map="auto")
# You can also load other models e.g.
# model = GritLM("Muennighoff/SGPT-125M-weightedmean-nli-bitfit", pooling_method="weighted_mean", attn=None)
# model = GritLM("hkunlp/instructor-base", pooling_method="mean", attn=None)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Created GritLM: torch.bfloat16 dtype, mean pool, embedding mode, bbcc attn


### Step 2.3: Build the Index

In [20]:
import chromadb
import torch
import chromadb.utils.embedding_functions as embedding_functions
from chromadb import Documents, EmbeddingFunction, Embeddings
import pickle
import os
from tqdm.notebook import tqdm
import math

chroma_client = chromadb.PersistentClient('./chroma')

## Embedding/Representation ###
instruction = "Given a research query, retrieve the title and abstract of the relevant research paper"

def gritlm_instruction(instruction):
    return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"

def _gritlm_encode_queries(queries: list[str]):
    return model.encode(queries, instruction=gritlm_instruction(instruction))

def _gritlm_encode_documents(documents: list[str]):
    return model.encode(documents, instruction=gritlm_instruction(""))

def write_cache(cache_dir, start_idx, end_idx, embeddings):
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir, exist_ok=True)
    cachefile = os.path.join(cache_dir, f'{start_idx}_to_{end_idx}.pkl')
    with open(cachefile, 'wb') as f:
        pickle.dump(embeddings, f)

def load_cache(cache_dir, start_idx, end_idx):
    cachefile = os.path.join(cache_dir, f'{start_idx}_to_{end_idx}.pkl')
    if os.path.exists(cachefile):
        with open(cachefile, 'rb') as f:
            return pickle.load(f)
    else:
        return None

def gritlm_encode(docs: list[str], encoding_fn, batch_size=256, cache_dir=None):
    out = []
    for i in tqdm(range(0, len(docs), batch_size),total=math.ceil(len(docs)/batch_size)):
        j = min(len(docs), i + batch_size)
        batch = docs[i:j]
        if cache_dir:
            embeddings = load_cache(cache_dir, i, j)
            if embeddings is None:
                embeddings = encoding_fn(batch)
                write_cache(cache_dir, i, j, embeddings)
        else:
            embeddings = encoding_fn(batch)
        out.extend(embeddings)
    return out

def gritlm_encode_queries(queries: list[str], batch_size=256, cache_dir=None):
    return gritlm_encode(queries, _gritlm_encode_queries, batch_size=batch_size, cache_dir=cache_dir)

def gritlm_encode_documents(documents: list[str], batch_size=256, cache_dir=None):
    return gritlm_encode(documents, _gritlm_encode_documents, batch_size=batch_size, cache_dir=cache_dir)

index_name = "litsearch_corpus"
def gritlm_build_index(documents: list[str], index_name: str, embeddings=None, batch_size=4096):
    collection = chroma_client.create_collection(index_name, get_or_create=False)
    if embeddings is None:
        embeddings = gritlm_encode_documents(documents)
    id_list = [str(i) for i in range(len(documents))]
    for i in tqdm(range(0, len(documents), batch_size)):
        collection.add(ids=id_list[i:i+batch_size], documents=documents[i:i+batch_size], embeddings=embeddings[i:i+batch_size])
    return collection

def delete_index(index_name:str):
    return chroma_client.delete_collection(index_name)

In [21]:
import numpy as np


input("going to run indexing again, hit anything to continue... if need new index change cache dir")
documents = corpus_clean_data_with_assembled_title_abstract['chunk'][:]
embeddings = gritlm_encode_documents(documents, cache_dir='./embeddings/litsearch')
embeddings = np.array(embeddings)
# it's already 1 norm

  0%|          | 0/226 [00:00<?, ?it/s]

In [22]:
# build a faiss index
import faiss
import numpy as np

d = embeddings[0].shape[-1]
index = faiss.IndexFlatIP(d)   # build the index
# embeddings = embeddings / np.linalg.norm(embeddings, axis=-1, keepdims=True) # it's already normalized
print(embeddings.shape, np.linalg.norm(embeddings[0]))
index.add(embeddings)
print(index.ntotal)

def faiss_retrieve(query_embeddings, index, k=1):
    D, I = index.search(query_embeddings, k)
    embeddings_topk = embeddings[I.reshape(-1)].reshape(I.shape + (embeddings.shape[-1],))
    return {'distances': D.tolist(), 'ids': I.tolist(), 'embeddings': embeddings_topk.tolist()}

(57657, 4096) 1.0
57657


In [26]:
faiss.write_index(index, 'faiss/litsearch.index', )
faiss.read_index('faiss/litsearch.index')

<faiss.swigfaiss_avx2.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x7fb786b9ff60> >

In [None]:
# build a chromadb index
build=True
if build:
    input("going to overwrite existing index, hit anything to continue...")
    try:
        delete_index(index_name)
    except:
        pass
    index = gritlm_build_index(documents, index_name, embeddings=embeddings)
else:
    index = chroma_client.get_collection(index_name)

## Step 3: Make Predictions

In [23]:
from chromadb import Documents, EmbeddingFunction, Embeddings, Collection
from datasets import Dataset
from tqdm.notebook import tqdm
from faiss import Index
result_fields = ('metadatas', 'documents', 'distances', 'embeddings')
def gritlm_retrieve(queries, index: Collection | Index, k=5):
    query_embeddings = _gritlm_encode_queries(queries)
    if isinstance(index, Collection):
        results = index.query(query_embeddings=query_embeddings, n_results=k, include=list(result_fields))
    elif isinstance(index, Index):
        results = faiss_retrieve(query_embeddings, index, k=k)
    return results

def merge_results(batched_results, result_fields=result_fields):
    out = {}
    for field in result_fields + ('ids',):
        if field in batched_results[0]: # only include existing keys
            out[field] = sum([res[field] for res in batched_results], [])
    return out

results = []
for i in tqdm(range(len(query_data))):
    result = gritlm_retrieve(queries=[query_data[i]['query']], index=index, k=20)
    results.append(result)

results = merge_results(results)
Dataset.from_dict(results).to_parquet('retrieval_results/litsearch.parquet')

  0%|          | 0/597 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

391495884

## Step 4: Evaluate Predictions

In [None]:
from collections import defaultdict
import numpy as np
from tqdm.notebook import tqdm

def map_str_ids_to_corpus_ids(ids, source=corpus_clean_data):
    return [source[int(id)]['corpusid'] for id in ids]

def recall(predictions, relevant_set):
    relevant_count = 0
    recall_scores = []
    for i, pred in enumerate(predictions):
        relevant_count += (pred in relevant_set)
        recall_scores.append(relevant_count / len(relevant_set))
    return np.array(recall_scores)

def print_recall_scores(recall_scores, ks=None):
    if ks is None:
        ks = list(range(1, len(recall_scores) + 1))
    for k in ks:
        print(f'recall@{k}: {recall_scores[k-1]}')

In [13]:

retrieval_results = load_dataset('parquet', data_files='retrieval_results/litsearch.parquet', split='train')
assert len(retrieval_results) == len(query_data)

# overall
recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if query_data[i]['quality'] == 0: continue
    result = retrieval_results[i]
    corpus_ids = map_str_ids_to_corpus_ids(result['ids'])
    recall_scores = recall(corpus_ids, set(query_data[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("overall")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if query_data[i]['specificity'] != 1 or "inline" not in query_data[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = map_str_ids_to_corpus_ids(result['ids'])
    recall_scores = recall(corpus_ids, set(query_data[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("inline specific")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if query_data[i]['specificity'] != 1 or "manual" not in query_data[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = map_str_ids_to_corpus_ids(result['ids'])
    recall_scores = recall(corpus_ids, set(query_data[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("author specific")
print_recall_scores(recall_scores_all, ks=[5,20])

  0%|          | 0/597 [00:00<?, ?it/s]

overall
recall@5: 0.6913456169737577
recall@20: 0.8001395868230039


  0%|          | 0/597 [00:00<?, ?it/s]

inline specific
recall@5: 0.6774891774891775
recall@20: 0.7792207792207793


  0%|          | 0/597 [00:00<?, ?it/s]

author specific
recall@5: 0.8246445497630331
recall@20: 0.8909952606635071


# Augment the query data with their gritlm recall

In [8]:
from tqdm import tqdm 
retrieval_results = load_dataset('parquet', data_files='retrieval_results/litsearch.parquet', split='train')
recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if query_data[i]['quality'] == 0: continue
    result = retrieval_results[i]
    corpus_ids = map_str_ids_to_corpus_ids(result['ids'])
    recall_scores = recall(corpus_ids, set(query_data[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all)

100%|██████████| 597/597 [00:07<00:00, 75.66it/s]


In [11]:
query_data_with_score = query_data.map(lambda ex, i: {'grit_recall': recall_scores_all[i, 19]}, with_indices=True)
query_data_with_score.to_parquet('query_with_score.parquet')

Map:   0%|          | 0/597 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

117734