# LLM <-> DB
- Langgraph agent <-> PostgreDB 연결
- 사용자 요청 -> LLM -> SQL 쿼리 변환 -> DB -> LLM 답변 생성 -> 사용자

In [None]:
%pip install -q psycopg2

In [6]:
from dotenv import load_dotenv

load_dotenv()

True

In [22]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model='gpt-4.1', temperature=0)

In [23]:
import os
from langchain_community.utilities import SQLDatabase

POSTGRES_USER = os.getenv('POSTGRES_USER')
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD')
POSTGRES_DB = os.getenv('POSTGRES_DB')

URI = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@localhost:5432/{POSTGRES_DB}"

db = SQLDatabase.from_uri(URI)

In [27]:
print(db.dialect)  # 사용중인 DB 종류
print(db.get_table_info())  # 테이블 정보
print(db.get_usable_table_names())  # 테이블 정보
db.run('SELECT * FROM artist LIMIT 10;')

postgresql

CREATE TABLE album (
	album_id INTEGER NOT NULL, 
	title VARCHAR(160) NOT NULL, 
	artist_id INTEGER NOT NULL, 
	CONSTRAINT album_pkey PRIMARY KEY (album_id), 
	CONSTRAINT album_artist_id_fkey FOREIGN KEY(artist_id) REFERENCES artist (artist_id)
)

/*
3 rows from album table:
album_id	title	artist_id
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE artist (
	artist_id INTEGER NOT NULL, 
	name VARCHAR(120), 
	CONSTRAINT artist_pkey PRIMARY KEY (artist_id)
)

/*
3 rows from artist table:
artist_id	name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE customer (
	customer_id INTEGER NOT NULL, 
	first_name VARCHAR(40) NOT NULL, 
	last_name VARCHAR(20) NOT NULL, 
	company VARCHAR(80), 
	address VARCHAR(70), 
	city VARCHAR(40), 
	state VARCHAR(40), 
	country VARCHAR(40), 
	postal_code VARCHAR(10), 
	phone VARCHAR(24), 
	fax VARCHAR(24), 
	email VARCHAR(60) NOT NULL, 
	support_rep_id INTEGER, 
	CONSTRAINT customer_pkey PRIMARY KE

"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [25]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}"

query_prompt_template = ChatPromptTemplate(
    [
        ('system', system_message),
        ('user', user_prompt),
    ]
)

for msg in query_prompt_template.messages:
    msg.pretty_print()



Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m


Question: [33;1m[1;3m{input}[0m


In [None]:
from langgraph.graph import MessagesState

class State(MessagesState):
    question: str
    sql: str
    result: str
    answer: str


In [38]:
from typing_extensions import Annotated, TypedDict

class QueryOutput(TypedDict):
    """Generate SQL query"""
    query: Annotated[str, ..., '문법적으로 올바른 SQL 쿼리']


# SQL 생성 Node
def write_sql(state: State):
    """Generate SQL query to fetch info"""
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {'sql': result["query"]}


sql = write_sql({"question": "직원은 몇명이야?"})['sql']
print(sql)
db.run(sql)

SELECT COUNT(*) AS employee_count FROM employee;


'[(8,)]'