In [5]:
import pandas as pd
from decimal import Decimal
import gradio as gr
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI 
import os

# Set the OpenAI API key
os.environ['OPENAI_API_KEY'] = ''

# Database URI
mysql_uri = 'mysql+mysqlconnector://root:1234@localhost:3306/chinook'
db = SQLDatabase.from_uri(mysql_uri)

# Functions
def get_schema(_):
    schema = db.get_table_info()
    return schema

def expand_prompt(user_question):
    llm = ChatOpenAI(model='gpt-3.5-turbo')
    expand_template = """In context of the table schema below, expand the prompt don't give me SQL Code, but expand it to be more SQL Friendly, and more detailed. Don't just write the SQL, give me a prompt to be given to an LLM. ONLY give me the prompt, and nothing else.
    {schema}

    Question: {question}"""
    ex_prompt = ChatPromptTemplate.from_template(expand_template)
    expand_chain = (
        RunnablePassthrough.assign(schema=get_schema)
        | ex_prompt
        | llm
        | StrOutputParser()
    )
    return expand_chain.invoke({"question": user_question})

def run_query(query):
    print("Query being used is:- ", query)
    print("Generated answer is:- ", db.run(query), "EOL")
    return db.run(query)

def generate_response(user_question):
    # Expand the user question
    expanded_prompt = expand_prompt(user_question)
    print("The Expanded Prompt is:- ", expanded_prompt)

    # Create the SQL chain
    llm = ChatOpenAI(model='gpt-3.5-turbo')
    sql_template = """Based on the table schema below, write an SQL query keeping in mind the foreign keys that would answer the user's question. ONLY generate the SQL Query and NOTHING else:
    {schema}

    Question: {question}
    SQL Query:"""
    sql_prompt = ChatPromptTemplate.from_template(sql_template)
    sql_chain = (
        RunnablePassthrough.assign(schema=get_schema)
        | sql_prompt
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
    )

    # Generate the SQL query
    sql_query = sql_chain.invoke({"question": expanded_prompt})

    # Run the SQL query
    sql_result = run_query(sql_query)

    data = sql_result
    
    num_columns = len(data[0])

    column_names = [f'Column{i+1}' for i in range(num_columns)]
    df = pd.DataFrame(data, column_names)
    print(df)
    # Generate the final response
    response_template = """Write a natural language response which asks the user to refer to the table based on their question, and find the answer. Also summarize from the table the answer to the point as much as possible and write it as 'Summary:-':
    {schema}

    Question: {question}
    SQL Query: {query}
    SQL Response: {response}"""
    prompt_response = ChatPromptTemplate.from_template(response_template)
    full_chain = (
        RunnablePassthrough.assign(query=sql_chain).assign(
            schema=get_schema,
            response=lambda vars: run_query(vars["query"]),
        )
        | prompt_response
        | llm
    )

    # Generate the response
    final_response = full_chain.invoke({"question": expanded_prompt})
    return final_response.content

# Gradio Interface
def gradio_interface(user_question):
    return generate_response(user_question)

interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your question here..."),
    outputs="text",
    title="SQL Query Generator",
    description="Ask a question and get a detailed SQL query and response based on the Chinook database schema."
)

if __name__ == "__main__":
    interface.launch(debug=True)


  super().__init__(
  super().__init__(


IMPORTANT: You are using gradio version 3.31.0, however version 4.29.0 is available, please upgrade.
--------
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


The Expanded Prompt is:  Can you provide me with the top 5 customers based on their total invoice amount?
Query being used is:  SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalAmount
FROM customer c
JOIN invoice i ON c.CustomerId = i.CustomerId
GROUP BY c.CustomerId
ORDER BY TotalAmount DESC
LIMIT 5;
Generated answer is:  [(6, 'Helena', 'Holý', Decimal('49.62')), (26, 'Richard', 'Cunningham', Decimal('47.62')), (57, 'Luis', 'Rojas', Decimal('46.62')), (45, 'Ladislav', 'Kovács', Decimal('45.62')), (46, 'Hugh', "O'Reilly", Decimal('45.62'))] EOL
Error creating DataFrame: SQL result is not in the expected format (list of tuples).
Query being used is:  SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalAmount
FROM customer c
JOIN invoice i ON c.CustomerId = i.CustomerId
GROUP BY c.CustomerId
ORDER BY TotalAmount DESC
LIMIT 5;
Generated answer is:  [(6, 'Helena', 'Holý', Decimal('49.62')), (26, 'Richard', 'Cunningham', Decimal('47.62')), (57, 'Luis', 'Roja