# mediRAG Pipeline

In [None]:
import os
import torch
from prompts import *
import evaluate
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.chains import LLMChain
from langchain.schema.runnable import RunnablePassthrough

In [None]:
torch.cuda.set_device(0)  # have to change depending on which device u wana use
torch.cuda.current_device()

## Load Model and Tokenizer

In [None]:
model_name='mistralai/Mistral-7B-Instruct-v0.1'

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    cache_dir="models"
)

# Load Data

In [None]:
dataset = load_dataset("bigbio/pubmed_qa", cache_dir="data")

dataset

In [None]:
page_content_column = "CONTEXTS"

def preprocess(dataset):
    for split in dataset.keys():
        for contexts in dataset[split][page_content_column]:
            for sentence in contexts:
                yield Document(page_content=sentence)

data = list(preprocess(dataset))  # 655055

data[0]

## Setting up FAISS

In [None]:
embedding_model = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device':'cuda'}
encode_kwargs = {'normalize_embeddings': False}

# Initialize an instance of HuggingFaceEmbeddings with the specified parameters
embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model,   
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs, 
    cache_folder="models"
)

if os.path.exists("faiss_index_pubmed"):
    db = FAISS.load_local("faiss_index_pubmed", embeddings)
else:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    docs = text_splitter.split_documents(data)  # 676307

    db = FAISS.from_documents(docs, embeddings)
    db.save_local("faiss_index_pubmed")

In [None]:
question = dataset['train']["QUESTION"][0]
context = dataset['train']["CONTEXTS"][0]

retrieved_docs = db.similarity_search(question)  # db.similarity_search_with_score(question)

print(f"Question:\n{question}")
print(f"\nContext:\n{context}")
print(f"\nRetrieved document:\n{retrieved_docs[0].page_content}\n{retrieved_docs[1].page_content}\n{retrieved_docs[2].page_content}")

## Initializing Pipeline


In [None]:
text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    max_new_tokens=300,
    do_sample=False,
)

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

# Create prompt from prompt template 
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=PROMPT_TEMPLATE_QA_ANSWER_START,
)

# Create llm chain 
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)

## Running Queries

In [None]:
idx = 0

questions = dataset['train'][2:12]["QUESTION"]
contexts = dataset['train'][2:12]["CONTEXTS"]
long_answers = dataset['train'][2:12]["LONG_ANSWER"]
final_decisions = dataset['train'][2:12]["final_decision"]

## QA without Retrieval

In [None]:
pred_no_ret = []
for question in questions:
    # question = f'{question}. ALong with the answer, Explicitly state where the answer to the question is yes, no or maybe'
    print("Question: ", question)
    input_ids = tokenizer.encode(question, return_tensors="pt").to("cuda")

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            do_sample=False,
            return_dict_in_generate=True,
            max_new_tokens=300,
        )

    output = tokenizer.decode(output.sequences[0][len(input_ids[0]):])
    print("Generated Answer: ", output)

    pred_no_ret.append(output)
# output

In [None]:
def find_citations(predictions):
    finalpred = []
    for output in predictions:
        output_with_citations = ""
        citations = ""
        citation_list = []

        for lines in output.split("\n"):
            lines = lines.strip()
            if len(lines.split(" ")) > 10:
                for line in lines.split("."):
                    line = line.strip()
                    docs_and_scores = db.similarity_search_with_score(line)[0]  # choosing top 1 relevant document
                    if docs_and_scores[1] < 0.5:  # returned distance score is L2 distance, a lower score is better
                        doc_content = docs_and_scores[0].page_content
                        if doc_content in citation_list:
                            idx = citation_list.index(doc_content)

                        else:
                            citation_list.append(doc_content)
                            idx = len(citation_list)
                            citations += f"[{idx}] {doc_content}\n"

                        output_with_citations += line + f" [{idx}]. "

        final_output_with_citations = output_with_citations + "\n\nCitations:\n" + citations
        finalpred.append(final_output_with_citations)
    return finalpred

final_pred_citations = find_citations(pred_no_ret)

In [None]:
print(final_pred_citations[0])

## QA with Retrieval

In [None]:
pred_ret = []
for question in questions:
    print("Question: ", question)
    retriever = db.as_retriever(
        search_type="similarity",
        search_kwargs={'k': 3}
    )

    # retriever = db.as_retriever(search_type="similarity_score_threshold", 
    #                                  search_kwargs={"score_threshold": .5, 
    #                                                 "k": top_k})

    rag_chain = ({"context": retriever, "question": RunnablePassthrough()} | llm_chain)

    # QA with retrieval
    qa_retrieval_result = rag_chain.invoke(question)
    print("Generated Answer: ", qa_retrieval_result['text'])
    pred_ret.append(qa_retrieval_result["text"])

# Evaluation

In [None]:
bleu = evaluate.load("bleu", cache_dir="evaluation_metrics")  # value ranges from 0 to 1. score of 1 is better

bleu_score = bleu.compute(predictions=pred_no_ret, references=long_answers)
print(f"Vanilla QA: BLEU Score: {bleu_score}")

bleu_score = bleu.compute(predictions=final_pred_citations, references=long_answers)
print(f"QA with Citations: BLEU Score: {bleu_score}")

bleu_score = bleu.compute(predictions=pred_ret, references=long_answers)
print(f"QA with Retrieval: BLEU Score: {bleu_score}")

In [None]:
bertscore = evaluate.load("bertscore", cache_dir="evaluation_metrics")

bert_score = bertscore.compute(predictions=pred_no_ret, references=long_answers , lang="en", batch_size =1)
print(f"Vanilla QA: BERTScore: {bert_score}")

bert_score = bertscore.compute(predictions=final_pred_citations, references=long_answers , lang="en" , batch_size =1)
print(f"QA with Citations: BERTScore: {bert_score}")

bert_score = bertscore.compute(predictions=pred_ret, references=long_answers , lang="en", batch_size =1 )
print(f"QA with Retrieval: BERTScore: {bert_score}")

In [None]:
def acc_calc_final(predictions, references):
    acc = 0
    for i in range(len(predictions)):
        # print(references[i].lower(), predictions[i].lower())
        if references[i].lower() in predictions[i].lower():
            acc += 1
    return acc / len(predictions)

acc = acc_calc_final(predictions=pred_no_ret, references=final_decisions)
print(f"Vanilla QA: acc: {acc}")

acc = acc_calc_final(predictions=final_pred_citations, references=final_decisions)
print(f"QA with Citations: acc: {acc}")

acc = acc_calc_final(predictions=pred_ret, references=final_decisions)
print(f"QA with Retrieval: acc: {acc}")
