# Rewrite-Retrieve-Read

**Rewrite-Retrieve-Read** is a method proposed in the paper [Query Rewriting for Retrieval-Augmented Large Language Models](https://arxiv.org/pdf/2305.14283.pdf)

> Because the original query can not be always optimal to retrieve for the LLM, especially in the real world... we first prompt an LLM to rewrite the queries, then conduct retrieval-augmented reading

We show how you can easily do that with LangChain Expression Language

## Baseline

Baseline RAG (**Retrieve-and-read**) can be done like the following:

In [7]:
import os
import torch
import numpy as np
from prompts import *
import evaluate
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.chains import LLMChain

from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

In [8]:
torch.cuda.set_device(0)  # have to change depending on which device u wana use
torch.cuda.current_device()

0

In [6]:
model_name='mistralai/Mistral-7B-Instruct-v0.1'

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    cache_dir="models"
)

bin c:\Users\ptejd\anaconda3\envs\medrag3\lib\site-packages\bitsandbytes\libbitsandbytes_cuda118.dll


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Load Data

In [9]:
dataset = load_dataset("bigbio/pubmed_qa", cache_dir="data")

dataset

DatasetDict({
    train: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 200000
    })
    validation: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 11269
    })
})

In [10]:
page_content_column = "CONTEXTS"

def preprocess(dataset):
    for split in dataset.keys():
        for contexts in dataset[split][page_content_column]:
            for sentence in contexts:
                yield Document(page_content=sentence)

data = list(preprocess(dataset))  # 655055

data[0]

Document(page_content='In previous work we (Fisher et al., 2011) examined the emergence of neurobehavioral disinhibition (ND) in adolescents with prenatal substance exposure. We computed ND factor scores at three age points (8/9, 11 and 13/14 years) and found that both prenatal substance exposure and early adversity predicted ND. The purpose of the current study was to determine the association between these ND scores and initiation of substance use between ages 8 and 16 in this cohort as early initiation of substance use has been related to later substance use disorders. Our hypothesis was that prenatal cocaine exposure predisposes the child to ND, which, in turn, is associated with initiation of substance use by age 16.')

## Setting up FAISS

In [7]:
embedding_model = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device':'cuda'}
encode_kwargs = {'normalize_embeddings': False}

# Initialize an instance of HuggingFaceEmbeddings with the specified parameters
embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model,   
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs, 
    cache_folder="models"
)

if os.path.exists("faiss_index_pubmed"):
    db = FAISS.load_local("faiss_index_pubmed", embeddings)
else:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    docs = text_splitter.split_documents(data)  # 676307

    db = FAISS.from_documents(docs, embeddings)
    db.save_local("faiss_index_pubmed")

In [11]:
template = """Answer the users question based only on the following context:

<context>
{context}
</context>

Question: {question}
"""

# Create prompt from prompt template 
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=template,
)

# model = ChatOpenAI(temperature=0)

# search = DuckDuckGoSearchAPIWrapper()

text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    max_new_tokens=300,
    do_sample=False,
)

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)


retriever = db.as_retriever(
        search_type="similarity",
        search_kwargs={'k': 3}
    )

NameError: name 'db' is not defined

## Running Queries

In [19]:
idx = 0

questions = dataset['train'][2:102]["QUESTION"]
contexts = dataset['train'][2:102]["CONTEXTS"]
long_answers = dataset['train'][2:102]["LONG_ANSWER"]
final_decisions = dataset['train'][2:102]["final_decision"]

# Simple Chain

In [10]:
chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | mistral_llm
    | StrOutputParser()
)

In [11]:
# query = questions[0]

In [12]:
pred_simple = []
for query in questions:
    output_simple = chain.invoke(query)
    pred_simple.append(output_simple)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

## Rewrite-Retrieve-Read Implementation

The main part is a rewriter to rewrite the search query

In [13]:
template = """Provide a better search query for \
web search engine to answer the given question, end \
the queries with ’**’. Question: \
{x} Answer:"""
rewrite_prompt = ChatPromptTemplate.from_template(template)

In [14]:
from langchain import hub

rewrite_prompt = hub.pull("langchain-ai/rewrite")

In [15]:
print(rewrite_prompt.template)

Provide a better search query for web search engine to answer the given question, end the queries with ’**’.  Question {x} Answer:


In [16]:
# Parser to remove the `**`
def _parse(text):
    return text.strip("**")

In [17]:
rewriter = rewrite_prompt | mistral_llm | StrOutputParser() | _parse

In [18]:
query = questions[0]
rewriter.invoke({"x": query})

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


' **Heart failure with preserved ejection fraction and dynamic impairment of active relaxation and contraction of the left ventricle on exercise associated with myocardial energy deficiency**.'

In [19]:
rewrite_retrieve_read_chain = (
    {
        "context": {"x": RunnablePassthrough()} | rewriter | retriever,
        "question": RunnablePassthrough(),
    }
    | prompt
    | mistral_llm
    | StrOutputParser()
)

In [20]:
pred_rewrite = []
for query in questions:
    output_rewrite = rewrite_retrieve_read_chain.invoke(query)
    pred_rewrite.append(output_rewrite)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

# Stepback

In [None]:
from langchain.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
# Few Shot Examples
examples = [
    {
        "input": "Could the members of The Police perform lawful arrests?",
        "output": "what can the members of The Police do?",
    },
    {
        "input": "Jan Sindel’s was born in what country?",
        "output": "what is Jan Sindel’s personal history?",
    },
]
# We now transform these to example messages
example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}"),
        ("ai", "{output}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
)

In [22]:
# step_back_prompt = ChatPromptTemplate.from_messages(
#     [
#         (
#             "system",
#             """You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Only Give one paraphrased question as the output. Here are a few examples:""",
#         ),
#         # Few shot examples
#         few_shot_prompt,
#         # New question
#         ("user", "{question}"),
#     ]
# )
# """You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Only Give one paraphrased question as the output. Here are a few examples:"""

template = """You are an expert at biomedical knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Only Give one paraphrased question as the output.
Example:
"human": "Could the members of The Police perform lawful arrests?",
"ai": "what can the members of The Police do?",
Question:
"human": {question} 
"ai" :"""
stepback_prompt = ChatPromptTemplate.from_template(template)

In [23]:
question_gen = (
    {
        "question": RunnablePassthrough(),
    }
    | stepback_prompt 
    | mistral_llm 
    | StrOutputParser()
)


In [24]:
question_gen.invoke(query)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


' What is the relationship between optic nerve blood flow and primary open-angle glaucoma suspects?'

In [33]:
from langchain import hub

# response_prompt1 = hub.pull("langchain-ai/stepback-answer")

response_prompt_template = """Answer the users question based only on the following context and the background Context:

<context>
{normal_context}
</context>

<Background Context>
{step_back_context}
</Background Context>

Question: {question}
"""

# Create prompt from prompt template 
response_prompt = PromptTemplate(
    input_variables=["normal_context","step_back_context","question"],
    template=response_prompt_template,
)

In [34]:
stepback_chain = (
    {
        # Retrieve context using the normal question
        "normal_context": RunnablePassthrough() | retriever,
        # Retrieve context using the step-back question
        "step_back_context": question_gen | retriever,
        # Pass on the question
        "question": RunnablePassthrough(),
    }
    | response_prompt
    | mistral_llm
    | StrOutputParser()
)

In [35]:
pred_stepback = []
for query in questions:
    output_stepback = stepback_chain.invoke(query)
    pred_stepback.append(output_stepback)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

In [28]:
# print(response_prompt1)

input_variables=['normal_context', 'question', 'step_back_context'] template='You are an expert of world knowledge. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant.\n\n{normal_context}\n{step_back_context}\n\nOriginal Question: {question}\nAnswer:'


# Evaluation

In [None]:
import json
data = {"simple": {"prompt": str(prompt),'results': pred_simple}, "rewrite_prompt": {'prompt':str(prompt), "rewrite_prompt": str(rewrite_prompt), 'results':pred_rewrite}, "stepback_prompt": {'prompt':str(response_prompt), "stepback_prompt": str(stepback_prompt), 'results':pred_stepback}, "ground_truth": long_answers}
json_data = json.dumps(data, indent=2)

with open('prompt_testing_results.json', 'w') as json_file:
    json_file.write(json_data)

In [12]:
# import json

# # Read the JSON data from the file
# with open('prompt_testing_results.json', 'r') as json_file:
#     loaded_data = json.load(json_file)

# # Access the lists from the loaded data
# pred_simple = loaded_data['simple']['results']
# pred_rewrite = loaded_data['rewrite_prompt']['results']
# pred_stepback = loaded_data['stepback_prompt']['results']
# long_answers = loaded_data['ground_truth']


## BLEU

In [13]:
import evaluate
bleu = evaluate.load("bleu", cache_dir="evaluation_metrics")  # value ranges from 0 to 1. score of 1 is better

bleu_score_simple = bleu.compute(predictions=pred_simple, references=long_answers)
print(f"BLEU Score: {bleu_score_simple}")

bleu_score_rewrite = bleu.compute(predictions=pred_rewrite, references=long_answers)
print(f"BLEU Score: {bleu_score_rewrite}")

bleu_score_stepback = bleu.compute(predictions=pred_stepback, references=long_answers)
print(f"BLEU Score: {bleu_score_stepback}")

BLEU Score: {'bleu': 0.04296830138091172, 'precisions': [0.17676767676767677, 0.055970149253731345, 0.025991917707567966, 0.013255469039673711], 'brevity_penalty': 1.0, 'length_ratio': 2.5495516210623133, 'translation_length': 11088, 'reference_length': 4349}
BLEU Score: {'bleu': 0.042713500077162, 'precisions': [0.17177144904173383, 0.05548527019699982, 0.026171803757067298, 0.013344376955641451], 'brevity_penalty': 1.0, 'length_ratio': 2.567486778569786, 'translation_length': 11166, 'reference_length': 4349}
BLEU Score: {'bleu': 0.035810324471057066, 'precisions': [0.1457910014513788, 0.0472953216374269, 0.02187039764359352, 0.010905044510385757], 'brevity_penalty': 1.0, 'length_ratio': 3.1685444929868933, 'translation_length': 13780, 'reference_length': 4349}


## BertScore

In [15]:
bertscore = evaluate.load("bertscore", cache_dir="evaluation_metrics")

bert_score_simple = bertscore.compute(predictions=pred_simple, references=long_answers, lang="en", batch_size =1 )
bert_score_simple = {key: np.mean(value) if key!= "hashcode" else value for key, value in bert_score_simple.items()}
print(f"BERTScore: {bert_score_simple}")

bert_score_rewrite = bertscore.compute(predictions=pred_rewrite, references=long_answers, lang="en", batch_size =1)
bert_score_rewrite = {key: np.mean(value) if key!= "hashcode" else value for key, value in bert_score_rewrite.items()}
print(f"BERTScore: {bert_score_rewrite}")

bert_score_stepback = bertscore.compute(predictions=pred_stepback, references=long_answers, lang="en", batch_size =1)
bert_score_stepback = {key: np.mean(value) if key!= "hashcode" else value for key, value in bert_score_stepback.items()}
print(f"BERTScore: {bert_score_stepback}")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore: {'precision': 0.8410256075859069, 'recall': 0.8782067668437957, 'f1': 0.8589409524202347, 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.2)'}
BERTScore: {'precision': 0.8392905700206756, 'recall': 0.8767524176836013, 'f1': 0.8572965705394745, 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.2)'}
BERTScore: {'precision': 0.8287627708911895, 'recall': 0.8732502144575119, 'f1': 0.8501276648044587, 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.2)'}


# Perplexity

In [23]:
perplexity = evaluate.load("perplexity", module_type="metric")

perp_simple = perplexity.compute(model_id='gpt2',
                             add_start_token=False,
                             predictions=pred_simple)
print(perp_simple['mean_perplexity'])

perp_rewrite = perplexity.compute(model_id='gpt2',
                             add_start_token=False,
                             predictions=pred_rewrite)
print(perp_rewrite['mean_perplexity'])

perp_stepback = perplexity.compute(model_id='gpt2',
                             add_start_token=False,
                             predictions=pred_stepback)
print(perp_stepback['mean_perplexity'])

  0%|          | 0/7 [00:00<?, ?it/s]

23.022315759658813


  0%|          | 0/7 [00:00<?, ?it/s]

24.287154097557067


  0%|          | 0/7 [00:00<?, ?it/s]

15.132294914722443


## Accuracy

In [26]:
def acc_calc_final(predictions, references):
    acc = 0
    for i in range(len(predictions)):
        # print(references[i].lower(), predictions[i].lower())
        if references[i].lower() in predictions[i][:15].lower():
            acc += 1
    return acc / len(predictions)

acc_simple = acc_calc_final(predictions=pred_simple, references=final_decisions)
print(f"Vanilla prompt: acc: {acc_simple}")

acc_rewrite = acc_calc_final(predictions=pred_rewrite, references=final_decisions)
print(f"rewrite-retrieve-read prompting: acc: {acc_rewrite}")

acc_stepback = acc_calc_final(predictions=pred_stepback, references=final_decisions)
print(f"Stepback prompting: acc: {acc_stepback}")

Vanilla prompt: acc: 0.69
rewrite-retrieve-read prompting: acc: 0.68
Stepback prompting: acc: 0.75


In [30]:
# loaded_data['simple']['metrics'] = {'BertScore': bert_score_simple, 'BLEU': bleu_score_simple, "Accuracy": acc_simple, "Perplexity":perp_simple['mean_perplexity']}
# loaded_data['rewrite_prompt']['metrics'] = {'BertScore': bert_score_rewrite, 'BLEU': bleu_score_rewrite, "Accuracy": acc_rewrite, "Perplexity":perp_rewrite['mean_perplexity']}
# loaded_data['stepback_prompt']['metrics'] = {'BertScore': bert_score_stepback, 'BLEU': bleu_score_stepback, "Accuracy": acc_stepback, "Perplexity":perp_stepback['mean_perplexity']}

In [None]:
import json
data = {"simple": {"prompt": str(prompt),'Answers': pred_simple, 'results': {'BertScore': bert_score_simple, 'BLEU': bleu_score_simple, "Accuracy": acc_simple, "Perplexity":perp_simple}}, "rewrite_prompt": {'prompt':str(prompt), "rewrite_prompt": str(rewrite_prompt), 'Answers':pred_rewrite, 'results': {'BertScore': bert_score_rewrite, 'BLEU': bleu_score_rewrite, "Accuracy": acc_rewrite, "Perplexity":perp_rewrite}}, "stepback_prompt": {'prompt':str(response_prompt), "stepback_prompt": str(stepback_prompt), 'Answers':pred_stepback, 'results': {'BertScore': bert_score_stepback, 'BLEU': bleu_score_stepback, "Accuracy": acc_stepback, "Perplexity":perp_stepback}}, "ground_truth": long_answers}

json_data = json.dumps(data, indent=2)

with open('prompt_testing_results_metrics.json', 'w') as json_file:
    json_file.write(json_data)