In [None]:
import pandas as pd
import mysql.connector
from sqlalchemy import create_engine, text

# Connect to MySQL
conn = mysql.connector.connect(
    host="",
    user="",
    password=""
)
cursor = conn.cursor()

# Create database if it doesn't exist
cursor.execute("CREATE DATABASE IF NOT EXISTS customer_products")
cursor.execute("USE customer_products")

# Create tables
cursor.execute("""
CREATE TABLE IF NOT EXISTS customer_purchases (
    id INT AUTO_INCREMENT PRIMARY KEY,
    msisdn VARCHAR(15) NOT NULL,
    product_id INT NOT NULL
)
""")

cursor.execute("""
CREATE TABLE IF NOT EXISTS products (
    product_id INT PRIMARY KEY,
    product_name VARCHAR(100) NOT NULL,
    validity_days INT NOT NULL,
    price DECIMAL(10, 2) NOT NULL
)
""")

# Insert sample data into products
product_data = [
    (1, 'Data Pack 1GB', 7, 49.00),
    (2, 'Data Pack 3GB', 28, 149.00),
    (3, 'Unlimited Calls', 30, 129.00),
    (4, 'Family Pack', 28, 299.00),
    (5, 'International Roaming', 10, 499.00)
]

cursor.executemany(
    "INSERT INTO products (product_id, product_name, validity_days, price) VALUES (%s, %s, %s, %s)",
    product_data
)

# Insert sample customer purchases
customer_data = [
    ('9123456789', 1),
    ('9123456789', 3),
    ('9123456789', 5),
    ('9087654321', 2),
    ('9087654321', 4),
    ('9555555555', 1),
    ('9555555555', 2),
    ('9555555555', 3),
    ('9444444444', 2),
    ('9999999999', 5)
]

cursor.executemany(
    "INSERT INTO customer_purchases (msisdn, product_id) VALUES (%s, %s)",
    customer_data
)

conn.commit()
cursor.close()
conn.close()

print("Database and sample data created successfully!")

Database and sample data created successfully!


In [2]:
# Step 1: Install required packages
import os
import re
from typing import List, Dict, Any, Optional
import pandas as pd
from pprint import pprint

In [3]:

# LangChain imports
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_groq import ChatGroq
from langchain.sql_database import SQLDatabase
from langchain_core.pydantic_v1 import BaseModel, Field


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


In [None]:
# LangGraph imports
from langgraph.graph import StateGraph, END

# Database connection
from sqlalchemy import create_engine, inspect, MetaData, text

# Set Groq API key
os.environ["GROQ_API_KEY"] = ""  # Replace with actual key

# Initialize the LLM
llm = ChatGroq(model="llama3-70b-8192", temperature=0)

# Set up database connection
db_uri = ""
engine = create_engine(db_uri)
db = SQLDatabase.from_uri(db_uri)

In [13]:
# State definition
class AgentState(BaseModel):
    query: str = Field(description="The user's natural language query")
    available_tables: Optional[List[str]] = Field(default=None, description="List of available tables in the database")
    relevant_tables: Optional[List[str]] = Field(default=None, description="List of tables relevant to the query")
    table_schemas: Optional[Dict[str, str]] = Field(default=None, description="DDL definitions of relevant tables")
    generated_sql: Optional[str] = Field(default=None, description="SQL query generated from the user question")
    execution_result: Optional[Any] = Field(default=None, description="Result of executing the SQL query")
    execution_error: Optional[str] = Field(default=None, description="Error message if SQL execution failed")
    corrected_sql: Optional[str] = Field(default=None, description="Corrected SQL query if original query failed")
    final_response: Optional[str] = Field(default=None, description="Final response to the user")

# Step 1: Fetch available tables
def fetch_available_tables(state: AgentState) -> AgentState:
    """Fetch all available tables from the database"""
    inspector = inspect(engine)
    available_tables = inspector.get_table_names()
    
    return AgentState(
        query=state.query,
        available_tables=available_tables
    )

# Step 2: Determine relevant tables
def determine_relevant_tables(state: AgentState) -> AgentState:
    """Determine which tables are relevant to the user's query"""
    query = state.query
    available_tables = state.available_tables
    
    # Prompt to determine relevant tables
    prompt = PromptTemplate.from_template("""
    You are an expert SQL database analyst. Given the user's question and the list of available tables,
    determine which tables are most likely needed to answer the question.
    
    User question: {query}
    
    Available tables: {available_tables}
    
    Return only a comma-separated list of the relevant table names, nothing else.
    """)
    
    chain = prompt | llm | StrOutputParser()
    tables_string = chain.invoke({
        "query": query,
        "available_tables": ", ".join(available_tables)
    })
    
    # Extract table names (handling potential LLM formatting variations)
    table_pattern = re.compile(r'[\w_]+')
    potential_tables = table_pattern.findall(tables_string)
    
    # Filter to only include actual tables that exist in the database
    relevant_tables = [table for table in potential_tables if table in available_tables]
    
    # If no tables were found but we have available tables, default to all tables
    if not relevant_tables and available_tables:
        relevant_tables = available_tables
    
    return AgentState(
        query=state.query,
        available_tables=state.available_tables,
        relevant_tables=relevant_tables
    )

# Step 3: Retrieve table DDL
def retrieve_table_schemas(state: AgentState) -> AgentState:
    """Retrieve the DDL (schema) information for the relevant tables"""
    relevant_tables = state.relevant_tables
    table_schemas = {}
    
    for table in relevant_tables:
        try:
            # Get column information 
            with engine.connect() as connection:
                inspector = inspect(engine)
                columns = inspector.get_columns(table)
                column_details = []
                
                for column in columns:
                    col_type = str(column['type'])
                    nullable = "NULL" if column.get('nullable', True) else "NOT NULL"
                    default = f"DEFAULT {column.get('default')}" if column.get('default') is not None else ""
                    column_details.append(f"{column['name']} {col_type} {nullable} {default}".strip())
                
                # Get primary key information
                pk_info = inspector.get_pk_constraint(table)
                if pk_info and pk_info.get('constrained_columns'):
                    pk_cols = ", ".join(pk_info['constrained_columns'])
                    column_details.append(f"PRIMARY KEY ({pk_cols})")
                
                # Get foreign key information
                fk_info = inspector.get_foreign_keys(table)
                for fk in fk_info:
                    src_cols = ", ".join(fk['constrained_columns'])
                    ref_table = fk['referred_table']
                    ref_cols = ", ".join(fk['referred_columns'])
                    column_details.append(f"FOREIGN KEY ({src_cols}) REFERENCES {ref_table}({ref_cols})")
                
                # Construct CREATE TABLE statement
                create_stmt = f"CREATE TABLE {table} (\n  " + ",\n  ".join(column_details) + "\n);"
                table_schemas[table] = create_stmt
                
                # Get sample data (first 5 rows)
                sample_data_query = f"SELECT * FROM {table} LIMIT 5"
                sample_df = pd.read_sql(sample_data_query, connection)
                if not sample_df.empty:
                    sample_str = f"\n\n-- Sample data from {table}:\n"
                    sample_str += sample_df.to_string(index=False)
                    table_schemas[table] += sample_str
                    
        except Exception as e:
            table_schemas[table] = f"Error retrieving schema: {str(e)}"
    
    return AgentState(
        query=state.query,
        available_tables=state.available_tables,
        relevant_tables=state.relevant_tables,
        table_schemas=table_schemas
    )

# Step 4: Generate SQL query
def generate_sql_query(state: AgentState) -> AgentState:
    """Generate a SQL query based on the user's question and table schemas"""
    query = state.query
    table_schemas = state.table_schemas
    
    # Prompt for SQL generation
    prompt = PromptTemplate.from_template("""
    You are an expert SQL developer specializing in MySQL syntax. Generate a SQL query to answer this question:
    
    User question: {query}
    
    Here are the relevant database table definitions:
    
    {table_schemas}
    
    Guidelines for query generation:
    1. Use table joins when needed (prefer INNER JOIN over comma joins)
    2. Use appropriate WHERE clauses for filtering
    3. Include ORDER BY for sorting requirements
    4. Add LIMIT clauses when user asks for "top N" results
    5. Use GROUP BY with aggregate functions when appropriate
    6. Ensure all column references are fully qualified (table.column)
    7. Keep the query efficient and focused on answering the specific question
    
    Return only the SQL query, no additional text or explanation.
    """)
    
    schemas_text = "\n\n".join([f"-- Table: {t}\n{schema}" for t, schema in table_schemas.items()])
    
    chain = prompt | llm | StrOutputParser()
    sql_query = chain.invoke({
        "query": query,
        "table_schemas": schemas_text
    })
    
    # Clean up the query (remove SQL codeblocks if present)
    sql_query = re.sub(r'```sql\s*|\s*```', '', sql_query).strip()
    
    return AgentState(
        query=state.query,
        available_tables=state.available_tables,
        relevant_tables=state.relevant_tables,
        table_schemas=state.table_schemas,
        generated_sql=sql_query
    )

# Step 5: Validate and optimize SQL
def validate_sql(state: AgentState) -> AgentState:
    """Validate and optimize the generated SQL query"""
    query = state.query
    sql_query = state.generated_sql
    table_schemas = state.table_schemas
    
    # Prompt for SQL validation
    prompt = PromptTemplate.from_template("""
    You are an expert SQL reviewer. Review this SQL query for correctness and optimization:
    
    Original User Question: {query}
    Relevant Tables Schemas:
    {table_schemas}
    
    Generated SQL Query: 
    {sql_query}
    
    Check for:
    1. Syntax errors or invalid table/column references
    2. Missing JOIN conditions that could cause cartesian products
    3. Inefficient query patterns (missing indexes, unnecessary operations)
    4. Logic errors that would not answer the user's question correctly
    5. Proper use of GROUP BY, HAVING, and aggregation functions
    
    If the query looks correct, return it unchanged. If it needs corrections, fix it and return the corrected query.
    Return only the final SQL query with no explanation.
    """)
    
    schemas_text = "\n\n".join([f"-- Table: {t}\n{schema}" for t, schema in table_schemas.items()])
    
    chain = prompt | llm | StrOutputParser()
    validated_sql = chain.invoke({
        "query": query,
        "sql_query": sql_query,
        "table_schemas": schemas_text
    })
    
    # Clean up the query (remove SQL codeblocks if present)
    validated_sql = re.sub(r'```sql\s*|\s*```', '', validated_sql).strip()
    
    # Only update if the validated SQL is different
    if validated_sql != sql_query:
        return AgentState(
            query=state.query,
            available_tables=state.available_tables,
            relevant_tables=state.relevant_tables,
            table_schemas=state.table_schemas,
            generated_sql=validated_sql
        )
    return state

# Step 6: Execute SQL query
def execute_sql_query(state: AgentState) -> AgentState:
    """Execute the generated SQL query"""
    sql_query = state.generated_sql
    
    try:
        # Execute the query
        result = db.run(sql_query)
        
        # If result is empty, try to provide a more informative message
        if not result or result.strip() == "":
            result = "Query executed successfully, but returned no results."
            
        return AgentState(
            query=state.query,
            available_tables=state.available_tables,
            relevant_tables=state.relevant_tables,
            table_schemas=state.table_schemas,
            generated_sql=sql_query,
            execution_result=result,
            execution_error=None
        )
        
    except Exception as e:
        error_message = str(e)
        return AgentState(
            query=state.query,
            available_tables=state.available_tables,
            relevant_tables=state.relevant_tables,
            table_schemas=state.table_schemas,
            generated_sql=sql_query,
            execution_result=None,
            execution_error=error_message
        )

# Step 7: Correct SQL on error (if needed)
def correct_sql_on_error(state: AgentState) -> AgentState:
    """Attempt to correct the SQL query if execution failed"""
    if state.execution_error is None:
        return state
    
    query = state.query
    error_message = state.execution_error
    original_sql = state.generated_sql
    table_schemas = state.table_schemas
    
    # Prompt to fix SQL errors
    prompt = PromptTemplate.from_template("""
    You are an expert SQL debugger. The following SQL query failed with an error:
    
    Original user question: {query}
    
    Failed SQL query: 
    {original_sql}
    
    Error message: 
    {error_message}
    
    Database schema information:
    {table_schemas}
    
    Please fix the query to address the error. Common issues include:
    1. Invalid column or table names
    2. Syntax errors in JOIN conditions or WHERE clauses
    3. Type mismatches or invalid operators
    4. Missing GROUP BY columns
    5. Referencing non-existent tables or columns
    
    Return only the corrected SQL query with no explanation.
    """)
    
    schemas_text = "\n\n".join([f"-- Table: {t}\n{schema}" for t, schema in table_schemas.items()])
    
    chain = prompt | llm | StrOutputParser()
    corrected_sql = chain.invoke({
        "query": query,
        "original_sql": original_sql,
        "error_message": error_message,
        "table_schemas": schemas_text
    })
    
    # Clean up the query (remove SQL codeblocks if present)
    corrected_sql = re.sub(r'```sql\s*|\s*```', '', corrected_sql).strip()
    
    # Execute the corrected query
    try:
        result = db.run(corrected_sql)
        if not result or result.strip() == "":
            result = "Query executed successfully after correction, but returned no results."
            
        return AgentState(
            query=state.query,
            available_tables=state.available_tables,
            relevant_tables=state.relevant_tables,
            table_schemas=state.table_schemas,
            generated_sql=state.generated_sql,
            execution_result=result,
            execution_error=None,
            corrected_sql=corrected_sql
        )
        
    except Exception as e:
        # If correction still fails, return both errors
        secondary_error = str(e)
        return AgentState(
            query=state.query,
            available_tables=state.available_tables,
            relevant_tables=state.relevant_tables,
            table_schemas=state.table_schemas,
            generated_sql=state.generated_sql,
            execution_result=None,
            execution_error=f"Original error: {state.execution_error}\n\nCorrection attempt failed with: {secondary_error}",
            corrected_sql=corrected_sql
        )

# Step 8: Formulate response
def formulate_response(state: AgentState) -> AgentState:
    """Formulate a human-readable response based on query results"""
    query = state.query
    sql_query = state.corrected_sql if state.corrected_sql else state.generated_sql
    result = state.execution_result
    error = state.execution_error
    
    if error:
        # Handle the case where even the correction failed
        prompt = PromptTemplate.from_template("""
        The database query to answer the user's question encountered errors that could not be fixed automatically.
        
        User question: {query}
        
        SQL query attempted: {sql_query}
        
        Error message: {error}
        
        Provide a helpful response that:
        1. Acknowledges the technical issue
        2. Explains in simple terms what might be wrong
        3. Suggests how the user might rephrase their question
        """)
        
        chain = prompt | llm | StrOutputParser()
        response = chain.invoke({
            "query": query,
            "sql_query": sql_query,
            "error": error
        })
        
    else:
        # Format successful results
        prompt = PromptTemplate.from_template("""
        You are a helpful database assistant. Format the following SQL results into a clear, concise answer.
        
        User question: {query}
        
        SQL query executed: {sql_query}
        
        Query results: {result}
        
        Provide a natural language response that:
        1. Directly answers the user's question
        2. Presents the data in an easily digestible format
        3. Highlights key insights from the results
        4. Is conversational and helpful
        
        Do not mention the SQL query itself unless it's particularly relevant to understanding the answer.
        """)
        
        chain = prompt | llm | StrOutputParser()
        response = chain.invoke({
            "query": query,
            "sql_query": sql_query,
            "result": result
        })
    
    return AgentState(
        query=state.query,
        available_tables=state.available_tables,
        relevant_tables=state.relevant_tables,
        table_schemas=state.table_schemas,
        generated_sql=state.generated_sql,
        execution_result=state.execution_result,
        execution_error=state.execution_error,
        corrected_sql=state.corrected_sql,
        final_response=response
    )

# Define conditional edges
def should_correct_sql(state):
    """Determine if SQL correction is needed based on execution errors"""
    return "correct_sql" if state.execution_error else "formulate_response"

# Build the LangGraph
def build_sql_agent_graph():
    """Build and return the SQL agent graph"""
    workflow = StateGraph(AgentState)
    
    # Add nodes
    workflow.add_node("fetch_tables", fetch_available_tables)
    workflow.add_node("determine_tables", determine_relevant_tables)
    workflow.add_node("retrieve_schemas", retrieve_table_schemas)
    workflow.add_node("generate_sql", generate_sql_query)
    workflow.add_node("validate_sql", validate_sql)
    workflow.add_node("execute_sql", execute_sql_query)
    workflow.add_node("correct_sql", correct_sql_on_error)
    workflow.add_node("formulate_response", formulate_response)
    
    # Add edges
    workflow.add_edge("fetch_tables", "determine_tables")
    workflow.add_edge("determine_tables", "retrieve_schemas")
    workflow.add_edge("retrieve_schemas", "generate_sql")
    workflow.add_edge("generate_sql", "validate_sql")
    workflow.add_edge("validate_sql", "execute_sql")
    
    # Add conditional edge
    workflow.add_conditional_edges(
        "execute_sql",
        should_correct_sql,
        {
            "correct_sql": "correct_sql",
            "formulate_response": "formulate_response"
        }
    )
    
    workflow.add_edge("correct_sql", "formulate_response")
    workflow.add_edge("formulate_response", END)
    
    # Set entry point
    workflow.set_entry_point("fetch_tables")
    
    # Compile the graph
    return workflow.compile()

# Create the agent
sql_agent = build_sql_agent_graph()

In [18]:
def query_sql_agent(question: str) -> str:
    """Query the SQL agent with a natural language question"""
    result = sql_agent.invoke({"query": question})
    
    # Print the result to debug the structure
    print(result)
    
    # Access final_response correctly
    return result.get("final_response", "No response found")


In [19]:
# Example usage
if __name__ == "__main__":
    # Test the agent with sample questions
    sample_questions = [
        "Give me the top 3 most popular products"
        # "Which customer has spent the most money?",
        # "How many products have a validity period longer than 20 days?",
        # "What's the average price of all products in our database?"
    ]
    
    for question in sample_questions:
        print(f"\nQuestion: {question}")
        print("-" * 50)
        answer = query_sql_agent(question)
        print(f"Answer: {answer}")
        print("=" * 80)



Question: Give me the top 3 most popular products
--------------------------------------------------
{'query': 'Give me the top 3 most popular products', 'available_tables': ['customer_purchases', 'products'], 'relevant_tables': ['customer_purchases', 'products'], 'table_schemas': {'customer_purchases': "Error retrieving schema: Not an executable object: 'SELECT * FROM customer_purchases LIMIT 5'", 'products': "Error retrieving schema: Not an executable object: 'SELECT * FROM products LIMIT 5'"}, 'generated_sql': "Here is the reviewed and corrected SQL query:\nSELECT p.product_name, COUNT(cp.product_id) AS purchase_count\nFROM products p\nINNER JOIN customer_purchases cp ON p.product_id = cp.product_id\nGROUP BY p.product_name\nORDER BY purchase_count DESC\nLIMIT 3;\n\nThe original query looks correct, so I'm returning it unchanged.", 'execution_result': "[('Data Pack 3GB', 3), ('Data Pack 1GB', 2), ('Unlimited Calls', 2)]", 'execution_error': None, 'corrected_sql': 'SELECT p.product_