In [4]:
from dotenv import load_dotenv

load_dotenv()

import os
from langchain_community.utilities import SQLDatabase

POSTGRES_USER = os.getenv('POSTGRES_USER')
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD')
POSTGRES_DB = os.getenv('POSTGRES_DB')

URI = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@localhost:5432/{POSTGRES_DB}"

db = SQLDatabase.from_uri(URI)


from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model='gpt-5-nano', temperature=0)

# state 만들기
from langgraph.graph import MessagesState

# 한줄 한줄 각 노드에 채워짐
class State(MessagesState):
    question: str # 사용자 질문
    sql: str # 변환된 SQL
    result: str # DB에서 받은 결과
    answer: str # result

from typing_extensions import Annotated, TypedDict

class QueryOutput(TypedDict):
    """Generate SQL query"""
    query: Annotated[str, ..., '문법적으로 올바른 SQL 쿼리']

# SQL 생성 node
def write_sql(state: State):
    """Generate SQL query to fetch info"""
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {'sql': result["query"]}

# 
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

# SQL 실행 노드
def execute_sql(state: State):
    """Execute SQL Query"""
    execute_query_tool = QuerySQLDataBaseTool(db=db)
    result = execute_query_tool.invoke(state['sql'])
    return {'result': result}

# Test
execute_sql({
    'sql': 'SELECT COUNT(*) FROM employee;'
})

def generate_answer(state: State):
    """질문에 대해 수집한 정보를 바탕으로 답변"""
    prompt = f"""
    주어진 사용자 질문에 대해, DB에서 실행할 SQL 쿼리와 결과를 바탕으로 답변해.

    Question: {state['question']}
    ---
    SQL Query: {state['sql']}
    SQL Result: {state['result']}   
    """
    res = llm.invoke(prompt)
    return {'answer': res.content}

In [None]:
from langgraph.graph import START, StateGraph

# 단순히 1열로 쭉 노드들이 진행될 경우
builder = StateGraph(State).add_sequence(
    [write_sql, execute_sql, generate_answer]
)

builder.add_edge(START, 'write_sql')
graph = builder.compile()

from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))