# 构建一个SQL代理
在较高层面，该代理将：
1
从数据库中获取可用表和模式

2
决定哪些表与问题相关

3
获取相关表的模式

4
根据问题和模式信息生成查询

5
使用 LLM 仔细检查查询是否有常见错误

6
执行查询并返回结果

7
纠正数据库引擎发现的错误，直到查询成功

8
根据结果 формулировать（拟定）响应

我们将涵盖以下概念
用于从 SQL 数据库读取的工具
LangChain 代理
人工审核过程

In [1]:
# 安装
%pip install langchain  langgraph  langchain-community

Note: you may need to restart the kernel to use updated packages.


# 配置模型

In [2]:
# 环境变量配置
# 从电脑环境变量中导入
import os
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
print("OPENAI_API_KEY:", OPENAI_API_KEY is not None)
print("LANGCHAIN_API_KEY:", LANGCHAIN_API_KEY is not None)

# 设置环境变量以启用LangChain的跟踪功能
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY
# 设置项目名称以便在LangChain跟踪界面中区分不同项目
os.environ["LANGCHAIN_PROJECT"] = "llm-langchain-learn"

# # 兼容 DashScope：将同一个 Key 同步给 DASHSCOPE_API_KEY（若你单独配置了，可删掉此行）
# os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# os.environ["DASHSCOPE_API_KEY"] = OPENAI_API_KEY

OPENAI_API_KEY: True
LANGCHAIN_API_KEY: True


In [3]:
# 配置模型
from langchain_openai import ChatOpenAI
model = ChatOpenAI(
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    api_key=OPENAI_API_KEY,
    model="qwen-plus"
)

# 配置数据库
你将要创建一个SQLite 数据库为了这个教程。SQLite是一个轻量级的数据库，易于搭建和使用。我们将加载数据库，这是一个代表数字媒体商店的示例数据库。chinook
为了方便，我们将数据库（）托管在一个公共的GCS桶上。Chinook.db

In [5]:
import requests, pathlib

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path("Chinook.db")

if local_path.exists():
    print(f"{local_path} already exists, skipping download.")
else:
    response = requests.get(url)
    if response.status_code == 200:
        local_path.write_bytes(response.content)
        print(f"File downloaded and saved as {local_path}")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")

File downloaded and saved as Chinook.db


In [6]:
# 连接数据库并运行简单查询
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f'Sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')

Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]


# 添加数据库交互工具
使用包中可用的包装器与数据库交互。包装器提供了一个简单的接口来执行SQL查询和获取结果：SQLDatabase langchain_community

In [7]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=model)

tools = toolkit.get_tools()

for tool in tools:
    print(f"Tool name: {tool.name}, description: {tool.description}")

Tool name: sql_db_query, description: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
Tool name: sql_db_schema, description: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
Tool name: sql_db_list_tables, description: Input is an empty string, output is a comma-separated list of tables in the database.
Tool name: sql_db_query_checker, description: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!


# 使用create_agent
使用 create_agent 以最少的代码构建一个 ReAct 代理。该代理将解释请求并生成 SQL 命令，工具将执行这些命令。如果命令有错误，错误消息将返回给模型。然后模型可以检查原始请求和新的错误消息，并生成新命令。这可以持续进行，直到 LLM 成功生成命令或达到结束计数。这种向模型提供反馈（在本例中是错误消息）的模式非常强大。

In [9]:
# 用描述性系统提示初始化代理以定制其行为：
system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect=db.dialect,
    top_k=5,
)

In [10]:
# 创建一个包含模型、工具和系统提示的代理
from langchain.agents import create_agent

agent = create_agent(
    model=model,
    tools=tools,
    system_prompt=system_prompt,
)

# 运行代理

In [None]:
question = "Which genre on average has the longest tracks?" # “哪种类型的音乐平均来说曲目最长？”

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


Which genre on average has the longest tracks?

I need to find which genre has the longest average track duration. Let me start by exploring the database schema to understand the available tables and their structure.
Tool Calls:
  sql_db_list_tables (call_00648c5a173f42ce82830e)
 Call ID: call_00648c5a173f42ce82830e
  Args:
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

Now that I know the tables, I need to examine the relevant ones. The Genre and Track tables are likely the most important for this question. Let me check their schemas.
Tool Calls:
  sql_db_schema (call_0eb921522d0f4c65983f63)
 Call ID: call_0eb921522d0f4c65983f63
  Args:
    table_names: Genre, Track
Name: sql_db_schema


CREATE TABLE "Genre" (
	"GenreId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("GenreId")
)

/*
3 rows from Genre table:
GenreId	Name
1	Rock
2	Jazz
3	Metal
*/


CREATE TABLE "Track" (
	"TrackId" INTEGER 

# 实现人工审核
在执行代理的 SQL 查询之前，仔细检查它们是否存在任何意外操作或低效率是明智之举。
LangChain 代理支持内置的人工审核中间件，以增加对代理工具调用的监督。让我们配置代理，以便在调用 sql_db_query 工具时暂停以供人工审核：

In [12]:
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware 
from langgraph.checkpoint.memory import InMemorySaver 


agent = create_agent(
    model,
    tools,
    system_prompt=system_prompt,
    middleware=[ 
        HumanInTheLoopMiddleware( 
            interrupt_on={"sql_db_query": True}, 
            description_prefix="Tool execution pending approval", 
        ), 
    ], 
    checkpointer=InMemorySaver(), 
)

In [None]:
# 运行代理后，它现在会暂停以进行审核，然后再执行该工具：sql_db_query
question = "Which genre on average has the longest tracks?"
config = {"configurable": {"thread_id": "1"}} 

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    config, 
    stream_mode="values",
):
    if "__interrupt__" in step: 
        print("INTERRUPTED:") 
        interrupt = step["__interrupt__"][0] 
        for request in interrupt.value["action_requests"]: 
            print(request["description"]) 
    elif "messages" in step:
        step["messages"][-1].pretty_print()
    else:
        pass


Which genre on average has the longest tracks?

I'll help you find which genre has the longest average track duration. Let me start by exploring the database schema to understand the available tables.
Tool Calls:
  sql_db_list_tables (call_768f126abfb34a11a5990f)
 Call ID: call_768f126abfb34a11a5990f
  Args:
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

Now I'll examine the schema of the Genre and Track tables since they're most relevant to finding average track durations by genre.
Tool Calls:
  sql_db_schema (call_8b521c559bde4d359aece7)
 Call ID: call_8b521c559bde4d359aece7
  Args:
    table_names: Genre, Track
Name: sql_db_schema


CREATE TABLE "Genre" (
	"GenreId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("GenreId")
)

/*
3 rows from Genre table:
GenreId	Name
1	Rock
2	Jazz
3	Metal
*/


CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId"

In [None]:
# 我们可以继续执行，在这里接受查询，使用以下方式指挥:
from langgraph.types import Command 

for step in agent.stream(
    Command(resume={"decisions": [{"type": "approve"}]}), 
    config,
    stream_mode="values",
):
    if "messages" in step:
        step["messages"][-1].pretty_print()
    elif "__interrupt__" in step:
        print("INTERRUPTED:")
        interrupt = step["__interrupt__"][0]
        for request in interrupt.value["action_requests"]:
            print(request["description"])
    else:
        pass


The query looks correct. Now I'll execute it to find which genre has the longest average track duration.
Tool Calls:
  sql_db_query (call_5efe6e9504f845629a2f3d)
 Call ID: call_5efe6e9504f845629a2f3d
  Args:
    query: SELECT g.Name, AVG(t.Milliseconds) as avg_duration FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.GenreId, g.Name ORDER BY avg_duration DESC LIMIT 1

The query looks correct. Now I'll execute it to find which genre has the longest average track duration.
Tool Calls:
  sql_db_query (call_5efe6e9504f845629a2f3d)
 Call ID: call_5efe6e9504f845629a2f3d
  Args:
    query: SELECT g.Name, AVG(t.Milliseconds) as avg_duration FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.GenreId, g.Name ORDER BY avg_duration DESC LIMIT 1
Name: sql_db_query

[('Sci Fi & Fantasy', 2911783.0384615385)]

Based on the query results, the genre with the longest average track duration is "Sci Fi & Fantasy" with an average duration of approximately 2,911,783 milliseconds (ab