# mediRAG Pipeline

In [None]:
import os
import torch
import numpy as np
from prompts import *
import evaluate
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.chains import LLMChain
from langchain.schema.runnable import RunnablePassthrough

In [None]:
torch.cuda.set_device(0)  # have to change depending on which device u wana use
torch.cuda.current_device()

## Load Model and Tokenizer

In [None]:
# # Use a pipeline as a high-level helper
# from transformers import pipeline

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=False,
# )

# pipe = pipeline("text-generation", model="epfl-llm/meditron-7b", model_kwargs={'cache_dir':"models", 'quantization_config':bnb_config}, device_map = "auto")

In [None]:
# def get_prompt(question):
#     prompt = f"""<|im_start|>system
#     Answer the users questions:<|im_end|>
#     <|im_start|>question
#     {question}<|im_end|>
#     <|im_start|>answer  
#     """
#     return prompt

# Question=  'Is heart failure with preserved ejection fraction characterized by dynamic impairment of active relaxation and contraction of the left ventricle on exercise and associated with myocardial energy deficiency?'
# pipe(get_prompt(Question))

In [None]:
from config import HF_TOKEN
model_name= 'mistralai/Mistral-7B-Instruct-v0.1'
# model_name= 'epfl-llm/meditron-7b'

tokenizer = AutoTokenizer.from_pretrained(model_name,   cache_dir="models") #, token = HF_TOKEN)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    cache_dir="models",
    device_map = "auto",# token = HF_TOKEN
    # load_in_8bit = True,
    # load_in_8bit_fp32_cpu_offload=True
)

# Load Data

In [None]:
dataset = load_dataset("bigbio/pubmed_qa", cache_dir="data")

dataset

In [None]:
page_content_column = "CONTEXTS"

def preprocess(dataset):
    for split in dataset.keys():
        for contexts in dataset[split][page_content_column]:
            for sentence in contexts:
                yield Document(page_content=sentence)

data = list(preprocess(dataset))  # 655055

data[0]

## Setting up FAISS

In [None]:
embedding_model = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device':'cuda'}
encode_kwargs = {'normalize_embeddings': False}

# Initialize an instance of HuggingFaceEmbeddings with the specified parameters
embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model,   
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs, 
    cache_folder="models"
)

if os.path.exists("faiss_index_pubmed"):
    db = FAISS.load_local("faiss_index_pubmed", embeddings)
else:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    docs = text_splitter.split_documents(data)  # 676307

    db = FAISS.from_documents(docs, embeddings)
    db.save_local("faiss_index_pubmed")

In [None]:
question = dataset['train']["QUESTION"][0]
context = dataset['train']["CONTEXTS"][0]
long_answers = dataset['train']["LONG_ANSWER"][0]
final_decisions = dataset['train']["final_decision"][0]

retrieved_docs = db.similarity_search(question)  # db.similarity_search_with_score(question)

print(f"Question:\n{question}")
print(f"\nContext:\n{context}")
print(f"Answer:\n{long_answers}")
print(f"\nFinal_Decisions:\n{final_decisions}")
print(f"\nRetrieved document:\n{retrieved_docs[0].page_content}\n{retrieved_docs[1].page_content}\n{retrieved_docs[2].page_content}")

## Initializing Pipeline


In [None]:
text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    max_new_tokens=300,
    do_sample=False,
)

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

# Create prompt from prompt template 
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=FEW_SHOT_PROMPT,
)

# Create llm chain 
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)

## Running Queries

In [None]:
idx = 0

questions = dataset['train'][2:102]["QUESTION"]
contexts = dataset['train'][2:102]["CONTEXTS"]
long_answers = dataset['train'][2:102]["LONG_ANSWER"]
final_decisions = dataset['train'][2:102]["final_decision"]

## QA without Retrieval

In [None]:
from tqdm.notebook import tqdm
def get_prompt(question):
    # prompt = """Think Step by Step and Answer the users question:

    # Question: {}  
    # """.format(question)
    prompt = f"""Answer the users question. An example is given below:

    <example>
    Question:
    Does neurobehavioral disinhibition predict initiation of substance use in children with prenatal cocaine exposure?
    Answer:
    Yes, Prenatal drug exposure appears to be a risk pathway to ND, which by 8/9 years portends substance use initiation.
    </example>
    Question: {question}  
    """
    return prompt    
    # return f"""Answer the Users Questions with explainations
    # Question: {question}
    # Answer:
    # """

# def get_prompt(question):
#     prompt = f"""<|im_start|>system
#     Answer the users questions:<|im_end|>
#     <|im_start|>question
#     {question}<|im_end|>
#     <|im_start|>answer  
#     """
#     return prompt


pred_no_ret = []
for question in tqdm(questions):
    # question = f'{question}. ALong with the answer, Explicitly state where the answer to the question is yes, no or maybe'
    print("Question: ", question)
    prompt = get_prompt(question)
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cpu")
    # print(input_ids)
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            do_sample=False,
            return_dict_in_generate=True,
            max_new_tokens=300,
        )

    output = tokenizer.decode(output.sequences[0][len(input_ids[0]):])
    # output = text_generation_pipeline(prompt, CUDA_LAUNCH_BLOCKING=1)
    # output = pipe(prompt)
    print("Generated Answer: ", output)

    pred_no_ret.append(output)
# output

In [None]:
def find_citations(predictions):
    finalpred = []
    for output in predictions:
        output_with_citations = ""
        citations = ""
        citation_list = []

        for lines in output.split("\n"):
            lines = lines.strip()
            if len(lines.split(" ")) > 10:
                for line in lines.split("."):
                    line = line.strip()
                    docs_and_scores = db.similarity_search_with_score(line)[0]  # choosing top 1 relevant document
                    if docs_and_scores[1] < 0.5:  # returned distance score is L2 distance, a lower score is better
                        doc_content = docs_and_scores[0].page_content
                        if doc_content in citation_list:
                            idx = citation_list.index(doc_content)

                        else:
                            citation_list.append(doc_content)
                            idx = len(citation_list)
                            citations += f"[{idx}] {doc_content}\n"

                        output_with_citations += line + f" [{idx}]. "

        final_output_with_citations = output_with_citations + "\n\nCitations:\n" + citations
        finalpred.append(final_output_with_citations)
    return finalpred

final_pred_citations = find_citations(pred_no_ret)

In [None]:
print(final_pred_citations[0])

## QA with Retrieval

In [None]:
from tqdm.notebook import tqdm
pred_ret = []
for question in tqdm(questions):
    print("Question: ", question)
    retriever = db.as_retriever(
        search_type="similarity",
        search_kwargs={'k': 3}
    )

    # retriever = db.as_retriever(search_type="similarity_score_threshold", 
    #                                  search_kwargs={"score_threshold": .5, 
    #                                                 "k": top_k})

    rag_chain = ({"context": retriever, "question": RunnablePassthrough()} | llm_chain)

    # QA with retrieval
    qa_retrieval_result = rag_chain.invoke(question)
    print("Generated Answer: ", qa_retrieval_result['text'])
    pred_ret.append(qa_retrieval_result["text"])

In [None]:
import json
data = {"model": model_name,
        "prompt": FEW_SHOT_PROMPT,
        # "TOT": {'results':pred_ret},
        "Vanilla": {'results': pred_no_ret}, #, 'metrics': {'BertScore': bert_score_no_ret, 'BLEU': bleu_score_no_ret, "Accuracy": acc_no_ret, "Perplexity":perp_no_ret}},
         "With Citations": { 'results':final_pred_citations}, #, 'metrics': {'BertScore': bert_score_citations, 'BLEU': bleu_score_citations, "Accuracy": acc_citations, "Perplexity":perp_citations}},
           "With Retrieval": {'results':pred_ret}, #, 'metrics': {'BertScore': bert_score_ret, 'BLEU': bleu_score_ret, "Accuracy": acc_ret, "Perplexity":perp_ret}}, 
           "ground_truth": {"long_answers": long_answers, "final_decisions": final_decisions}}

json_data = json.dumps(data, indent=2)

with open('oneshot_Full_results.json', 'w') as json_file:
    json_file.write(json_data)

# Evaluation

In [1]:
import json
import evaluate

# Read the JSON data from the file
with open('oneshot_Full_results.json', 'r') as json_file:
    loaded_data = json.load(json_file)

# Access the lists from the loaded data
pred_no_ret = loaded_data['Vanilla']['results']
final_pred_citations = loaded_data['With Citations']['results']
pred_ret = loaded_data['With Retrieval']['results']
long_answers = loaded_data['ground_truth']['long_answers']
final_decisions = loaded_data['ground_truth']['final_decisions']

# def preprocess(preds):
#     new_pred = []
#     for pred in preds:
#         new_pred.append(pred[5:])
#     return new_pred

def clean_cit(preds):
    new_pred = []
    for pred in preds:
        new_pred.append(pred.split("\n\nCitations:\n")[0])
    return new_pred

def rem_comment(preds):
    new_pred = []
    for pred in preds:
        new_pred.append(pred.split("\n\nComment")[0])
    return new_pred

# pred_no_ret = preprocess(pred_no_ret)
final_pred_citations = clean_cit(final_pred_citations)
# pred_ret = rem_comment(pred_ret)


In [2]:
bleu = evaluate.load("bleu", cache_dir="evaluation_metrics")  # value ranges from 0 to 1. score of 1 is better

bleu_score_no_ret = bleu.compute(predictions=pred_no_ret, references=long_answers)
print(f"Vanilla QA: BLEU Score: {bleu_score_no_ret}")

bleu_score_citations = bleu.compute(predictions=final_pred_citations, references=long_answers)
print(f"QA with Citations: BLEU Score: {bleu_score_citations}")

bleu_score_ret = bleu.compute(predictions=pred_ret, references=long_answers)
print(f"QA with Retrieval: BLEU Score: {bleu_score_ret}")

Vanilla QA: BLEU Score: {'bleu': 0.053022689170644156, 'precisions': [0.344866491162091, 0.12270418132082844, 0.06588043920292802, 0.03603221704111912], 'brevity_penalty': 0.5296296529571533, 'length_ratio': 0.6114049206714187, 'translation_length': 2659, 'reference_length': 4349}
QA with Citations: BLEU Score: {'bleu': 0.04331043272050909, 'precisions': [0.364613476643241, 0.11929371231696813, 0.06157303370786517, 0.03195488721804511], 'brevity_penalty': 0.45029590241911394, 'length_ratio': 0.5562198206484249, 'translation_length': 2419, 'reference_length': 4349}
QA with Retrieval: BLEU Score: {'bleu': 0.02859670911874478, 'precisions': [0.12182410423452769, 0.03769595333576376, 0.01696658097686375, 0.008583055863854976], 'brevity_penalty': 1.0, 'length_ratio': 3.17659232007358, 'translation_length': 13815, 'reference_length': 4349}


In [3]:
import numpy as np

bertscore = evaluate.load("bertscore", cache_dir="evaluation_metrics")

bert_score_no_ret = bertscore.compute(predictions=pred_no_ret, references=long_answers , lang="en", batch_size =1)
bert_score_no_ret = {key: np.mean(value) if key!= "hashcode" else value for key, value in bert_score_no_ret.items()}
print(f"Vanilla QA: BERTScore: {bert_score_no_ret}")

bert_score_citations = bertscore.compute(predictions=final_pred_citations, references=long_answers , lang="en" , batch_size =1)
bert_score_citations = {key: np.mean(value) if key!= "hashcode" else value for key, value in bert_score_citations.items()}
print(f"QA with Citations: BERTScore: {bert_score_citations}")

bert_score_ret = bertscore.compute(predictions=pred_ret, references=long_answers , lang="en", batch_size =1 )
bert_score_ret = {key: np.mean(value) if key!= "hashcode" else value for key, value in bert_score_ret.items()}
print(f"QA with Retrieval: BERTScore: {bert_score_ret}")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Vanilla QA: BERTScore: {'precision': 0.8677240765094757, 'recall': 0.8686194777488708, 'f1': 0.8679354679584503, 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.2)'}




QA with Citations: BERTScore: {'precision': 0.8474992269277573, 'recall': 0.8440714913606644, 'f1': 0.845581641793251, 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.2)'}
QA with Retrieval: BERTScore: {'precision': 0.8225240308046341, 'recall': 0.8677024567127227, 'f1': 0.8440739339590073, 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.2)'}


# Perplexity

In [4]:
perplexity = evaluate.load("perplexity", module_type="metric", cache_dir="evaluation_metrics")

perp_no_ret = perplexity.compute(model_id='gpt2',
                             add_start_token=False,
                             predictions=pred_no_ret, 
                             batch_size =2)
print(perp_no_ret['mean_perplexity'])

# perp_citations = perplexity.compute(model_id='gpt2',
#                              add_start_token=False,
#                              predictions=final_pred_citations, 
#                              batch_size =2)
# print(perp_citations['mean_perplexity'])

perp_ret = perplexity.compute(model_id='gpt2',
                             add_start_token=False,
                             predictions=pred_ret, 
                             batch_size =2)
print(perp_ret['mean_perplexity'])

  0%|          | 0/50 [00:00<?, ?it/s]

113.20494256973267


  0%|          | 0/50 [00:00<?, ?it/s]

22.3957336640358


In [7]:
meteor = evaluate.load('meteor', cache_dir="evaluation_metrics")
met_no_ret = meteor.compute(predictions=pred_no_ret, references=long_answers)
print(met_no_ret)
met_citations = meteor.compute(predictions=final_pred_citations, references=long_answers)
print(met_citations)
met_ret = meteor.compute(predictions=pred_ret, references=long_answers)
print(met_ret)


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\ptejd\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ptejd\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\ptejd\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


{'meteor': 0.1715961297070393}
{'meteor': 0.19560881179840028}
{'meteor': 0.23214835260102568}


In [6]:
def acc_calc_final(predictions, references):
    acc = 0
    for i in range(len(predictions)):
        # print(references[i].lower(), predictions[i].lower())
        if references[i].lower() in predictions[i][:15].lower():
            acc += 1
    return acc / len(predictions)

acc_no_ret = acc_calc_final(predictions=pred_no_ret, references=final_decisions)
print(f"Vanilla QA: acc: {acc_no_ret}")

acc_citations = acc_calc_final(predictions=final_pred_citations, references=final_decisions)
print(f"QA with Citations: acc: {acc_citations}")

acc_ret = acc_calc_final(predictions=pred_ret, references=final_decisions)
print(f"QA with Retrieval: acc: {acc_ret}")


Vanilla QA: acc: 0.69
QA with Citations: acc: 0.66
QA with Retrieval: acc: 0.85


In [8]:
import json
data = {"model": loaded_data['model'],
        "prompt": loaded_data['prompt'],
        # "COT": {'results':pred_ret}, 'metrics': {'BertScore': bert_score_ret, 'BLEU': bleu_score_ret, "Accuracy": acc_ret}, #, "Perplexity":perp_ret}},
        "Vanilla": {'results': pred_no_ret, 'metrics': {'BertScore': bert_score_no_ret, 'BLEU': bleu_score_no_ret, "Accuracy": acc_no_ret,  "Perplexity":perp_no_ret['mean_perplexity'], "Meteor": met_no_ret['meteor']}},  #
         "With Citations": { 'results':final_pred_citations, 'metrics': {'BertScore': bert_score_citations, 'BLEU': bleu_score_citations, "Accuracy": acc_citations, "Meteor": met_citations['meteor']}},  # "Perplexity":perp_citations
           "With Retrieval": {'results':pred_ret, 'metrics': {'BertScore': bert_score_ret, 'BLEU': bleu_score_ret, "Accuracy": acc_ret, "Perplexity":perp_ret['mean_perplexity'], "Meteor": met_ret['meteor']}},  # 
           "ground_truth": {"long_answers": long_answers, "final_decisions": final_decisions}}

json_data = json.dumps(data, indent=2)

with open('oneshot_Full_results_metrics_cleanret.json', 'w') as json_file:
    json_file.write(json_data)