In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
import os
# Load OpenAI API key from environment variable
openai_api_key = os.getenv('OPENAI_API_KEY')

In [25]:
template = """
    Based on the table schema below, write a sql query that woulf anser the user's question.
    The table schema is as follows:
    {schema}
    
    Question: {question}
    SQL Query
"""

In [26]:
prompt = ChatPromptTemplate.from_template(template)
prompt.format(schema='my schema', question='how many users are there?')

"Human: \n    Based on the table schema below, write a sql query that woulf anser the user's question.\n    The table schema is as follows:\n    my schema\n\n    Question: how many users are there?\n    SQL Query\n"

In [None]:
db_uri = 'mysql+mysqlconnector://root:{password}@localhost:3306/chinook'
db = SQLDatabase.from_uri(db_uri)


In [28]:
db.run("select * from album limit 5;")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [29]:
def get_schema(_):
    return db.get_table_info()

In [37]:
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")

In [38]:
chain = RunnablePassthrough.assign(schema=get_schema) | prompt | llm.bind(stop='\nSQL Result:') | StrOutputParser()

'SELECT COUNT(*) AS total_albums FROM album;'

In [42]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [43]:
def run_query(query):
    return db.run(query)


In [45]:
full_chain = (
    RunnablePassthrough.assign(query=chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
)


In [53]:
user_question = 'Get me Nancy Edwards\'s manager name'
query = chain.invoke({
    'question': user_question
})
print(query)
result = full_chain.invoke({"question": user_question})

print(result.content)


SELECT e1.FirstName, e1.LastName
FROM employee e1
JOIN employee e2 ON e1.EmployeeId = e2.ReportsTo
WHERE e2.FirstName = 'Nancy' AND e2.LastName = 'Edwards';
Nancy Edwards's manager's name is Andrew Adams.
