# 03-sql.ipynb

1. 데이터베이스에서 사용 가능한 테이블ㄹ과 스키마 가져오기
2. 질문과 관련된 테이블 결정
3. 해당 테이블의 스키마 가져오기
4. 질문과 스키마의 정보를 기반으로 쿼리를 생성
5. LLM을 사용하여 흔히 발생하는 오류가 있는지 SQL 확인
6. DB에서 SQL을 실행하고 결과를 확인
7. DB에서 에러 발생시, 수정 후 다시 확인
8. DB 결과를 바탕으로 LLM이 답변 생성

In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
from langchain_community.utilities import SQLDatabase
import os

DB_URI = os.environ.get('DB_URI')

db = SQLDatabase.from_uri(DB_URI)

  from pydantic.v1.fields import FieldInfo as FieldInfoV1


In [3]:
print(db.get_usable_table_names())
print(db.run('SELECT * FROM sales LIMIT 5;'))

['courses', 'customers', 'dt_demo', 'members', 'sales', 'students', 'students_courses', 'userinfo']
[(1, datetime.date(2024, 1, 17), 'C021', 'P2084', '과자', '식품', 7, 9480, 66360, '정동훈', '대구'), (2, datetime.date(2024, 4, 18), 'C042', 'P8517', '음료수', '식품', 5, 2584, 12920, '이영희', '부산'), (3, datetime.date(2024, 10, 14), 'C035', 'P8019', '청소기', '생활용품', 10, 254700, 2547000, '박민수', '부산'), (4, datetime.date(2024, 3, 11), 'C033', 'P1771', '쌀', '식품', 4, 35008, 140032, '이영희', '인천'), (5, datetime.date(2024, 11, 1), 'C005', 'P8668', '음료수', '식품', 17, 2529, 42993, '정동훈', '서울')]


In [None]:
# LLM 초기화
from langchain_openai import ChatOpenAI
model = ChatOpenAI(name='gpt-4.1-mini')

In [14]:
# Agent용 Tool 만들기
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=model)
toolkit.get_tools()

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000027642B66A50>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000027642B66A50>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000027642B66A50>),
 QuerySQLCheckerTool(description='Use this tool to 

In [None]:
# Agent 만들기
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware
from langgraph.checkpoint.memory import InMemorySaver

dialect = db.dialect
top_k = 5

system_prompt = f"""
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
"""

agent = create_agent(model, 
                     toolkit.get_tools(), 
                     system_prompt=system_prompt, 
                     middleware=[
            HumanInTheLoopMiddleware(
            interrupt_on={'sql_db_query': True},
            description_prefix='Tool 실행 전에 승인을 기다림'
        )
    ],
    checkpointer=InMemorySaver()
  )  # 일시정지 - 재실행에서 돌아갈 곳을 기억해야함!)

In [None]:
from langgraph.types import Command

question = '2월에 가장 많이 팔린 물건 3개와, 해당 물건들의 토요일 일요일 평균 매출액'

config = {'configurable': {'thread_id': '1234'}}

for event in agent.stream(
    {'messages': [{'role': 'user', 'content': question}]},
    stream_mode='values',
    config=config,
):
    if "__interrupt__" in event: 
        print("INTERRUPTED:") 
        interrupt = event["__interrupt__"][0] 
        for request in interrupt.value["action_requests"]: 
            print(request["description"]) 
    elif "messages" in event:
        event["messages"][-1].pretty_print()
    else:
        pass

print('---------------------------------------------------')

for step in agent.stream(
    Command(resume={"decisions": [{"type": "approve"}]}), 
    config,
    stream_mode="values",
):
    if "messages" in step:
        step["messages"][-1].pretty_print()
    elif "__interrupt__" in step:
        print("INTERRUPTED:")
        interrupt = step["__interrupt__"][0]
        for request in interrupt.value["action_requests"]:
            print(request["description"])
    else:
        pass
        


2월에 가장 많이 팔린 물건 3개와, 해당 물건들의 주말 평균 매출액
Tool Calls:
  sql_db_list_tables (call_rrplXwLZmcBA9QSSa8Ep5CHt)
 Call ID: call_rrplXwLZmcBA9QSSa8Ep5CHt
  Args:
Name: sql_db_list_tables

courses, customers, dt_demo, members, sales, students, students_courses, userinfo
Tool Calls:
  sql_db_schema (call_sFIHCHQD35tQXqt05UxSxfkv)
 Call ID: call_sFIHCHQD35tQXqt05UxSxfkv
  Args:
    table_names: sales
Name: sql_db_schema


CREATE TABLE sales (
	id INTEGER NOT NULL, 
	order_date DATE NOT NULL, 
	customer_id VARCHAR(10) NOT NULL, 
	product_id VARCHAR(10) NOT NULL, 
	product_name VARCHAR(50) NOT NULL, 
	category VARCHAR(20) NOT NULL, 
	quantity INTEGER NOT NULL, 
	unit_price INTEGER NOT NULL, 
	total_amount INTEGER NOT NULL, 
	sales_rep VARCHAR(30) NOT NULL, 
	region VARCHAR(20) NOT NULL, 
	CONSTRAINT sales_pkey PRIMARY KEY (id)
)

/*
3 rows from sales table:
id	order_date	customer_id	product_id	product_name	category	quantity	unit_price	total_amount	sales_rep	region
1	2024-01-17	C021	P2084	과자	식품	7	948