In [None]:
import os
import re
from typing import TypedDict, List, Dict, Any
from dotenv import load_dotenv

from langchain_community.utilities import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import StateGraph, END

In [23]:
mysql_uri = "mysql+pymysql://root:0000@localhost:3306/sales_text2sql"
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=2)

llm = ChatGoogleGenerativeAI(model="gemma-3-27b-it", temperature=0)

In [24]:
class AgentState(TypedDict):
    question: str
    schema: str
    sql_query: str
    db_result: Any
    error_message: str
    iteration_count: int

In [25]:
def get_schema(state: AgentState):
    """Fetches the database schema. In a production app, you'd use a VectorDB 
    here to only pull relevant tables (Schema Pruning)."""
    return {
        "schema": db.get_table_info(), 
        "iteration_count": state.get("iteration_count", 0)
    }

In [26]:
def generate_sql(state: AgentState):
    #Generates SQL based on schema and previous errors.
    error_context = ""
    if state.get("error_message"):
        error_context = f"\nPREVIOUS ERROR: {state['error_message']}\nPlease fix the SQL based on this error."

    prompt = f"""
    You are a MySQL expert. Given the schema below, write a query for the question.
    SCHEMA:
    {state['schema']}
    
    QUESTION: {state['question']}
    {error_context}

    IMPORTANT:
    - Only generate valid MySQL syntax.
    - Enclose string literals in single quotes. 
    - Use backticks for columns with spaces (e.g., `Total Sale`).
    - Use only columns listed in the schema.
    """
    
    response = llm.invoke([HumanMessage(content=prompt)])
    
    match = re.search(r"```sql\s*(.*?)\s*```", response.content, re.DOTALL | re.IGNORECASE)
    sql = match.group(1).strip() if match else response.content.strip()
    
    return {"sql_query": sql, "iteration_count": state["iteration_count"] + 1}

In [27]:
def execute_sql(state: AgentState):
    #Attempts to run the query. If it fails, captures the error.
    try:
        # We use db.run to validate execution
        result = db.run(state["sql_query"])
        return {"db_result": result, "error_message": None}
    except Exception as e:
        return {"error_message": str(e)}

In [28]:
def router(state: AgentState):
    #Decides whether to finish or send back to generator for fixing.
    if state["error_message"] is None:
        return "end"
    if state["iteration_count"] >= 3:
        return "end"  # Max retries reached
    return "retry"

In [29]:
workflow = StateGraph(AgentState)

workflow.add_node("get_schema", get_schema)
workflow.add_node("generate_sql", generate_sql)
workflow.add_node("execute_sql", execute_sql)

workflow.set_entry_point("get_schema")
workflow.add_edge("get_schema", "generate_sql")
workflow.add_edge("generate_sql", "execute_sql")

workflow.add_conditional_edges(
    "execute_sql",
    router,
    {
        "retry": "generate_sql",
        "end": END
    }
)

app = workflow.compile()

In [30]:
if __name__ == "__main__":
    test_question = "What was the total sale for the East region?"
    final_state = app.invoke({"question": test_question})
    
    print(f"QUESTION: {test_question}")
    print(f"SQL GENERATED: {final_state['sql_query']}")
    print(f"RESULT: {final_state['db_result']}")
    if final_state['error_message']:
        print(f"FINAL ERROR: {final_state['error_message']}")

QUESTION: What was the total sale for the East region?
SQL GENERATED: SELECT
  SUM(`TotalSale`)
FROM sales_data
WHERE
  `Region` = 'East';
RESULT: [(112629.82999999999,)]


In [31]:
if __name__ == "__main__":
    test_question = "What was the total sale for the Category 3 in East region?"
    final_state = app.invoke({"question": test_question})
    
    print(f"QUESTION: {test_question}")
    print(f"SQL GENERATED: {final_state['sql_query']}")
    print(f"RESULT: {final_state['db_result']}")
    if final_state['error_message']:
        print(f"FINAL ERROR: {final_state['error_message']}")

QUESTION: What was the total sale for the Category 3 in East region?
SQL GENERATED: SELECT
  SUM(T1.`TotalSale`)
FROM sales_data AS T1
INNER JOIN product_catalog AS T2
  ON T1.`ProductID` = T2.id
WHERE
  T2.category = 'Category 3' AND T1.`Region` = 'East';
RESULT: [(16640.199999999997,)]
