### Goal:
The idea is generate multiple step back for retrieval.

In [None]:
from dotenv import load_dotenv
load_dotenv()
import rich

In [None]:
from langchain_core.prompts import FewShotChatMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.runnables import RunnablePassthrough

from pydantic import BaseModel, Field
from typing import List

In [None]:
examples = [
    {
        'input': 'What happens to the pressure, P, of an ideal gas if the temperature is increased by a factor of 2 and the volume is increased by a factor of 8?',
        'output': 'What are the physics principles behind this question?'
    },
    {
        'input': 'Estella Leopold went to which school between Aug 1954 and Nov 1954?',
        'output': "What was Estella Leopold's education history?"
    }
]

example_prompt = ChatPromptTemplate.from_messages(
        [
            ('human', '{input}'), 
            ('ai', '{output}')
        ]
    )

few_shot_prompt = FewShotChatMessagePromptTemplate(
    examples=examples,
    example_prompt=example_prompt
)

In [None]:
# rich.print(few_shot_prompt.format())
rich.print(few_shot_prompt)

In [None]:
class Multi_Step_Back(BaseModel):
    queries: List[str] = Field(description="step back and paraphrase of the original query, The number of step back questions is depend on the complexity of the original question, range from 1 to 5.")

multi_step_back_parser = JsonOutputParser(pydantic_object=Multi_Step_Back)
multi_step_back_formater = multi_step_back_parser.get_format_instructions()

In [None]:
system_message = SystemMessagePromptTemplate(
    prompt = PromptTemplate(
        template="""You are an expert at world knowledge. Your task is to step back and paraphrase a question to more generic step-back questions, which is easier to answer.

        The number of step back questions is depend on the complexity of the original question, range from 1 to 5.
        If the question need multiple steps of thinking, it should have more step back queries.
        If the question is simple, it can have just one, it should have more step back queries.
         
         {format_instructions}
         
         Here are a few examples:
         {few_shot_examples}
         """,
         partial_variables={
             'format_instructions': multi_step_back_formater,
                'few_shot_examples': few_shot_prompt.format()
             }
    )
)

human_message = HumanMessagePromptTemplate(
    prompt=PromptTemplate(
        template='{question}',
        input_variables=['question']
    )
)

final_prompt = ChatPromptTemplate.from_messages(
    [
        system_message,
        # few_shot_prompt,
        human_message
    ]
)

### Multi Query generator

In [None]:
multi_step_back_queries_generator = (
    {"question": RunnablePassthrough()}
    | final_prompt
    | ChatOpenAI(model="gpt-4o-mini", temperature=0.9)
    | multi_step_back_parser
    | (lambda x: x['queries'])
)

In [None]:
test = multi_step_back_queries_generator.invoke("What need to consider when using LLM to eval LLM generation?")
# test = multi_step_back_queries_generator.invoke("How to pick rock from floor?")

In [None]:
print(len(test))
rich.print(test)

### Building Retriever

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [None]:
embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

In [None]:
loader = DirectoryLoader('../../pdf_files/',glob="*.pdf",loader_cls=PyPDFLoader)
documents = loader.load()

# Split text into chunks

text_splitter  = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=20)
text_chunks = text_splitter.split_documents(documents)

vectorstore = Chroma.from_documents(documents=text_chunks, 
                                    embedding=embedding,
                                    persist_directory="data/vectorstore")
vectorstore.persist()

retriever = vectorstore.as_retriever()

### Add retriever into queries generator

In [None]:
multi_step_back_queries_chain = (
    multi_step_back_queries_generator
    | retriever.map()
)

In [None]:
test = multi_step_back_queries_chain.invoke("What need to consider when using LLM to eval LLM generation?")

In [None]:
rich.print(test)

### RRF

In [None]:
from langchain.load import dumps, loads

def rrf(results: list[list], k=60):
    fused_scores = {}
    for docs in results:
        # assumes the docs are returned in the order of relevance
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            fused_scores[doc_str] += 1/(rank+k)

    reranked_results = [
        (loads(doc_str), score) for doc_str, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    ]
    
    return reranked_results

In [None]:
multi_step_back_queries_chain = (
    multi_step_back_queries_generator
    | retriever.map()
    | rrf
    | (lambda obj_list: "\n".join(f"<doc_{i}>{obj[0].page_content}</doc_{i}>" for i, obj in enumerate(obj_list) if obj and obj[0].page_content))
)

In [None]:
test = multi_step_back_queries_chain.invoke("What need to consider when using LLM to eval LLM generation?")

In [None]:
print(test)

In [None]:
response_prompt_template = """You are an expert of world knowledge. 
I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. 
Otherwise, ignore them if they are not relevant.

<normal_context>
# {normal_context}
</normal_context>

<step_back_context>
# {step_back_context}
</step_back_context>


# Original Question: {question}
# Answer:"""

response_prompt = ChatPromptTemplate.from_template(response_prompt_template)

step_back_and_response_chain = (
    {"question": RunnablePassthrough()}
    | {"normal_context": RunnablePassthrough() |  retriever,
     "step_back_context": RunnablePassthrough() | multi_step_back_queries_chain,
     "question": RunnablePassthrough()}
     | response_prompt
     | ChatOpenAI(model="gpt-4o-mini", temperature=0.2)
     | StrOutputParser()
)

In [None]:
res = step_back_and_response_chain.invoke("What need to consider when using LLM to eval LLM generation?")

In [None]:
rich.print(res)