In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model='gpt-4o-mini')

In [None]:
from langsmith import Client

client = Client()

In [None]:
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings

vector_store = Chroma(
    embedding_function=OpenAIEmbeddings(model='text-embedding-3-large'),
    collection_name='income_tax_collections',
    persist_directory='./income_tax_collections'
)
retriever = vector_store.as_retriever(search_kwargs={'k': 3})

In [None]:
from typing import Literal
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate

class AgentState(TypedDict):
    query: str
    context: list
    answer: str

In [None]:
from langchain_community.tools import TavilySearchResults

tavily_search_tool = TavilySearchResults(
    max_results=3,
    search_depth="advanced",
    include_answer=True,
    include_raw_content=True,
    include_images=True,
)

In [None]:
def web_search(state: AgentState) -> AgentState:
    """
    'web_search' Node
    : 주어진 state를 기반으로 웹 검색을 수행한다.

    Args:
        - state(AgentState): 사용자의 질문을 포함한 에이전트의 현재 state

    Returns:
        - AgentState: 웹 검색 결과가 추가된 state
    """
    
    query = state['query']
    
    # 웹 검색 도구 활용
    results = tavily_search_tool.invoke(query)
    
    return {'context': results}

In [None]:
rag_prompt = client.pull_prompt("rlm/rag-prompt", include_model=True)

def web_generate(state: AgentState) -> AgentState:
    """
    'web_generate' Node
    : 사용자의 질문과 웹 검색된 문서를 기반으로 응답을 생성한다.

    Args:
        - state(AgentState): 사용자의 질문과 검색된 문서를 포함한 에이전트의 현재 state

    Returns:
        - AgentState: 생성된 응답이 추가된 state
    """
    
    query = state['query']
    context = state['context']
    
    rag_chain = rag_prompt | llm
    ai_message = rag_chain.invoke({'question': query, 'context': context})
    
    return {'answer': ai_message}

In [None]:
def basic_generate(state: AgentState) -> AgentState:
    """
    'basic_generate' Node
    : 사용자의 질문에 대한 응답을 생성한다.

    Args:
        - state(AgentState): 사용자의 질문을 포함한 에이전트의 현재 state

    Returns:
        - AgentState: 생성된 응답이 추가된 state
    """
    
    query = state['query']
    
    # LLM에게 직접 질문
    ai_message = llm.invoke(query)
    
    return {'answer': ai_message}

In [None]:
from pydantic import BaseModel, Field

class Router(BaseModel):
    target: Literal['vector_store', 'llm', 'web_search'] = Field(
        description="The target for the query to answer"
    )
    
structured_llm = llm.with_structured_output(Router)

In [None]:
route_system_prompt = """
You are an expert at routing a user's question to 'vector_store', 'llm', or 'web_search'.

if you think the question is simple enough use 'llm'.
if you think you need to search the web to answer the question use 'web_search'.

- 'vector_store' contains information about income tax up to December 2024.
"""

route_prompt = ChatPromptTemplate.from_messages([
    ('system', route_system_prompt),
    ('user', '{query}')
])

def route(state: AgentState) -> Literal['vector_store', 'llm', 'web_search']:
    """
    'route' Node
    : 사용자 질문의 적절한 경로를 결정한다.

    Args:
        - state(AgentState): 사용자의 질문을 포함한 에이전트의 현재 state

    Returns:
        - Literal['vector_store', 'llm', 'web_search']: 질문을 처리하기 위한 경로를 나타내는 문자열
    """
    
    query = state['query']
    
    # 라우터 체인
    route_chain = route_prompt | structured_llm
    ai_message = route_chain.invoke({'query': query})
    
    print(f"target=={ai_message.target}")
    
    return ai_message.target

In [None]:
from langgraph.graph import StateGraph, START, END
from graph.income_tax_graph import graph as income_tax_graph

graph_builder = StateGraph(AgentState)

# nodes
graph_builder.add_node('web_search', web_search)
graph_builder.add_node('web_generate', web_generate)
graph_builder.add_node('basic_generate', basic_generate)
graph_builder.add_node('income_tax_agent', income_tax_graph)

# edges
graph_builder.add_conditional_edges(
    START,
    route,
    {
        'vector_store': 'income_tax_agent',
        'llm': 'basic_generate',
        'web_search': 'web_search'
    }
)
graph_builder.add_edge('web_search', 'web_generate')
graph_builder.add_edge('web_generate', END)
graph_builder.add_edge('basic_generate', END)
graph_builder.add_edge('income_tax_agent', END)

In [None]:
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
query1 = "연봉 5천만원인 거주자가 납부해야 하는 소득세는 얼마인가요?"
query2 = "군자역 맛집을 알려주세요."
query3 = "대한민국의 수도는 어디인가요?"
initial_state = {'query': query1}

graph.invoke(initial_state)