# MultiQueryRetriever
- 거리 기반 벡터 데이터베이스 검색은 쿼리를 고차원 공간에 임베딩하고 "거리"를 기준으로 유사한 문서를 찾음
- 쿼리의 단어 선택이 미묘하게 변경되거나 임베딩이 데이터의 의미를 잘 포착하지 못하면 검색 결과가 달라질 수 있음
- 프롬프트 엔지니어링/튜닝은 이러한 문제를 수동으로 해결하기 위해 수행되지만 번거로울 수 있음
- MultiQueryRetriever는 LLM을 사용하여 다양한 관점에서 여러 쿼리를 생성함으로써 프롬프트 튜닝 과정을 자동화함
- 각 쿼리에 대해 관련 문서 세트를 검색하고 모든 쿼리에서 고유한 합집합을 취하여 더 큰 세트를 얻음
- 여러 관점을 생성하여 거리 기반 검색의 한계를 극복하고 더 풍부한 결과를 얻을 수 있음

In [10]:
# Build a sample vectorDB
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
load_dotenv('../dot.env')

# Load blog post
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()

# Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
splits = text_splitter.split_documents(data)

# VectorDB
embedding = OpenAIEmbeddings()
vectordb = Chroma.from_documents(documents=splits, embedding=embedding)

In [11]:
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_openai import ChatOpenAI

question = "What are the approaches to Task Decomposition?"
llm = ChatOpenAI(temperature=0)
retriever_from_llm = MultiQueryRetriever.from_llm(
    retriever=vectordb.as_retriever(), llm=llm
)

In [12]:
# Set logging for the queries
import logging

logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

In [13]:
# "What are the approaches to Task Decomposition?"이라는 쿼리를 다변화 한 내용을 출력하고 있음.
unique_docs = retriever_from_llm.invoke(question)
len(unique_docs)

INFO:langchain.retrievers.multi_query:Generated queries: ['1. How can Task Decomposition be achieved through different methods?', '2. What strategies can be used for breaking down tasks in Task Decomposition?', '3. What are the various techniques available for approaching Task Decomposition?']


1

# Supplying your own prompt
- 프롬프트와 출력 파서를 제공하여 결과를 쿼리 목록으로 분할할 수도 있습니다.(<> 위의 셀에서는 logging 형식으로 출력)

In [14]:
from typing import List

from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field


# Output parser will split the LLM result into a list of queries
class LineList(BaseModel):
    # "lines" is the key (attribute name) of the parsed output
    lines: List[str] = Field(description="Lines of text")


class LineListOutputParser(PydanticOutputParser):
    def __init__(self) -> None:
        super().__init__(pydantic_object=LineList)

    def parse(self, text: str) -> LineList:
        lines = text.strip().split("\n")
        return LineList(lines=lines)


output_parser = LineListOutputParser()

QUERY_PROMPT = PromptTemplate(
    input_variables=["question"],
    # 다섯가지 version의 쿼리를 생성하게끔
    template="""You are an AI language model assistant. Your task is to generate five 
    different versions of the given user question to retrieve relevant documents from a vector 
    database. By generating multiple perspectives on the user question, your goal is to help
    the user overcome some of the limitations of the distance-based similarity search. 
    Provide these alternative questions separated by newlines.
    Original question: {question}""",
)
llm = ChatOpenAI(temperature=0)

# Chain
llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT, output_parser=output_parser)

# Other inputs
question = "What are the approaches to Task Decomposition?"

In [20]:
# Run
retriever = MultiQueryRetriever(
    retriever=vectordb.as_retriever(), 
    llm_chain=llm_chain, parser_key="lines"
)  # "lines" is the key (attribute name) of the parsed output

# Results
unique_docs = retriever.invoke(input="What does the course say about agent?")
len(unique_docs)

OutputParserException: Failed to parse LineList from completion 1. Got: 1 validation error for LineList
  Input should be a valid dictionary or instance of LineList [type=model_type, input_value=1, input_type=int]
    For further information visit https://errors.pydantic.dev/2.7/v/model_type