In [None]:
import os
from pathlib import Path
from fdua_competition.vectorstore import build_vectorstore, get_document_list, get_documents_dir
from fdua_competition.enums import Mode, VectorStoreOption, ChatModelOption
from fdua_competition.rag import get_chat_model, read_prompt
from langchain_core.prompts import ChatPromptTemplate
from fdua_competition.utils import get_queries
from tqdm import tqdm
from pydantic import BaseModel, Field

In [None]:
query = "extract name of organization"

In [None]:
db = build_vectorstore(os.getenv("OUTPUT_NAME"), Mode.TEST, VectorStoreOption.CHROMA)

In [None]:
class Organizations(BaseModel):
    organizations: list[str] = Field(description="[company names]")
    source: str = Field(description="source file")


def extract_organization_names(path: Path) -> Organizations:
    query = "extract all the company names mentioned in the documents"
    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", read_prompt("information_extractor")),
            ("system", "documents:\n{documents}"),
            ("user", "{query}")
        ]
    )
    documents = db.as_retriever().invoke(query, filter={"source": str(path)})
    chat_model = get_chat_model(ChatModelOption.AZURE).with_structured_output(Organizations)
    chain = prompt_template | chat_model
    return chain.invoke({"documents": documents, "query": query})

In [None]:
docs = get_document_list(get_documents_dir(Mode.TEST))
orgs = [extract_organization_names(doc) for doc in tqdm(docs, desc="extrcting organization names...")]

In [None]:
for org in orgs:
    print(org.source)
    print(org.organizations)
    print()

In [None]:
class SourceToLookup(BaseModel):
    query: str = Field(description="query")
    source: str = Field(description="source file to look up")


def search_source_to_lookup(query, orgs):    
    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", read_prompt("retrieval_assistant")),
            ("system", "organizations:\n{organization}"),
            ("user", "query: {query}")
        ]
    )
    chat_model = get_chat_model(ChatModelOption.AZURE).with_structured_output(SourceToLookup)
    chain = prompt_template | chat_model
    context = "\n\n".join([f"- {org.source}: {org.organizations}" for org in orgs])
    return chain.invoke({"organization": context, "query": query})

In [None]:
docs_to_lookup = [search_source_to_lookup(query, orgs) for query in tqdm(get_queries(Mode.TEST), desc="finding source to reference...")]

In [None]:
docs_to_lookup

In [None]:
class ResearchAssistantResponse(BaseModel):
    query: str = Field(description="the query that was asked.")
    response: str = Field(description="the answer for the given query")
    reason: str = Field(description="the reason for the response.")
    organization_name: str = Field(description="the organization name that the query is about.")
    contexts: list[str] = Field(description="the context that the response was based on with its file path and page number.")


def answer_query(query: str, reference: str) -> Organizations:
    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", read_prompt("research_assistant")),
            ("system", "context:\n{context}"),
            ("user", query)
        ]
    )
    context = db.as_retriever().invoke(query, filter={"source": reference})
    chat_model = get_chat_model(ChatModelOption.AZURE).with_structured_output(ResearchAssistantResponse)
    chain = prompt_template | chat_model
    return chain.invoke({"context": context, "language": "japanese"})


for i in docs_to_lookup:
    print(answer_query(i.query, i.source))