In [1]:
from dataclasses import dataclass
from typing import List, TypedDict
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, END
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from get_embedding_function import get_embedding_function
import os, json
import openai
from IPython.display import display, Markdown
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.environ['OPENAI_API_KEY']
CHROMA_PATH = "data/chroma"

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
@dataclass
class QueryResponse:
    query_text: str
    response_text: str
    sources: List[str]

def query_rag(query_text: str) -> QueryResponse:

    db = Chroma(
        persist_directory=CHROMA_PATH,
        embedding_function=get_embedding_function()
    )

    print(f'''Number of docs in db: {len(db.get()['ids'])}''')

    retriever = db.as_retriever(search_kwargs={"k": 7})

    prompt_template = ChatPromptTemplate.from_template(open('prompts/answer_query.md').read())

    model = ChatAnthropic(model='claude-3-5-sonnet-20240620')

    chain_with_prompt = prompt_template | model | StrOutputParser()

    class AgentState(TypedDict):
        question: str
        raw_docs: list[BaseMessage]
        formatted_docs: list[str]
        generation: str
        sources: list[str]

    def get_docs(state: AgentState):
        #print("get_docs:", state)
        question = state["question"]
        docs = retriever.invoke(question)
        state["sources"] = [doc.metadata.get("id") for doc in docs]
        state["raw_docs"] = docs
        return state
    
    def format_docs(state:AgentState):
        #print("format_docs:",state)
        documents = state["raw_docs"]
        state["formatted_docs"] = "\n\n---\n\n".join(["Talk Title:" + doc.metadata.get("vid_title", None) 
                                    + "\nExcerpt:" + doc.page_content + "\nPublished time:" + doc.metadata.get("published_dt", None)
                                    for doc in documents])
        return state
    
    def generate(state:AgentState):
        #print("generate:", state)
        question = state["question"]
        formatted_docs = state["formatted_docs"]
        result = chain_with_prompt.invoke({"question": question, "context":formatted_docs})
        state["generation"] = result
        return state

    workflow = StateGraph(AgentState)
    workflow.add_node("get_docs", get_docs)
    workflow.add_node("format_docs", format_docs)
    workflow.add_node("generate", generate)
    workflow.add_edge("get_docs", "format_docs")
    workflow.add_edge("format_docs", "generate")
    workflow.add_edge("generate", END)
    workflow.set_entry_point("get_docs")

    rag_app = workflow.compile()

    result = rag_app.invoke({"question":query_text})

    log_file = 'logs.md'
    with open(log_file, 'a') as file:
        file.write(f"Question:\n {result['question']}\n")
        file.write(f"Response:\n {result['generation']}\n")
        file.write('\n----------------------------------------------------------\n') 

    return QueryResponse(
        query_text=query_text, response_text=result['generation'], sources=result['sources']
        )


In [13]:
response = query_rag("Sự sống silicon có thể tồn tại không?")
display(Markdown(response.response_text))

Number of docs in db: 2803


AttributeError: 'Document' object has no attribute 'response_metadata'