# Notes

Basically there's two important things needed to reproduce the numbers
* use the same chunk format
* use the same instruction prompt for gritlm
* deduplicate the corpus, taking last entry as canonical!
* use exact NN search (no chromadb).

In [18]:
from collections import defaultdict
import numpy as np
from tqdm.notebook import tqdm
from datasets import load_dataset

corpus_clean_data = load_dataset('princeton-nlp/LitSearch', "corpus_clean", split="full")
corpus_clean_data_with_assembled_title_abstract = corpus_clean_data.map(lambda x: {'chunk': f"Title: {x['title']}\nAbstract: {x['abstract']}"})

doc_2_ids = defaultdict(list)
for line in corpus_clean_data_with_assembled_title_abstract:
    doc_2_ids[line['chunk']].append(line['corpusid'])

multiplicity = {val[-1]: len(val) for val in doc_2_ids.values()}
print(len(multiplicity))
print(max(multiplicity.values()))
print(set(len(val) for val in doc_2_ids.values()))
def recall(predictions, relevant_set, expand_multiplicity=False):
    if expand_multiplicity:
        total = len(predictions)
        predictions = [pred for pred in predictions for _ in range(multiplicity[pred])]
        if len(predictions) > total:
            pass
        predictions = predictions[:total]
    relevant_count = 0
    recall_scores = []
    for i, pred in enumerate(predictions):
        relevant_count += (pred in relevant_set)
        recall_scores.append(relevant_count / len(relevant_set))
    return np.array(recall_scores)

def print_recall_scores(recall_scores, ks=None):
    if ks is None:
        ks = list(range(1, len(recall_scores) + 1))
    for k in ks:
        print(f'recall@{k}: {recall_scores[k-1]}')

retrieval_results = load_dataset('json', data_files='LitSearch/results/retrieval/LitSearch.title_abstract.bm25.jsonl', split='train')

print("BM25")
# overall
recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("overall")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 0 or 'inline' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("inline specific")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 1 or 'inline' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("inline broad")
print_recall_scores(recall_scores_all, ks=[5,20])


recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 0 or 'manual' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("author specific")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 1 or 'manual' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("author broad")
print_recall_scores(recall_scores_all, ks=[5,20])



retrieval_results = load_dataset('json', data_files='LitSearch/results/retrieval/LitSearch.title_abstract.grit.jsonl', split='train')
print("GRITLM")
# overall
recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("overall")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 0 or 'inline' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("inline specific")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 1 or 'inline' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("inline broad")
print_recall_scores(recall_scores_all, ks=[5,20])


recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 0 or 'manual' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("author specific")
print_recall_scores(recall_scores_all, ks=[5,20])

recall_scores_all = []
for i in tqdm(range(len(retrieval_results))):
    if retrieval_results[i]['quality'] == 0 or retrieval_results[i]['specificity'] == 1 or 'manual' not in retrieval_results[i]['query_set']: continue
    result = retrieval_results[i]
    corpus_ids = retrieval_results[i]['retrieved']
    recall_scores = recall(corpus_ids, set(retrieval_results[i]['corpusids']))
    recall_scores_all.append(recall_scores)
recall_scores_all = np.vstack(recall_scores_all).mean(axis=0)
print("author broad")
print_recall_scores(recall_scores_all, ks=[5,20])


57657
6197
{1, 2, 3, 4, 5, 6, 8, 6197}
BM25


  0%|          | 0/597 [00:00<?, ?it/s]

overall
recall@5: 0.4380792853154662
recall@20: 0.5793969849246231


  0%|          | 0/597 [00:00<?, ?it/s]

inline specific
recall@5: 0.3852813852813853
recall@20: 0.5584415584415584


  0%|          | 0/597 [00:00<?, ?it/s]

inline broad
recall@5: 0.22944444444444442
recall@20: 0.37416666666666665


  0%|          | 0/597 [00:00<?, ?it/s]

author specific
recall@5: 0.6255924170616114
recall@20: 0.7345971563981043


  0%|          | 0/597 [00:00<?, ?it/s]

author broad
recall@5: 0.37142857142857144
recall@20: 0.4857142857142857
GRITLM


  0%|          | 0/597 [00:00<?, ?it/s]

overall
recall@5: 0.6913456169737577
recall@20: 0.8001395868230039


  0%|          | 0/597 [00:00<?, ?it/s]

inline specific
recall@5: 0.6774891774891775
recall@20: 0.7792207792207793


  0%|          | 0/597 [00:00<?, ?it/s]

inline broad
recall@5: 0.5269444444444444
recall@20: 0.6973611111111111


  0%|          | 0/597 [00:00<?, ?it/s]

author specific
recall@5: 0.8246445497630331
recall@20: 0.8909952606635071


  0%|          | 0/597 [00:00<?, ?it/s]

author broad
recall@5: 0.5428571428571428
recall@20: 0.7428571428571429
