###  使用 官方的SQL链 实现 数据查询

In [2]:
from langchain.chat_models import  ChatOpenAI
from langchain.agents import initialize_agent,AgentType, Tool
from langchain.utilities import SQLDatabase
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Type, List, Optional, Any
import os
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
# 获取 环境变量
#大模型apikey和url
api_key = os.getenv('api_key')
base_url = os.getenv('base_url')

#天气查询的apikey
juheAppKey = os.getenv('juheAppKey')

#数据库信息
mysql_info = os.getenv('mysql_info')

In [4]:
#定义大模型
online_model = ChatOpenAI(
    api_key=api_key,
    model="qwen-plus",
    base_url=base_url,
    temperature=0
)

  online_model = ChatOpenAI(


In [5]:
#定义数据库 连接
db = SQLDatabase.from_uri(mysql_info)


In [6]:
#查看数据库中的可用表
db.get_usable_table_names()

['city_stats']

In [7]:
#查看表结构
print(db.get_table_info())


CREATE TABLE city_stats (
	city_name VARCHAR(255), 
	population VARCHAR(255), 
	area VARCHAR(255)
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from city_stats table:
city_name	population	area
北京	20000000	400000234
上海	17689300	29867200
河北	19999944	39999821
*/


In [8]:
#查看数据库信息
print(db.get_context())

{'table_info': '\nCREATE TABLE city_stats (\n\tcity_name VARCHAR(255), \n\tpopulation VARCHAR(255), \n\tarea VARCHAR(255)\n)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci\n\n/*\n3 rows from city_stats table:\ncity_name\tpopulation\tarea\n北京\t20000000\t400000234\n上海\t17689300\t29867200\n河北\t19999944\t39999821\n*/', 'table_names': 'city_stats'}


In [9]:
#使用 官方的SQL链

from langchain.chains.sql_database.query import create_sql_query_chain

sql_chain = create_sql_query_chain(
    llm=online_model,
    db=db,
    k=2
)

In [10]:
#查看sql链的prompt
sql_chain.get_prompts()[0].pretty_print()

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 2 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [11]:
#TEST
response = sql_chain.invoke({"question":'数据库中有哪些表？'})
response

'SHOW TABLES;'

In [12]:
response = sql_chain.invoke({"question":'请你向city_stats表中插入一条数据，其中city为成都，population为2000，area为1000'})

response

"INSERT INTO `city_stats` (`city_name`, `population`, `area`) VALUES ('成都', '2000', '1000')"

In [13]:
db.run('SHOW COLUMNS FROM `city_stats`')

"[('city_name', 'varchar(255)', 'YES', '', None, ''), ('population', 'varchar(255)', 'YES', '', None, ''), ('area', 'varchar(255)', 'YES', '', None, '')]"

In [None]:
def format_sql(sql: str):

    if 'SQLQuery:' in sql:
        sql = sql.split('SQLQuery:')[1]
    return sql


#QuerySQLDataBaseTool :官方自定义工具，可以执行sql语句
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.runnables import RunnableLambda
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(online_model, db)

chain = write_query | RunnableLambda(lambda x: format_sql(x)) | execute_query

result = chain.invoke({"question":'查询city_stats的表中全部信息'})


  execute_query = QuerySQLDataBaseTool(db=db)


In [15]:
result

"[('北京', '20000000', '400000234'), ('上海', '17689300', '29867200'), ('河北', '19999944', '39999821')]"

- 在第一个RunnablePassthrough.assign之后，会返回一个包含两个元素的可运行对象：{"question": question, "query": write_query.invoke(question)}，其中write_query将生成一个SQL查询语句，以回答这个问题。
- 在第二个RunnablePassthrough.assign之后，我们添加了第三个元素"result"，它包含execute_query.invoke(query)，其中query是前一步生成的结果。
- 这三个输入被格式化成提示，并传递给LLM。
- StrOutputParser()提取输出消息的字符串内容。

In [16]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
answer_prompt = PromptTemplate.from_template(
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
)
| answer_prompt
| online_model
| StrOutputParser()
)
chain.invoke({"question": "上海有多少人,城市有多大"})

'上海的人口为17,689,300人，城市的面积为29,867,200平方米。'