# Database Operation

In [1]:
from langchain_community.utilities import SQLDatabase

# Define database connection parameters
db_user = "root"
db_password = "q1w2e3r4"
db_host = "localhost"
db_name = "llm_asst"

# 数据库连接
# db = SQLDatabase.from_uri("sqlite:///nba_roster.db", sample_rows_in_table_info=0)
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

print(db.dialect)
print(db.get_usable_table_names())

# 获取数据库架构信息的函数
def get_schema(_):
    return db.get_table_info()

# 执行 SQL 查询的函数
def run_query(query):
    return db.run(query)


mysql
['memos', 'schedules']


In [2]:
# how to convert \n in string print out
schema_str = get_schema(["schedules"])
print(schema_str)


CREATE TABLE memos (
	id INTEGER NOT NULL AUTO_INCREMENT, 
	memo TEXT NOT NULL, 
	created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, 
	PRIMARY KEY (id)
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from memos table:
id	memo	created_at
1	Remember to buy milk	2024-10-19 16:18:35
2	Prepare materials for tomorrow's meeting	2024-10-19 16:18:35
3	Call mom	2024-10-19 16:18:35
*/


CREATE TABLE schedules (
	id INTEGER NOT NULL AUTO_INCREMENT, 
	event VARCHAR(255) NOT NULL, 
	event_date DATE NOT NULL, 
	start_time TIME, 
	end_time TIME, 
	description TEXT, 
	PRIMARY KEY (id)
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from schedules table:
id	event	event_date	start_time	end_time	description
1	Meeting	2024-01-10	10:00:00	11:00:00	Discuss project progress with the team
2	Birthday Party	2024-02-05	18:00:00	21:00:00	Celebrate a friend's birthday
3	Gym	2024-01-15	07:00:00	08:00:00	Morning workout
*/


# Main Job

In [7]:
# Update the template based on the SQL database type and schema (e.g., MySQL, Microsoft SQL Server, etc.)

from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

model = OllamaLLM(model="llama3.1")

template1 = """
    Based on the table schema below, write an MySQL query to answer the user's question:
    {schema}
    Question: {question}
    SQL Query:
"""

prompt1 = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to an SQL query. No preamble."),
        ("human", template1),
    ]
)


# Execute the SQL query and get the response

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt1
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


response = sql_response.invoke({"question":"What do I have planned for today?"})

In [8]:
response

'SELECT event FROM schedules WHERE DATE(event_date) = CURDATE()'

In [3]:

template2 = """
    Based on the table schema, question, SQL query, and SQL response below, write a natural language answer:
    {schema}

    Question: {question}
    SQL Query: {query}
    SQL Response: {response}
"""

prompt2 = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given a question and SQL response, convert it to a natural language answer. No preamble.",
        ),
        ("human", template2),
    ]
)

# Build the complete chain and call
full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt2
    | model
)

full_chain.invoke({"question": "When is the birthday party?"})

'The birthday party is scheduled to take place on February 5th, 2024.'

In [34]:
# gradio for web interface(multi-turn conversation)

import gradio as gr

def chat(message, history):
    response = full_chain.invoke({"question": message})
    return response

# chatbot = gr.Chatbot(height=300)
iface = gr.ChatInterface(fn=chat,
                         type="messages")

iface.launch(share=True)

  from .autonotebook import tqdm as notebook_tqdm


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://46014bff61ff12f9d2.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Traceback (most recent call last):
  File "/Users/yzey/miniconda3/envs/llm-asst/lib/python3.10/site-packages/sqlalchemy/engine/base.py", line 1967, in _exec_single_context
    self.dialect.do_execute(
  File "/Users/yzey/miniconda3/envs/llm-asst/lib/python3.10/site-packages/sqlalchemy/engine/default.py", line 941, in do_execute
    cursor.execute(statement, parameters)
  File "/Users/yzey/miniconda3/envs/llm-asst/lib/python3.10/site-packages/pymysql/cursors.py", line 153, in execute
    result = self._query(query)
  File "/Users/yzey/miniconda3/envs/llm-asst/lib/python3.10/site-packages/pymysql/cursors.py", line 322, in _query
    conn.query(q)
  File "/Users/yzey/miniconda3/envs/llm-asst/lib/python3.10/site-packages/pymysql/connections.py", line 563, in query
    self._affected_rows = self._read_query_result(unbuffered=unbuffered)
  File "/Users/yzey/miniconda3/envs/llm-asst/lib/python3.10/site-packages/pymysql/connections.py", line 825, in _read_query_result
    result.read()
  File 