# Sean's adaptation
#### Goal:
Use `JsonOutputParser` to get more stable output format when generate multi-queiries

In [None]:
from dotenv import load_dotenv
load_dotenv(override=True)

import rich

In [None]:
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings

embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

from langchain_community.vectorstores import Chroma

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

loader = DirectoryLoader('../../pdf_files/',glob="*.pdf",loader_cls=PyPDFLoader)
documents = loader.load()

# Split text into chunks

text_splitter  = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=20)
text_chunks = text_splitter.split_documents(documents)

vectorstore = Chroma.from_documents(documents=text_chunks, 
                                    embedding=embedding,
                                    persist_directory="data/vectorstore")
vectorstore.persist()

retriever = vectorstore.as_retriever()

In [None]:
# from langchain import hub
# prompt = hub.pull("langchain-ai/rag-fusion-query-generation")

In [None]:
from pydantic import BaseModel, Field
from typing import List

class Multi_Queries(BaseModel):
    multi_queries: List[str]=Field(description="The new queries that rephrase user's query with different perspectives.")

In [None]:
from langchain_core.output_parsers import JsonOutputParser

multi_queries_parser = JsonOutputParser(pydantic_object=Multi_Queries)
multi_queries_format = multi_queries_parser.get_format_instructions()

In [None]:
from langchain_core.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate)

system_prompt = """You are a helpful assistant that generates multiple search queries based on a single input query.
Generate 4 queries.

{format_instructions}
"""

system_message = SystemMessagePromptTemplate(
    prompt=PromptTemplate(
        template=system_prompt,
        # input_variables=['format_instructions']
        partial_variables={'format_instructions': multi_queries_format}
    )
)
human_message = HumanMessagePromptTemplate(
    prompt=PromptTemplate(
        template="Generate multiple search queries related to: {original_query}",
        input_variables=['original_query']
    )
)

prompt = ChatPromptTemplate.from_messages(
    [
        system_message,
        human_message
    ]
)

In [None]:
from langchain_core.runnables import RunnablePassthrough

generate_multi_queries =(
{"original_query": RunnablePassthrough()}
| prompt
| ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
| multi_queries_parser
| (lambda x: x['multi_queries'])
)

In [None]:
generate_multi_queries.invoke({"original_query": "What need to consider when using LLM to eval LLM generation?"})

In [None]:
from langchain.load import dumps, loads

def rrf(results: list[list], k=60):
    fused_scores = {}
    for docs in results:
        # assumes the docs are returned in the order of relevance
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            fused_scores[doc_str] += 1/(rank+k)

    reranked_results = [
        (loads(doc_str), {"rrf_score": score}) for doc_str, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    ]
    
    return reranked_results

In [None]:
rrf_chain = generate_multi_queries | retriever.map() | rrf

In [None]:
input = {"original_query": "What need to consider when using LLM to eval LLM generation?"}
final_result = rrf_chain.invoke(input)

In [None]:
rich.print(final_result)

In [None]:
for doc in final_result:
    print(doc[1]['rrf_score'])

In [None]:
from langchain_core.runnables import RunnableLambda

def filter_rrf_score(docs, threshold=0.1):
    return [doc for doc in docs if doc[1]['rrf_score'] > threshold]

In [None]:
test = filter_rrf_score(final_result, threshold=0.04)
test

In [None]:
filter_chain_test = rrf_chain | filter_rrf_score

In [None]:
input = {"original_query": "What need to consider when using LLM to eval LLM generation?"}
test = filter_chain_test.invoke({"original_query": input})
# test = filter_chain_test.invoke({"original_query": input, "threshold": 0.0})
test

## RAG

In [None]:
def concatenate_docs(docs):
    documents = ""
    for doc in docs:
        documents += doc[0].page_content + "\n\n"
    return documents

In [None]:
concatenate_chain = rrf_chain | concatenate_docs
input = {"original_query": "What need to consider when using LLM to eval LLM generation?"}

test = concatenate_chain.invoke(input)
test

In [None]:
prompt_template = """You are a helpful assistant that generates answer based on user's input query and retrieved documents.

<retreived_documents>
{retreived_documents}
</retreived_documents>

<user_query>
{user_query}
</user_query>
"""

from langchain_core.prompts import PromptTemplate

rag_prompt = PromptTemplate(
    template=prompt_template,
    input_variables=['retreived_documents', 'user_query']
    )

respond_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

response_chain = rag_prompt | respond_llm


In [None]:
input = {"original_query": "What need to consider when using LLM to eval LLM generation?"}

docs = concatenate_chain.invoke(input)

rag_result = response_chain.invoke({"retreived_documents": docs, "user_query": input['original_query']})

In [None]:
rich.print(rag_result)