In [82]:
import os 
os.environ['OPENAI_API_KEY'] = "key"

In [83]:
user_question = "How many tables are in this schema and name each table with the number of rows in each table"

In [84]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Based on the table schema below, write a SQL query that would answer the user's question:
{schema}


Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

In [85]:
prompt.format(schema='my_schema', question=user_question)

"Human: \nBased on the table schema below, write a SQL query that would answer the user's question:\nmy_schema\n\n\nQuestion: How many tables are in this schema and name each table with the number of rows in each table\nSQL Query:\n"

In [86]:
from langchain_community.utilities import SQLDatabase

db_uri = "mysql+mysqlconnector://root:alskdj10@localhost:3306/Chinook"
db = SQLDatabase.from_uri(db_uri)

In [87]:
def get_schema(_):
    return db.get_table_info()

In [88]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop='\nSQL_Result:')
    | StrOutputParser()
)


In [89]:
sql_chain.invoke({"question": user_question})

"SELECT 'Album' AS Table_Name, COUNT(*) AS Row_Count FROM Album\nUNION\nSELECT 'Artist' AS Table_Name, COUNT(*) AS Row_Count FROM Artist\nUNION\nSELECT 'Customer' AS Table_Name, COUNT(*) AS Row_Count FROM Customer\nUNION\nSELECT 'Employee' AS Table_Name, COUNT(*) AS Row_Count FROM Employee\nUNION\nSELECT 'Genre' AS Table_Name, COUNT(*) AS Row_Count FROM Genre\nUNION\nSELECT 'InvoiceLine' AS Table_Name, COUNT(*) AS Row_Count FROM InvoiceLine\nUNION\nSELECT 'Invoice' AS Table_Name, COUNT(*) AS Row_Count FROM Invoice\nUNION\nSELECT 'MediaType' AS Table_Name, COUNT(*) AS Row_Count FROM MediaType\nUNION\nSELECT 'PlaylistTrack' AS Table_Name, COUNT(*) AS Row_Count FROM PlaylistTrack\nUNION\nSELECT 'Playlist' AS Table_Name, COUNT(*) AS Row_Count FROM Playlist\nUNION\nSELECT 'Track' AS Table_Name, COUNT(*) AS Row_Count FROM Track;"

In [90]:
response_template = """
Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

new_prompt = ChatPromptTemplate.from_template(response_template)



In [91]:
def run_query(query):
    return db.run(query)

In [92]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda variables: run_query(variables["query"])
    )
    | new_prompt
    | llm
    | StrOutputParser()
)

In [93]:
full_chain.invoke({"question": user_question})

'There are 11 tables in this schema. The tables are:\n1. Album (347 rows)\n2. Artist (275 rows)\n3. Customer (59 rows)\n4. Employee (8 rows)\n5. Genre (25 rows)\n6. InvoiceLine (2240 rows)\n7. Invoice (412 rows)\n8. MediaType (5 rows)\n9. PlaylistTrack (8715 rows)\n10. Playlist (18 rows)\n11. Track (3503 rows)'