In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain_core.prompts import ChatPromptTemplate,PromptTemplate

embeddings= OllamaEmbeddings(
    model="nomic-embed-text:latest"
)

vector_store= Chroma(
    embedding_function= embeddings,
    collection_name= "income_tax_collection",
    persist_directory= "./income_tax_collection"
)

retriever = vector_store.as_retriever(search_kwargs={"k": 3})

In [None]:
from typing_extensions import TypedDict, List
from langgraph.graph import StateGraph


class AgentState(TypedDict):
    query: str
    context: list
    answer: str

graph_builder = StateGraph(AgentState)


In [None]:
def retrieve(state: AgentState):
    query = state["query"]
    docs = retriever.invoke(query)
    return {"context": docs}

In [None]:
from langchain_ollama import ChatOllama

llm = ChatOllama(
    model="deepseek-r1:1.5b",
    temperature=0
)

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

generate_prompt = hub.pull("rlm/rag-prompt")


def generate(state: AgentState) -> AgentState:
    context = state["context"]
    query = state["query"]
    rag_chain = generate_prompt | llm | StrOutputParser()
    response = rag_chain.invoke({"question": query, "context": context})
    return {"answer": response}

In [None]:
from langchain import hub
from typing import Literal

doc_relevance_prompt = hub.pull("langchain-ai/rag-document-relevance")


def check_doc_relevance(state: AgentState) -> Literal["relevant", "irrelevant"]:
    query = state["query"]
    context = state["context"]
    doc_relevance_chain = doc_relevance_prompt | llm
    response = doc_relevance_chain.invoke({"question": query, "context": context})
    if response['Score'] == 1:
        return "relevant"
    return "irrelevant"

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

rewrite_prompt = PromptTemplate.from_template("""사용자의 질문을 보고, 웹 검색에 용이하게 사용자의 질문을 수정해주세요
질문: {query}
""")

def rewrite(state: AgentState) -> AgentState:
    query = state["query"]
    rewrite_chain = rewrite_prompt | llm | StrOutputParser()
    response = rewrite_chain.invoke({"query": query})
    return {"query": response}


In [None]:
from langchain_community.tools import TavilySearchResults

tavily_search_tool = TavilySearchResults(
    max_results=3,
    search_depth="advanced",
    include_answer=True,
    include_raw_content=True,
    include_images=True
)

def web_search(state:AgentState):
    query = state["query"]
    results = tavily_search_tool.invoke(query)
    return {"context": results}



In [None]:
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("generate", generate)
graph_builder.add_node("rewrite", rewrite)
graph_builder.add_node("web_search", web_search)

In [None]:
from langgraph.graph import START, END

graph_builder.add_edge(START, "retrieve")
graph_builder.add_conditional_edges(
    'retrieve',
    check_doc_relevance,
    {
        'relevant': 'generate',
        'irrelevant': 'rewrite'
    }
)

graph_builder.add_edge('rewrite', 'web_search')
graph_builder.add_edge('web_search', 'generate')

In [None]:
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
initial_state = {"query": "연봉 8천만원 거주자의 소득세는 얼마인가요?"}
graph.invoke(initial_state)

In [None]:
initial_state = {'query' : '역삼역 맛집을 추천해주세요'}
graph.invoke(initial_state)