## This is Chain based 

In [1]:
# Impport necessary libraries
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI


In [2]:
# Connect your MySQL database
# Make sure to install the required packages
host = 'localhost'
port = '3306'
username = 'root'
password = 'root'
database_schema = 'regional_sales_data'
mysql_uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=2)

In [3]:
# Database connection
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=1)

In [4]:
db.get_table_info()

'\nCREATE TABLE `2017_budgets` (\n\t`Product Name` TEXT, \n\t`2017 Budgets` DOUBLE\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n1 rows from 2017_budgets table:\nProduct Name\t2017 Budgets\nProduct 1\t3016489.2089999998\n*/\n\n\nCREATE TABLE customers (\n\t`Customer Index` INTEGER, \n\t`Customer Names` TEXT\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n1 rows from customers table:\nCustomer Index\tCustomer Names\n1\tGeiss Company\n*/\n\n\nCREATE TABLE products (\n\t`Index` INTEGER, \n\t`Product Name` TEXT\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n1 rows from products table:\nIndex\tProduct Name\n1\tProduct 1\n*/\n\n\nCREATE TABLE regions (\n\tid INTEGER, \n\tname TEXT, \n\tcounty TEXT, \n\tstate_code TEXT, \n\tstate TEXT, \n\ttype TEXT, \n\tlatitude DOUBLE, \n\tlongitude DOUBLE, \n\tarea_code INTEGER, \n\tpopulation INTEGER, \n\thouseholds INTEGER, \n\tmedian_income INTEGER, \n\tland_area INTEGER, \

In [5]:
# Create the LLM Prompt Template                  
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
Remember : Only provide me the sql query dont include anything else.
           Provide me sql query in a single line dont add line breaks.
Table Schema:
{schema}

Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)

In [6]:
# get the schema of the database
def get_schema(db):
    schema = db.get_table_info()
    return schema


In [7]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    api_key="API_KEY_HERE"
)

In [8]:
# Create the SQL query chain using the LLM and the prompt template
sql_chain = (
    RunnablePassthrough.assign(schema=lambda _: get_schema(db))
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [16]:
#test the SQL query chain with a sample question
resp=sql_chain.invoke({"question": "What was the budget of Product 12"})
print(resp)

```sql
SELECT `2017 Budgets` FROM `2017_budgets` WHERE `Product Name` = 'Product 12'
```


In [26]:
import re

query = re.search(r"```sql\s*(.*?)\s*```", resp, re.DOTALL | re.IGNORECASE)

if query:
    query=query.group(1).strip()


In [27]:
db.run(query)

'[(1356976.996,)]'