In [1]:
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_groq import ChatGroq


In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
load_dotenv()

True

In [4]:
import os
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from typing_extensions import TypedDict, Annotated

In [5]:
db_uri = os.getenv("DB_URI") 

In [None]:
from langchain_community.utilities.sql_database import SQLDatabase

# Connect to the MySQL database
db = SQLDatabase.from_uri(db_uri)   

In [41]:
# Print the database dialect and usable table names
print(db.dialect)

mysql


In [8]:
from langchain import hub
from typing_extensions import TypedDict, Annotated
from concurrent.futures import ThreadPoolExecutor, TimeoutError as ThreadTimeout
import re

In [9]:
print(db.get_table_info())


CREATE TABLE actor (
	actor_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, 
	first_name VARCHAR(45) NOT NULL, 
	last_name VARCHAR(45) NOT NULL, 
	last_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 
	PRIMARY KEY (actor_id)
)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4

/*
3 rows from actor table:
actor_id	first_name	last_name	last_update
1	PENELOPE	GUINESS	2006-02-15 04:34:33
2	NICK	WAHLBERG	2006-02-15 04:34:33
3	ED	CHASE	2006-02-15 04:34:33
*/


CREATE TABLE address (
	address_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, 
	address VARCHAR(50) NOT NULL, 
	address2 VARCHAR(50), 
	district VARCHAR(20) NOT NULL, 
	city_id SMALLINT UNSIGNED NOT NULL, 
	postal_code VARCHAR(10), 
	phone VARCHAR(20) NOT NULL, 
	last_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 
	PRIMARY KEY (address_id), 
	CONSTRAINT fk_address_city FOREIGN KEY(city_id) REFERENCES city (city_id) ON DELETE RESTRICT ON UPDATE CASCADE
)COL

In [10]:
table_name = db.get_usable_table_names()

In [11]:
table_name

['actor',
 'address',
 'category',
 'city',
 'country',
 'customer',
 'film',
 'film_actor',
 'film_category',
 'film_text',
 'inventory',
 'language',
 'payment',
 'rental',
 'staff',
 'store']

In [12]:
table_info = db.get_table_info(table_name)

In [13]:
table_info

"\nCREATE TABLE actor (\n\tactor_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, \n\tfirst_name VARCHAR(45) NOT NULL, \n\tlast_name VARCHAR(45) NOT NULL, \n\tlast_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, \n\tPRIMARY KEY (actor_id)\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n3 rows from actor table:\nactor_id\tfirst_name\tlast_name\tlast_update\n1\tPENELOPE\tGUINESS\t2006-02-15 04:34:33\n2\tNICK\tWAHLBERG\t2006-02-15 04:34:33\n3\tED\tCHASE\t2006-02-15 04:34:33\n*/\n\n\nCREATE TABLE address (\n\taddress_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, \n\taddress VARCHAR(50) NOT NULL, \n\taddress2 VARCHAR(50), \n\tdistrict VARCHAR(20) NOT NULL, \n\tcity_id SMALLINT UNSIGNED NOT NULL, \n\tpostal_code VARCHAR(10), \n\tphone VARCHAR(20) NOT NULL, \n\tlast_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, \n\tPRIMARY KEY (address_id), \n\tCONSTRAINT fk_address_city FOREIGN KEY(city_id) REFERENCES 

In [15]:
# Initialize Groq LLM
llm = ChatGroq(api_key=os.getenv("GROQ_API_KEY"), model="meta-llama/llama-4-scout-17b-16e-instruct")

In [16]:
from langchain import hub

# Pull the query prompt template from the LangChain hub
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

In [17]:
query_prompt_template

ChatPromptTemplate(input_variables=['dialect', 'input', 'table_info', 'top_k'], input_types={}, partial_variables={}, metadata={'lc_hub_owner': 'langchain-ai', 'lc_hub_repo': 'sql-query-system-prompt', 'lc_hub_commit_hash': '360a0e9d0f0f5da0ee9810a2a0ea3c9dc3de31b3ae9d50272420c33e48e6e323'}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['dialect', 'table_info', 'top_k'], input_types={}, partial_variables={}, template='Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema

In [18]:
class QueryRequest(BaseModel):
    question: str

In [19]:
from typing_extensions import TypedDict

# Define a TypedDict for the query output
class QueryOutput(TypedDict):
    """Generated SQL query."""
    query: Annotated[str, ..., "Syntactically valid SQL query."]

In [20]:
# 🧠 Table extraction logic
def extract_relevant_tables(user_input: str, all_tables: list[str]) -> list[str]:
    user_input_lower = user_input.lower() 
    return [table for table in all_tables if table.lower() in user_input_lower]


In [21]:
# 🔒 Query validation
def validate_query(query: str):
    disallowed = ["drop", "delete", "truncate", "alter", "update"]
    if any(danger in query.lower() for danger in disallowed):
        raise HTTPException(status_code=400, detail="Unsafe SQL query detected!")


In [22]:
user_input ="fetch me phone number and address of the id who resides in the Mandalay district"

In [23]:
all_tables = db.get_usable_table_names()
relevant_tables = extract_relevant_tables(user_input, all_tables)

In [24]:
relevant_tables

['address']

In [25]:
db.get_table_info(table_names=relevant_tables)

'\nCREATE TABLE address (\n\taddress_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, \n\taddress VARCHAR(50) NOT NULL, \n\taddress2 VARCHAR(50), \n\tdistrict VARCHAR(20) NOT NULL, \n\tcity_id SMALLINT UNSIGNED NOT NULL, \n\tpostal_code VARCHAR(10), \n\tphone VARCHAR(20) NOT NULL, \n\tlast_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, \n\tPRIMARY KEY (address_id), \n\tCONSTRAINT fk_address_city FOREIGN KEY(city_id) REFERENCES city (city_id) ON DELETE RESTRICT ON UPDATE CASCADE\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n3 rows from address table:\naddress_id\taddress\taddress2\tdistrict\tcity_id\tpostal_code\tphone\tlast_update\n1\t47 MySakila Drive\tNone\tAlberta\t300\t\t\t2014-09-25 22:30:27\n2\t28 MySQL Boulevard\tNone\tQLD\t576\t\t\t2014-09-25 22:30:09\n3\t23 Workhaven Lane\tNone\tAlberta\t300\t\t14033335568\t2014-09-25 22:30:27\n*/'

In [26]:
def generate_sql_query(question: str) -> dict:
    all_tables = db.get_usable_table_names()
    relevant_tables = extract_relevant_tables(question, all_tables)
    table_info = db.get_table_info(table_names=relevant_tables) if relevant_tables else db.get_table_info()

    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": table_info,
        "input": question,
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    
    result = structured_llm.invoke(prompt)

    query = result["query"]
    validate_query(query)
    return {"query": query}

In [27]:
def execute_sql_query(query: str) -> dict:
    """Execute the provided SQL query."""
    validate_query(query)

    tool = QuerySQLDatabaseTool(db=db) 
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(query)}

In [28]:
# Generate the SQL query
generated_query = generate_sql_query("fetch me phone number and address who resides in the Mandalay district")

In [29]:
generated_query

{'query': "SELECT phone, address FROM address WHERE district = 'Mandalay' LIMIT 10"}

In [30]:
# Execute the generated SQL query
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'result': "[('705814003527', '1566 Inegöl Manor')]"}


In [31]:
final_result

{'result': "[('705814003527', '1566 Inegöl Manor')]"}

In [32]:
# Generate the SQL query
generated_query = generate_sql_query("Which actor has appeared in the highest-grossing films (based on total rental revenue), and what is the average rental duration for films they appeared in?")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'query': 'SELECT T1.first_name ,  T1.last_name ,  SUM(T3.rental_rate * T3.rental_duration) ,  AVG(T3.rental_duration) FROM actor AS T1 INNER JOIN film_actor AS T2 ON T1.actor_id = T2.actor_id INNER JOIN film AS T3 ON T2.film_id = T3.film_id GROUP BY T1.first_name ,  T1.last_name ORDER BY SUM(T3.rental_rate * T3.rental_duration) DESC LIMIT 1'}
{'result': "[('SUSAN', 'DAVIS', Decimal('723.58'), Decimal('4.4815'))]"}


In [34]:
# # Generate the SQL query
# generated_query = generate_sql_query("Identify the top 3 cities with the highest average rental revenue per customer over the last year of available rental data. For each of these cities, provide the total number of customers, total revenue, and the top 2 film categories (based on rental frequency) preferred by customers from those cities. Also, indicate the average number of rentals per customer in those cities.")

# # Execute the generated SQL query
# print(generated_query)
# final_result = execute_sql_query(generated_query["query"])

# # Print the final result
# print(final_result)

In [None]:
# Generate the SQL query
generated_query = generate_sql_query("Which country has the highest average customer lifetime value (total payments divided by rental count per customer), and what are the top 3 cities in that country by total revenue?")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

BadRequestError: Error code: 400 - {'error': {'message': "Failed to call a function. Please adjust your prompt. See 'failed_generation' for more details.", 'type': 'invalid_request_error', 'code': 'tool_use_failed', 'failed_generation': ' \n  SELECT \n    c.country, \n    AVG(p.total_payments / r.rental_count) AS avg_lifetime_value\n  FROM \n    country c\n  JOIN \n    customer cu ON c.country_id = cu.customer_id\n  JOIN \n    (\n      SELECT \n        customer_id, \n        SUM(amount) AS total_payments\n      FROM \n        payment\n      GROUP BY \n        customer_id\n    ) p ON cu.customer_id = p.customer_id\n  JOIN \n    (\n      SELECT \n        customer_id, \n        COUNT(rental_id) AS rental_count\n      FROM \n        rental\n      GROUP BY \n        customer_id\n    ) r ON cu.customer_id = r.customer_id\n  GROUP BY \n    c.country\n  ORDER BY \n    avg_lifetime_value DESC\n  LIMIT 1;\n</function>\n\n<function=QueryOutput> \n  SELECT \n    a.city, \n    SUM(p.amount) AS total_revenue\n  FROM \n    address a\n  JOIN \n    customer cu ON a.address_id = cu.address_id\n  JOIN \n    payment p ON cu.customer_id = p.customer_id\n  WHERE \n    cu.country_id = (\n      SELECT \n        c.country_id\n      FROM \n        country c\n      JOIN \n        customer cu ON c.country_id = cu.customer_id\n      JOIN \n        (\n          SELECT \n            customer_id, \n            SUM(amount) AS total_payments\n          FROM \n            payment\n          GROUP BY \n            customer_id\n        ) p ON cu.customer_id = p.customer_id\n      JOIN \n        (\n          SELECT \n            customer_id, \n            COUNT(rental_id) AS rental_count\n          FROM \n            rental\n          GROUP BY \n            customer_id\n        ) r ON cu.customer_id = r.customer_id\n      GROUP BY \n        c.country\n      ORDER BY \n        AVG(p.total_payments / r.rental_count) DESC\n      LIMIT 1\n    )\n  GROUP BY \n    a.city\n  ORDER BY \n    total_revenue DESC\n  LIMIT 3;\n</function>'}}

In [35]:
# Generate the SQL query
generated_query = generate_sql_query("fetch me release year,length,rating, language of this movie ARABIA DOGMA ")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result 
print(final_result)

{'query': "SELECT release_year, length, rating, language FROM film WHERE title = 'ARABIA DOGMA' LIMIT 10"}
{'result': 'Error: (pymysql.err.OperationalError) (1054, "Unknown column \'language\' in \'field list\'")\n[SQL: SELECT release_year, length, rating, language FROM film WHERE title = \'ARABIA DOGMA\' LIMIT 10]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'}


In [36]:
# Generate the SQL query
generated_query = generate_sql_query("fetch me count movies released in 2006 with NC-17 rating ")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'query': "SELECT COUNT(film_id) FROM film WHERE release_year = 2006 AND rating = 'NC-17'"}
{'result': '[(210,)]'}


In [37]:
# Generate the SQL query
generated_query = generate_sql_query("Which staff member has processed the most payments, and what is the total amount they have handled?")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'query': 'SELECT s.first_name, s.last_name, COUNT(p.payment_id) as total_payments, SUM(p.amount) as total_amount FROM payment p JOIN staff s ON p.staff_id = s.staff_id GROUP BY s.staff_id ORDER BY total_payments DESC LIMIT 1'}
{'result': "[('Mike', 'Hillyer', 8054, Decimal('33482.50'))]"}


In [38]:
# Generate the SQL query
generated_query = generate_sql_query("What is the total revenue generated by each store location from film rentals?")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'query': 'SELECT s.store_id, SUM(f.rental_rate) AS total_revenue FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id JOIN store s ON i.store_id = s.store_id GROUP BY s.store_id LIMIT 10'}
{'result': "[(1, Decimal('23509.77')), (2, Decimal('23701.79'))]"}


In [39]:
# Generate the SQL query
generated_query = generate_sql_query("List all stores with their staff counts. Which store has the most staff?")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'query': 'SELECT s.store_id, COUNT(st.staff_id) as staff_count FROM store s JOIN staff st ON s.store_id = st.store_id GROUP BY s.store_id ORDER BY staff_count DESC LIMIT 10;'}
{'result': '[(1, 1), (2, 1)]'}


In [40]:
# Generate the SQL query
generated_query = generate_sql_query("List all films that have never been rented, including their descriptions and categories.")

# Execute the generated SQL query
print(generated_query)
final_result = execute_sql_query(generated_query["query"])

# Print the final result
print(final_result)

{'query': 'SELECT f.description, f.title, l.name AS category FROM film f JOIN language l ON f.language_id = l.language_id WHERE f.film_id NOT IN (SELECT film_id FROM rental) LIMIT 10'}
{'result': ''}
