# mediRAG Pipeline

In [1]:
import os
import torch
from prompts import *
import evaluate
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.chains import LLMChain
from langchain.schema.runnable import RunnablePassthrough

In [2]:
torch.cuda.set_device(3)  # have to change depending on which device u wana use
torch.cuda.current_device()

3

## Load Model and Tokenizer

In [3]:
model_name='mistralai/Mistral-7B-Instruct-v0.1'

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    cache_dir="models"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Load Data

In [4]:
dataset = load_dataset("bigbio/pubmed_qa", cache_dir="data")

dataset

DatasetDict({
    train: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 200000
    })
    validation: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 11269
    })
})

In [5]:
page_content_column = "CONTEXTS"

def preprocess(dataset):
    for split in dataset.keys():
        for contexts in dataset[split][page_content_column]:
            for sentence in contexts:
                yield Document(page_content=sentence)

data = list(preprocess(dataset))  # 655055

data[0]

Document(page_content='In previous work we (Fisher et al., 2011) examined the emergence of neurobehavioral disinhibition (ND) in adolescents with prenatal substance exposure. We computed ND factor scores at three age points (8/9, 11 and 13/14 years) and found that both prenatal substance exposure and early adversity predicted ND. The purpose of the current study was to determine the association between these ND scores and initiation of substance use between ages 8 and 16 in this cohort as early initiation of substance use has been related to later substance use disorders. Our hypothesis was that prenatal cocaine exposure predisposes the child to ND, which, in turn, is associated with initiation of substance use by age 16.')

## Setting up FAISS

In [6]:
embedding_model = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device':'cuda'}
encode_kwargs = {'normalize_embeddings': False}

# Initialize an instance of HuggingFaceEmbeddings with the specified parameters
embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model,   
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs, 
    cache_folder="models"
)

if os.path.exists("faiss_index_pubmed"):
    db = FAISS.load_local("faiss_index_pubmed", embeddings)
else:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    docs = text_splitter.split_documents(data)  # 676307

    db = FAISS.from_documents(docs, embeddings)
    db.save_local("faiss_index_pubmed")

In [7]:
question = dataset['train']["QUESTION"][0]
context = dataset['train']["CONTEXTS"][0]

retrieved_docs = db.similarity_search(question)  # db.similarity_search_with_score(question)

print(f"Question:\n{question}")
print(f"\nContext:\n{context}")
print(f"\nRetrieved document:\n{retrieved_docs[0].page_content}\n{retrieved_docs[1].page_content}\n{retrieved_docs[2].page_content}")

Question:
Does neurobehavioral disinhibition predict initiation of substance use in children with prenatal cocaine exposure?

Context:
['In previous work we (Fisher et al., 2011) examined the emergence of neurobehavioral disinhibition (ND) in adolescents with prenatal substance exposure. We computed ND factor scores at three age points (8/9, 11 and 13/14 years) and found that both prenatal substance exposure and early adversity predicted ND. The purpose of the current study was to determine the association between these ND scores and initiation of substance use between ages 8 and 16 in this cohort as early initiation of substance use has been related to later substance use disorders. Our hypothesis was that prenatal cocaine exposure predisposes the child to ND, which, in turn, is associated with initiation of substance use by age 16.', "We studied 386 cocaine exposed and 517 unexposed children followed since birth in a longitudinal study. Five dichotomous variables were computed based 

## Initializing Pipeline


In [8]:
text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    max_new_tokens=300,
    do_sample=False,
)

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

# Create prompt from prompt template 
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=PROMPT_TEMPLATE_QA_EXPLAINER,
)

# Create llm chain 
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)

## Running Queries

In [9]:
idx = 0

question = dataset['train'][2]["QUESTION"]
context = dataset['train'][2]["CONTEXTS"]
long_answer = dataset['train'][2]["LONG_ANSWER"]
final_decision = dataset['train'][2]["final_decision"]

## QA without Retrieval

In [10]:
input_ids = tokenizer.encode(question, return_tensors="pt").to("cuda")

with torch.no_grad():
    output = model.generate(
        input_ids=input_ids,
        do_sample=False,
        return_dict_in_generate=True,
        max_new_tokens=300,
    )

output = tokenizer.decode(output.sequences[0][len(input_ids[0]):])

output

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'\n\n## Abstract\n\nHeart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction. The pathophysiology of HFpEF is complex and involves both structural and functional changes in the heart. Recent studies have shown that HFpEF is associated with dynamic impairment of active relaxation and contraction of the LV on exercise, which may be related to myocardial energy deficiency. This review summarizes the current evidence on the dynamic impairment of LV function in HFpEF and its association with myocardial energy deficiency.\n\n## Introduction\n\nHeart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction. HFpEF is estimated to affect 50% of all patients with heart failure, and its prevalence is increasing with

In [11]:
output_with_citations = ""
citations = ""
citation_list = []

for lines in output.split("\n"):
    lines = lines.strip()
    if len(lines.split(" ")) > 10:
        for line in lines.split("."):
            line = line.strip()
            docs_and_scores = db.similarity_search_with_score(line)[0]  # choosing top 1 relevant document
            if docs_and_scores[1] < 0.5:  # returned distance score is L2 distance, a lower score is better
                doc_content = docs_and_scores[0].page_content
                if doc_content in citation_list:
                    idx = citation_list.index(doc_content)

                else:
                    citation_list.append(doc_content)
                    idx = len(citation_list)
                    citations += f"[{idx}] {doc_content}\n"

                output_with_citations += line + f" [{idx}]. "

final_output_with_citations = output_with_citations + "\n\nCitations:\n" + citations

In [12]:
print(final_output_with_citations)

Heart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction [1]. The pathophysiology of HFpEF is complex and involves both structural and functional changes in the heart [2]. Recent studies have shown that HFpEF is associated with dynamic impairment of active relaxation and contraction of the LV on exercise, which may be related to myocardial energy deficiency [3]. This review summarizes the current evidence on the dynamic impairment of LV function in HFpEF and its association with myocardial energy deficiency [1]. Heart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction [0]. HFpEF is estimated to affect 50% of all patients with heart failure, and its prevalence is increasing with age [4]. The pathophy

## QA with Retrieval

In [13]:
retriever = db.as_retriever(
    search_type="similarity",
    search_kwargs={'k': 3}
)

# retriever = db.as_retriever(search_type="similarity_score_threshold", 
#                                  search_kwargs={"score_threshold": .5, 
#                                                 "k": top_k})

rag_chain = ({"context": retriever, "question": RunnablePassthrough()} | llm_chain)

# QA with retrieval
qa_retrieval_result = rag_chain.invoke(question)

qa_retrieval_result

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


{'context': [Document(page_content='We sought to evaluate the role of exercise-related changes in left ventricular (LV) relaxation and of LV contractile function and vasculoventricular coupling (VVC) in the pathophysiology of heart failure with preserved ejection fraction (HFpEF) and to assess myocardial energetic status in these patients.'),
  Document(page_content='Nearly half of patients with heart failure have a preserved ejection fraction (HFpEF). Symptoms of exercise intolerance and dyspnea are most often attributed to diastolic dysfunction; however, impaired systolic and/or arterial vasodilator reserve under stress could also play an important role.'),
  Document(page_content='To investigate the associations between glucose metabolism, left ventricular (LV) contractile reserve, and exercise capacity in patients with chronic systolic heart failure (HF).')],
 'question': 'Is heart failure with preserved ejection fraction characterized by dynamic impairment of active relaxation and

# Evaluation

In [14]:
bleu = evaluate.load("bleu", cache_dir="evaluation_metrics")  # value ranges from 0 to 1. score of 1 is better

bleu_score = bleu.compute(predictions=[output_with_citations], references=[long_answer])
print(f"BLEU Score: {bleu_score}")

bleu_score = bleu.compute(predictions=[qa_retrieval_result["text"]], references=[long_answer])
print(f"BLEU Score: {bleu_score}")

BLEU Score: {'bleu': 0.0, 'precisions': [0.05627705627705628, 0.013043478260869565, 0.004366812227074236, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 9.625, 'translation_length': 231, 'reference_length': 24}
BLEU Score: {'bleu': 0.0, 'precisions': [0.07075471698113207, 0.018957345971563982, 0.004761904761904762, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 8.833333333333334, 'translation_length': 212, 'reference_length': 24}


In [15]:
bertscore = evaluate.load("bertscore", cache_dir="evaluation_metrics")

bert_score = bertscore.compute(predictions=[output_with_citations], references=[long_answer], lang="en")
print(f"BERTScore: {bert_score}")

bert_score = bertscore.compute(predictions=[qa_retrieval_result["text"]], references=[long_answer], lang="en")
print(f"BERTScore: {bert_score}")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore: {'precision': [0.7966367602348328], 'recall': [0.8586264848709106], 'f1': [0.8264709115028381], 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.1)'}
BERTScore: {'precision': [0.8198319673538208], 'recall': [0.8714783191680908], 'f1': [0.8448666334152222], 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.1)'}
