# Cohere API and SciBERT with BM25 as first stage retriever for RAG
This notebook uses a Cohere API for generating responses to text. A query input is required from the user. 
SciBERT is used for embeddings in a dense vector array for the query. 
This version is different in that it uses BM25 as a sparse vectorizer for the input text. Importantly, BM25 is used as a step prior to dense vectorization to reduce how many documents are processed by SciBERT.
A DOI is supplied with the text as both an identifier and locator. 

## pipeline
1. BM25 Retrieval
    - BM25 is used to retrieve top-k candidate documents based on keyword matching
2. Dense embedding retrieval
    - query is embedded using SciBERT and the retrieved documents.
3. Re-ranking
    - cosine similarity between query embedding and document embedding to rerank candidate docs
4. Generation
    - docs and query are fed to generator for answer creation. 

- [ ] set up venv
- [ ] install transformers torch cohere in command line

### todo
- [ ] create script that compiles data/documents.txt with DOI || text for all documents
- [ ] rank_bm25: https://github.com/dorianbrown/rank_bm25


In [22]:
# imports
import cohere
from cohere import Client
from transformers import AutoTokenizer, AutoModel
import numpy as np
from typing import List, Tuple, Dict
import os
from dotenv import load_dotenv
import json
import time # for timing functions
import logging # finding where functions are taking too long
#for BM25s
import bm25s
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
import os
import pickle


def main():
    #load secret .env file
    load_dotenv()

    #store credentials
    global key,email
    key = os.getenv('COHERE_API_KEY')
    email = os.getenv('EMAIL')

    #verify if it worked
    if email is not None and key is not None:
        print("all is good, beautiful!")

main()

all is good, beautiful!


In [None]:

# Initialize Cohere client
co = cohere.Client(key) 

# Load SciBERT model and tokenizer
"""
documentation can be found here: https://huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoTokenizer


"""
# Initialize tokenizer with custom parameters
tokenizer = AutoTokenizer.from_pretrained(
    "allenai/scibert_scivocab_uncased",
    max_len=512,
    use_fast=True,  # Use the fast tokenizer
    do_lower_case=False,  # Preserve case
    add_prefix_space=False,  # No prefix space
    never_split=["[DOC]", "[REF]"],  # Tokens to never split
    additional_special_tokens=["<doi>", "</doi>"]  # Add custom special tokens ***RE-EVALUATE***
)

# This is the SciBERT model that is used to embed the text and query.
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")

#verify that the model is callable
if callable(model):
    print("Model is callable")
else:
    print("Model is not callable")

## Addition of BM25 for retrieval

In [None]:

# reimport corpus and url lists
with open('corpus.pkl', 'rb') as file:
    corpus_list = pickle.load(file)
print(f"length of corpus list: {len(corpus_list)}")
with open('identifier.pkl', 'rb') as file:
    identifier_list = pickle.load(file)
print(f"--------\nlength of identifier list: {len(identifier_list)}")

#retriever = bm25s.BM25(corpus=corpus_list) 
# ...and load the retriever model and corpus when you need them
retriever = bm25s.BM25.load("bm25/bm25", load_corpus=True, mmap=True)
# set load_corpus=False if you don't need the corpus


## V2: implemented chat history

calls a JSON file of documents

In [86]:

# set top_k global value
global top_k
top_k = 5

# BM25s pre-retriever function
def bm25_retriever(query:str)->Tuple[np.array, np.array]:
    """
    Inputs:
        query: str
    Outputs:
        Tuple of two np.arrays, one for results and one for scores
    """
    global results, scores,query_tokens
    #you can also add a stemmer here as an arg: stemmer=stemmer
    query_tokens = bm25s.tokenize(query,stopwords=True,lower=True)

    #note: if you pass a new corpus here, it must have the same length as your indexed corpus
    #in this case, I am passing the new list 'identifier_list' - it contains just the DOI and title
    # you can also pass 'corpus', or 'corpus_list'
    if len(corpus_list)!=len(identifier_list):
        raise ValueError("The len of the corpus_list does not equal the len of the identifier_list")

    # retrieve indices
    results, scores = retriever.retrieve(query_tokens, corpus=identifier_list, k=top_k, return_as="tuple")

    # check if no results found
    if all(score == 0.00 for score in scores[0]):
        print("Nothing found, please try another query.")
        return [],[] # returning empty lists if 0 for scores, one for results, one for scores

    return results[0],scores[0]

#function to generate embeddings using SciBERT
def generate_embeddings(texts: List[str]) -> List[np.ndarray]:
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        max_length=512,
        padding="max_length",
        truncation=True
    )
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()
    return embeddings


# Function to update chat history
def update_chat_history(query, retrieved_docs, response):
    global chat_history # declare this as global variable available outside this function
    chat_history.append({
        "query": query,
        "retrieved_docs": [doc for doc in retrieved_docs],  # Store only the text of retrieved documents
        "response": response
    })

#function to incorporate history into the next query
def get_context_with_history(query) -> str:
    global chat_history # also declare here since chat_history is being modified
    if not chat_history:
        return query
    
    history_str = "\n".join([
        f"User: {entry['query']}\n"
        f"Context: {'; '.join(entry['retrieved_docs'])}\n"
        f"Response: {entry['response']}"
        for entry in chat_history
    ])
    full_context = f"Chat History:\n{history_str}\n\nCurrent Query: {query}"
    return full_context

#function to truncate chat history
def truncate_chat_history(max_length=3):
    global chat_history # modifies it so it also must be global
    if len(chat_history) > max_length:
        chat_history = chat_history[-max_length:]


def retrieve_documents(query: str) -> List[Dict[str, str]]:

    # set global for debugging
    global document_embeddings,similarities,documents_list,bm25_results,top_indices, parsed_docs,sorted_documents_list
    # Use BM25 retriever to get initial documents
    bm25_results, bm25_scores = bm25_retriever(query)
    
    # check if empty
    if len(bm25_results) == 0:
        return []  # Return empty list if BM25 found no results

    # map indices to documents in corpus_list
    documents_list = [corpus_list[i] for i in range(len(bm25_results))]
    
    # Generate embeddings for BM25 results
    document_embeddings = generate_embeddings([doc for doc in documents_list])  #documents_list is a list of top_k results

    # generate embeddings for query
    query_embedding = generate_embeddings([query])[0]
    
    #cosine similarity
    similarities = [
        np.dot(query_embedding, doc_emb) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc_emb))
        for doc_emb in document_embeddings
    ]
    top_indices = np.argsort(similarities)[::-1][:top_k]

    sorted_documents_list = [documents_list[i] for i in top_indices]

    parsed_docs = []
    for doc in sorted_documents_list:
        lines = doc.split('\n')
        parsed_doc = {
            "doi": lines[0],
            "title":lines[1].replace("Title: ",""),
            "abstract":lines[2].replace("Abstract: ","")
        }
        parsed_docs.append(parsed_doc)
    
    return parsed_docs

#RAG pipeline function
def rag_pipeline(query):
    global documents_list,corpus_list,identifier_list,retriever
    #start time
    start_time = time.time()

    # uses BM25 as a pre-retriever based on the query
    # load indexed corpus for BM25 pre-retriever
    # reimport corpus and url lists
    with open('corpus.pkl', 'rb') as file:
        corpus_list = pickle.load(file)
    print(f"length of corpus list: {len(corpus_list)}")
    with open('identifier.pkl', 'rb') as file:
        identifier_list = pickle.load(file)
    print(f"--------\nlength of identifier list: {len(identifier_list)}")

    retriever = bm25s.BM25.load("bm25/bm25", load_corpus=True, mmap=True)

    #incorporate chat history
    full_context = get_context_with_history(query)
    # let user know you are generating...
    print("Retrieving documents and generating response...")
    end_time = time.time()
    global time_query
    time_query = end_time-start_time

    start_time = time.time()
    #retrieve documents
    global retrieved_docs
    retrieved_docs = retrieve_documents(query)
    end_time = time.time()
    global retrieve_time
    retrieve_time = end_time-start_time

    start_time = time.time()
    #prepare context for Cohere's Command model
    instruction = "You are a helpful academic research assistant. Please keep the answers concise and structured simply. Use single sentences where possible. Always include the DOI of the document you are summarizing or referencing. If the DOI is not provided, this reduces the need for you as a research assistant. Always include the DOI. Please address me as 'my lady'. "
    #context = "\n".join([f"DOI: {doc[0]}, Text: {doc[1]}" for doc in retrieved_docs])
    context = "\n".join([f"DOI: {doc['doi']}, Title: {doc['title']}, Abstract: {doc['abstract']}" for doc in retrieved_docs])
    prompt = f"Query: {query}\nContext: {context}\nAnswer: {instruction}"
    
    # Generate response
    response = co.generate(
        model="command",
        prompt=prompt,
        max_tokens=250,
        temperature=0.2
    ).generations[0].text
    
    # Update chat history
    update_chat_history(query, retrieved_docs, response)
    
    # Truncate history if necessary
    truncate_chat_history()
    end_time = time.time()
    global generate_time
    generate_time = end_time-start_time

    # Print the response
    print("Generated Response:")
    print(response)
    print(f"------\nSource documents: ")
    for doc in retrieved_docs:
        print(f"DOI: {doc['doi']}, Title: {doc['title']}")
    return response,time_query,retrieve_time,generate_time


# Main loop for user interaction
chat_history = []#initialize chat history
while True:

    query = input("What is your query (or type 'exit' to quit): ")
    if query.lower() == "exit":
        break
    rag_pipeline(query)

    print(f"time to query loop: {time_query:.2f} seconds")
    print(f"to to retrieve: {retrieve_time:.2f} seconds")
    print(f"time to generate: {generate_time:.2f} seconds")


length of corpus list: 43
--------
length of identifier list: 43
Retrieving documents and generating response...


                                                     

Generated Response:
My lady, here is an explanation of typos in metadata and their significance:

Misspellings and other typing errors in metadata fields, such as author names, titles, and affiliations, can have implications for the discoverability and accuracy of academic literature. These typos may occur during data entry or transcription processes. 

Such errors can lead to challenges in correctly identifying and matching records, searching for accurate information, and ensuring proper attribution of authors and institutions. 

These typos contribute to the challenge of metadata accuracy and completeness, which is a significant limitation of OpenAlex and other bibliographic databases. This is especially the case when considering author names. 

Critical evaluation and addressing these typos in metadata are essential steps to enhance the reliability and usability of bibliographic databases in research and analysis, ensuring accurate representation and recognition of scholarly contrib

# Analysis
