# Sean's adaptation
#### Goal:
Use `JsonOutputParser` to get more stable output format when generate multi-queiries

In [1]:
from dotenv import load_dotenv
load_dotenv(override=True)

import rich

In [3]:
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings

embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

from langchain_community.vectorstores import Chroma

  from tqdm.autonotebook import tqdm, trange


In [4]:
all_documents = {
    "doc1": "Climate change and economic impact.",
    "doc2": "Public health concerns due to climate change.",
    "doc3": "Climate change: A social perspective.",
    "doc4": "Technological solutions to climate change.",
    "doc5": "Policy changes needed to combat climate change.",
    "doc6": "Climate change and its impact on biodiversity.",
    "doc7": "Climate change: The science and models.",
    "doc8": "Global warming: A subset of climate change.",
    "doc9": "How climate change affects daily weather.",
    "doc10": "The history of climate change activism.",
}

In [6]:
vectorstore = Chroma.from_texts(all_documents.values(), embedding=embedding, collection_name="sean", persist_directory="./sean_db")

In [7]:
# from langchain import hub
# prompt = hub.pull("langchain-ai/rag-fusion-query-generation")

In [86]:
from pydantic import BaseModel, Field
from typing import List

class Multi_Queries(BaseModel):
    multi_queries: List[str]=Field(description="The new queries that rephrase user's query with different perspectives.")

In [87]:
from langchain_core.output_parsers import JsonOutputParser

multi_queries_parser = JsonOutputParser(pydantic_object=Multi_Queries)
multi_queries_format = multi_queries_parser.get_format_instructions()

In [88]:
from langchain_core.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate)

system_prompt = """You are a helpful assistant that generates multiple search queries based on a single input query.
Generate 4 queries.

{format_instructions}
"""

system_message = SystemMessagePromptTemplate(
    prompt=PromptTemplate(
        template=system_prompt,
        # input_variables=['format_instructions']
        partial_variables={'format_instructions': multi_queries_format}
    )
)
human_message = HumanMessagePromptTemplate(
    prompt=PromptTemplate(
        template="Generate multiple search queries related to: {original_query}",
        input_variables=['original_query']
    )
)

prompt = ChatPromptTemplate.from_messages(
    [
        system_message,
        human_message
    ]
)

In [106]:
from langchain_core.runnables import RunnablePassthrough

generate_multi_queries =(
{"original_query": RunnablePassthrough()}
| prompt
| ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
| multi_queries_parser
| (lambda x: x['multi_queries'])
)

In [107]:
generate_multi_queries.invoke({"original_query": "Climate change and economic impact."})

['How does climate change affect the economy?',
 'Economic consequences of climate change.',
 'Impact of climate change on global markets.',
 'Climate change and its effects on economic stability.']

In [91]:
retriever = vectorstore.as_retriever()

In [108]:
from langchain.load import dumps, loads

def rrf(results: list[list], k=60):
    fused_scores = {}
    for docs in results:
        # assumes the docs are returned in the order of relevance
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            fused_scores[doc_str] += 1/(rank+k)

    reranked_results = [
        (loads(doc_str), score) for doc_str, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    ]
    
    return reranked_results

In [109]:
rrf_chain = generate_multi_queries | retriever.map() | rrf

In [110]:
input = {"original_query": "Climate change and economic impact."}
final_result = rrf_chain.invoke(input)

In [111]:
rich.print(final_result)