# Sean's adaptation
#### Goal:
Use `JsonOutputParser` to get more stable output format when generate multi-queiries

In [1]:
from dotenv import load_dotenv
load_dotenv(override=True)

import rich

In [2]:
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings

embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

from langchain_community.vectorstores import Chroma

  from tqdm.autonotebook import tqdm, trange


In [3]:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

loader = DirectoryLoader('../../pdf_files/',glob="*.pdf",loader_cls=PyPDFLoader)
documents = loader.load()

# Split text into chunks

text_splitter  = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=20)
text_chunks = text_splitter.split_documents(documents)

vectorstore = Chroma.from_documents(documents=text_chunks, 
                                    embedding=embedding,
                                    persist_directory="data/vectorstore")
vectorstore.persist()

retriever = vectorstore.as_retriever()

  vectorstore.persist()


In [7]:
# from langchain import hub
# prompt = hub.pull("langchain-ai/rag-fusion-query-generation")

In [4]:
from pydantic import BaseModel, Field
from typing import List

class Multi_Queries(BaseModel):
    multi_queries: List[str]=Field(description="The new queries that rephrase user's query with different perspectives.")

In [5]:
from langchain_core.output_parsers import JsonOutputParser

multi_queries_parser = JsonOutputParser(pydantic_object=Multi_Queries)
multi_queries_format = multi_queries_parser.get_format_instructions()

In [7]:
from langchain_core.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate)

system_prompt = """You are a helpful assistant that generates multiple search queries based on a single input query.
Generate 4 queries.

{format_instructions}
"""

system_message = SystemMessagePromptTemplate(
    prompt=PromptTemplate(
        template=system_prompt,
        # input_variables=['format_instructions']
        partial_variables={'format_instructions': multi_queries_format}
    )
)
human_message = HumanMessagePromptTemplate(
    prompt=PromptTemplate(
        template="Generate multiple search queries related to: {original_query}",
        input_variables=['original_query']
    )
)

prompt = ChatPromptTemplate.from_messages(
    [
        system_message,
        human_message
    ]
)

In [8]:
from langchain_core.runnables import RunnablePassthrough

generate_multi_queries =(
{"original_query": RunnablePassthrough()}
| prompt
| ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
| multi_queries_parser
| (lambda x: x['multi_queries'])
)

In [9]:
generate_multi_queries.invoke({"original_query": "What need to consider when using LLM to eval LLM generation?"})

['What factors should be considered when evaluating LLM outputs?',
 'How to assess LLM generation quality when using LLMs?',
 'What are the key considerations for using LLMs to evaluate their own outputs?',
 'What should I keep in mind when using an LLM for evaluating generated text?']

In [10]:
from langchain.load import dumps, loads

def rrf(results: list[list], k=60):
    fused_scores = {}
    for docs in results:
        # assumes the docs are returned in the order of relevance
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            fused_scores[doc_str] += 1/(rank+k)

    reranked_results = [
        (loads(doc_str), score) for doc_str, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    ]
    
    return reranked_results

In [11]:
rrf_chain = generate_multi_queries | retriever.map() | rrf

In [12]:
input = {"original_query": "What need to consider when using LLM to eval LLM generation?"}
final_result = rrf_chain.invoke(input)

  (loads(doc_str), score) for doc_str, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)


In [13]:
rich.print(final_result)