In [4]:
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

In [5]:
import os
os.environ['OPENAI_API_KEY'] = ''

In [12]:
db = SQLDatabase.from_uri("mysql+pymysql://lang:lang@127.0.0.1:3306/Chinook?charset=utf8")
llm = OpenAI(temperature=0, verbose=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [13]:
db_chain.run("How many employees are there?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many employees are there?
SQLQuery:[32;1m[1;3mSELECT COUNT(*) AS `NumberOfEmployees` FROM `Employee`;[0m
SQLResult: [33;1m[1;3m[(8,)][0m
Answer:[32;1m[1;3mThere are 8 employees.[0m
[1m> Finished chain.[0m


'There are 8 employees.'

In [15]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(ChatOpenAI(temperature=0), db)
response = chain.invoke({"question":"How many employees are there"})
print(response)

SELECT COUNT(*) FROM Employee


In [16]:
db.run(response)

'[(8,)]'

In [19]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
# from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

db = SQLDatabase.from_uri("mysql+pymysql://lang:lang@127.0.0.1:3306/Chinook?charset=utf8")
llm = OpenAI(temperature=0, verbose=True)

agent_executor = create_sql_agent(
    llm=OpenAI(temperature=0),
    toolkit=SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0)),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
agent_executor.run(
    "List the total sales per country. Which country's customers spent the most?"
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m
Thought:[32;1m[1;3m I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer[0m
Observation: [33;1m[1;3m
CREATE TABLE `Customer` (
	`CustomerId` INTEGER(11) NOT NULL, 
	`FirstName` VARCHAR(40) CHARACTER SET utf8 NOT NULL, 
	`LastName` VARCHAR(20) CHARACTER SET utf8 NOT NULL, 
	`Company` VARCHAR(80) CHARACTER SET utf8, 
	`Address` VARCHAR(70) CHARACTER SET utf8, 
	`City` VARCHAR(40) CHARACTER SET utf8, 
	`State` VARCHAR(40) CHARACTER SET utf8, 
	`Country` VARCHAR(40) CHARACTER SET utf8, 
	`PostalCode` VARCHAR(10) CHARACTER SET utf8, 
	`Phone` VARCHAR(24) CHARACTER SET utf8, 
	`Fax` VARCHAR(24) CHARACTER SET utf8, 
	`Email` VARCHAR(60) CHARACTER SET utf8 NOT NULL, 
	`SupportRepId` I

'USA customers spent the most, with a total of $523.06.'

In [20]:
db_chain.run("共有多少员工?")



[1m> Entering new SQLDatabaseChain chain...[0m
共有多少员工?
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM Employee;[0m
SQLResult: [33;1m[1;3m[(8,)][0m
Answer:[32;1m[1;3m共有8名员工。[0m
[1m> Finished chain.[0m


'共有8名员工。'

In [21]:
agent_executor.run(
    "列出每个国家/地区的总销售额。哪个国家的客户花费最多？"
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m
Thought:[32;1m[1;3m I should query the schema of the Customer and Invoice tables.
Action: sql_db_schema
Action Input: Customer, Invoice[0m
Observation: [33;1m[1;3m
CREATE TABLE `Customer` (
	`CustomerId` INTEGER(11) NOT NULL, 
	`FirstName` VARCHAR(40) CHARACTER SET utf8 NOT NULL, 
	`LastName` VARCHAR(20) CHARACTER SET utf8 NOT NULL, 
	`Company` VARCHAR(80) CHARACTER SET utf8, 
	`Address` VARCHAR(70) CHARACTER SET utf8, 
	`City` VARCHAR(40) CHARACTER SET utf8, 
	`State` VARCHAR(40) CHARACTER SET utf8, 
	`Country` VARCHAR(40) CHARACTER SET utf8, 
	`PostalCode` VARCHAR(10) CHARACTER SET utf8, 
	`Phone` VARCHAR(24) CHARACTER SET utf8, 
	`Fax` VARCHAR(24) CHARACTER SET utf8, 
	`Email` VARCHAR(60) CHARACTER SET utf8 NOT NULL, 
	`SupportRepId` I

'美国的客户花费最多，总销售额为523.06美元。'