# SQL Query Generation and Execution with LangGraph

## Introduction

This notebook demonstrates how to use LangChain, OpenAI, and SQLDatabaseToolkit to generate and execute SQL queries based on natural language questions. The workflow involves initializing a SQL database, defining tools for querying and schema retrieval, and setting up a stateful graph to manage the query generation and execution process. The notebook also includes error handling and result formatting to ensure a smooth user experience.

## Installation and Setup

First, we need to install the necessary libraries and import the required modules.

These installations set up the necessary environment for:
- Interacting with **OpenAI** and **Anthropic** language models.
- Using **LangChain tools** for database interactions and workflow management.
- Building **stateful workflows** with `langgraph`.
- Optional use of **vector databases** like ChromaDB for advanced tasks.

This ensures that all dependencies are available for the notebook to function correctly.

In [None]:
!pip install -qU langchain-openai
!pip install -qU langchain-anthropic
!pip install -qU langchain_community
!pip install -qU langgraph
!pip install -qU chromadb

In [None]:
# Import necessary modules
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import ToolMessage, AIMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import AnyMessage, add_messages
from pydantic import BaseModel, Field
from typing import Annotated, Literal, TypedDict, Any

## Initialize the Database

We will be creating a SQLite database for this tutorial. SQLite is a lightweight database that is easy to set up and use. We will be loading the chinook database, which is a sample database that represents a digital media store. Find more information about the database [here](https://www.sqlitetutorial.net/sqlite-sample-database/).

In [None]:
import requests

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)

if response.status_code == 200:
    # Open a local file in binary write mode
    with open("Chinook.db", "wb") as file:
        # Write the content of the response (the file) to the local file
        file.write(response.content)
    print("File downloaded and saved as Chinook.db")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

# Initialize the database
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(f"Database Type: {db.dialect}")
print(f"Tables: {db.get_usable_table_names()}")
print(f"Artist: {db.run('SELECT * FROM Artist LIMIT 5;')}")

## Define Tools for the Agent

Next, we define the tools that the agent will use to interact with the database. These tools include querying the database, retrieving the schema, and listing tables.

This code:
1. Securely loads the OpenAI API key.
2. Initializes a `SQLDatabaseToolkit` with the necessary tools for interacting with a SQL database.
3. Extracts specific tools for querying, schema retrieval, and table listing.
4. Binds a language model to the schema tool for handling schema-related tasks.

This setup enables the agent to:
- Execute SQL queries.
- Retrieve database schema and table information.
- Check SQL queries for correctness.
- Use OpenAI's language model to assist in database interactions.

In [None]:
from kaggle_secrets import UserSecretsClient

# Load OpenAI API Key
my_api_key = UserSecretsClient().get_secret("my-openai-api-key")

# Define tools for the agent
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o-mini", api_key=my_api_key))
tools = toolkit.get_tools()

# Extract specific tools from the toolkit
db_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
db_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
db_list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
db_query_checker_tool = next(tool for tool in tools if tool.name == "sql_db_query_checker")

# Bind the schema tool to a model
db_schema_model = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key).bind_tools([db_schema_tool])

### Define the Workflow State

We define the state of the workflow, which will keep track of the messages exchanged during the query generation and execution process.

In [None]:
# Define the workflow state
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

### Define the Initial Tool Node

The initial tool node is responsible for starting the workflow by listing the tables in the database.

In [None]:
# Define the initial tool node
def initial_tool_node(state: State) -> dict[str, list[AIMessage]]:
    """
    Initializes the workflow by creating a tool call to list tables in the database.
    """
    tool_call_id = "tool_abcd123"  # Hardcoded for debugging
    print(f"--- First Tool Call Node ---\nTool Call ID: {tool_call_id}")
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": tool_call_id,
                    }
                ],
            )
        ]
    }

### Define the List Tables Node

This node lists all tables in the database and returns the result as a ToolMessage.

In [None]:
# Define the list tables node
def list_tables_node(state: State) -> dict[str, list[AIMessage]]:
    """
    Lists all tables in the database and returns the result as a ToolMessage.
    """
    print("--- List Tables Node ---")
    result = db_list_tables_tool.invoke({})
    print("Tables in the database:", result)

    # Get the tool_call_id from the previous message
    tool_call_id = state["messages"][-1].tool_calls[0]["id"]
    print(f"Tool Call ID: {tool_call_id}")
    return {"messages": [ToolMessage(content=result, tool_call_id=tool_call_id)]}

### Define the Model Get Schema Node

This node uses the model to generate a schema request based on the current state.

In [None]:
# Define the model get schema node
def model_get_schema_node(state: State) -> dict[str, list[AIMessage]]:
    """
    Uses the model to generate a schema request based on the current state.
    """
    print("--- Model Get Schema Node ---")
    return {"messages": [db_schema_model.invoke(state["messages"])]}

### Define the Retrieve Schema Node

This node retrieves the schema for a specific table and returns it as a ToolMessage.

In [None]:
# Define the retrieve schema node
def retrieve_schema_node(state: State) -> dict[str, list[AIMessage]]:
    """
    Retrieves the schema for a specific table and returns it as a ToolMessage.
    """
    print("--- Retrieve Schema Node ---")
    table_name = state["messages"][-1].tool_calls[0]["args"]["table_names"]
    result = db_schema_tool.invoke(table_name)
    print(f"Schema for table '{table_name}':\n{result}")

    # Get the tool_call_id from the previous message
    tool_call_id = state["messages"][-1].tool_calls[0]["id"]
    print(f"Tool Call ID: {tool_call_id}")

    # Return a ToolMessage with the same tool_call_id
    return {"messages": [ToolMessage(content=result, tool_call_id=tool_call_id)]}

### Define the SubmitFinalAnswer Class

This class represents the final answer to be submitted to the user.

In [None]:
# Define the SubmitFinalAnswer class
class SubmitFinalAnswer(BaseModel):
    """
    A Pydantic model representing the final answer to be submitted to the user.
    """
    final_answer: str = Field(..., description="The final answer to the user")

## Define the Query Generation System Prompt

This prompt guides the model in generating SQL queries based on the input question. 

This code sets up a **query generation pipeline**:
1. The model acts as a SQL expert and generates SQL queries based on the input question.
2. The query is generated with strict guidelines to ensure correctness, relevance, and safety.
3. The final answer is submitted using the `SubmitFinalAnswer` tool.

This system ensures that:
- SQL queries are **accurate and optimized**.
- Results are **limited and relevant**.
- Errors and edge cases are **handled gracefully**.
- The workflow adheres to **best practices** (e.g., no DML statements).

In [None]:
# Define the query generation system prompt
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen_chain = query_gen_prompt | ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key).bind_tools([SubmitFinalAnswer])

### Define the Query Generation Node

This node generates a SQL query based on the current state and returns the result.

1. Validates the input state to ensure it contains a `ToolMessage`.
2. Generates an SQL query using the `query_gen_chain`.
3. Handles errors if the wrong tool is called during query generation.
4. Returns the generated query or error messages as part of the workflow state.

This node plays a critical role in the workflow by ensuring that SQL queries are generated correctly and that errors are handled gracefully.

In [None]:
# Define the query generation node
def query_gen_node(state: State):
    """
    Generates a SQL query based on the current state and returns the result.
    """
    print("--- Query Gen Node ---")
    # Ensure the last message is a ToolMessage
    if isinstance(state["messages"][-1], ToolMessage):
        tool_call_id = state["messages"][-1].tool_call_id
        print(f"Tool Call ID from previous message: {tool_call_id}")
    else:
        raise ValueError("Expected a ToolMessage as the last message.")

    # Generate the query
    message = query_gen_chain.invoke(state)
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

### Define the Should Continue Function

This function determines the next step in the workflow based on the current state.

In [None]:
# Define the should_continue function
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
    """
    Determines the next step in the workflow based on the current state.
    """
    messages = state["messages"]
    last_message = messages[-1]
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"

### Define the DB Statement Execution Tool

This tool executes a SQL query against the database and returns the result.

In [None]:
# Define the db_stmt_exec_tool function
@tool
def db_stmt_exec_tool(query: str) -> str:
    """
    Execute a SQL query against the database and return the result.
    If the query fails, return an error message.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result

### Define the Query Check System Prompt

This prompt guides the model in checking the SQL query for common mistakes.

1. The model acts as a SQL expert and checks the query for common mistakes.
2. If mistakes are found, the query is rewritten; otherwise, the original query is used.
3. The validated query is passed to the `db_stmt_exec_tool` for execution.

This ensures that only **correct and safe SQL queries** are executed on the database, reducing the risk of errors or unintended behavior.

In [None]:
# Define the query check system prompt
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check_chain = query_check_prompt | ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key).bind_tools([db_stmt_exec_tool], tool_choice="required")

### Define the Correct Query Node

This node corrects the SQL query if necessary and returns the corrected query.

In [None]:
# Define the correct query node
def correct_query_node(state: State) -> dict[str, list[AIMessage]]:
    """
    Corrects the SQL query if necessary and returns the corrected query.
    """
    print("--- Correct Query Node ---")
    return {"messages": [query_check_chain.invoke({"messages": [state["messages"][-1]]})]}

### Define the Execute Query Node

This node executes the SQL query and returns the result as a ToolMessage.

In [None]:
# Define the execute query node
def execute_query_node(state: State) -> dict[str, list[AIMessage]]:
    """
    Executes the SQL query and returns the result as a ToolMessage.
    """
    print("--- Execute Query Node ---")
    try:
        query = state["messages"][-1].tool_calls[0]["args"]["query"]
        result = db_stmt_exec_tool.invoke(query)
        print(f"Query Results:\n{result}")
    except Exception as e:
        result = f"Error: {str(e)}"
        print(result)

    # Get the tool_call_id from the previous message
    tool_call_id = state["messages"][-1].tool_calls[0]["id"]
    print(f"Tool Call ID: {tool_call_id}")
    return {"messages": [ToolMessage(content=result, tool_call_id=tool_call_id)]}

## Define the Workflow

We define the workflow by adding nodes and edges to the state graph. This code defines a **stateful, step-by-step workflow** for generating and executing SQL queries. The workflow:
1. Starts by listing database tables.
2. Retrieves the schema for a specific table.
3. Generates and corrects SQL queries.
4. Executes the queries and handles errors or corrections.
5. Visualizes the entire workflow as a diagram for better understanding.

The workflow is designed to be modular, with each node handling a specific task, and conditional edges ensuring the correct flow based on the state of the process.

In [None]:
# Define the workflow
workflow = StateGraph(State)

# Add nodes with redesigned names
workflow.add_node("initial_tool_node", initial_tool_node)
workflow.add_node("list_tables_node", list_tables_node)
workflow.add_node("model_get_schema_node", model_get_schema_node)
workflow.add_node("retrieve_schema_node", retrieve_schema_node)
workflow.add_node("query_gen_node", query_gen_node)
workflow.add_node("correct_query_node", correct_query_node)
workflow.add_node("execute_query_node", execute_query_node)

# Add edges with updated node names
workflow.add_edge(START, "initial_tool_node")
workflow.add_edge("initial_tool_node", "list_tables_node")
workflow.add_edge("list_tables_node", "model_get_schema_node")
workflow.add_edge("model_get_schema_node", "retrieve_schema_node")
workflow.add_edge("retrieve_schema_node", "query_gen_node")
workflow.add_conditional_edges("query_gen_node", should_continue, [END, "correct_query_node", "query_gen_node"])
workflow.add_edge("correct_query_node", "execute_query_node")
workflow.add_edge("execute_query_node", "query_gen_node")

app = workflow.compile()

# Visualize the graph
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

## Define Helper Functions for Query Execution

We define helper functions to extract, execute, and format SQL queries and their results.

### Key Workflow:

- The code is designed to work with a workflow that generates SQL queries from natural language questions.
- It extracts the query, executes it on the database, and formats the results for easy interpretation.
- Error handling is included to ensure robustness during query execution and result formatting.

This set of functions is typically used in conjunction with a larger system (like the one in the notebook) to automate SQL query generation and execution based on user input.

In [None]:
from openai import BadRequestError

def extract_sql_query(final_answer: str) -> str:
    """
    Extracts the SQL query from the final_answer string.
    
    Args:
        final_answer (str): The final answer string that may contain an SQL query.
    
    Returns:
        str: The extracted SQL query if found, otherwise None.
    """
    if "```sql" in final_answer:
        return final_answer.split("```sql")[1].split("```")[0].strip()
    return None

def execute_sql_query(sql_query: str):
    """
    Executes the SQL query and returns the results.
    
    Args:
        sql_query (str): The SQL query to execute.
    
    Returns:
        Any: The results of the SQL query execution, or None if an error occurs.
    """
    try:
        results = db.run(sql_query)
        return results
    except Exception as e:
        print("Error executing SQL query:", e)
        return None

def format_results(results) -> str:
    """
    Formats the query results into a human-readable string.
    
    Args:
        results (Any): The results of the SQL query execution.
    
    Returns:
        str: A formatted string representing the query results.
    """
    if isinstance(results, str):
        # If results is a string, return it as-is
        return results

    formatted_results = "The total sales amount per country is:\n"
    try:
        for row in results:
            # Handle cases where row is a tuple or list
            if isinstance(row, (tuple, list)) and len(row) >= 2:
                formatted_results += f"- {row[0]}: ${row[1]:.2f}\n"
            else:
                # Handle unexpected row formats
                formatted_results += f"- {row}\n"
    except Exception as e:
        print("Error formatting results:", e)
        return str(results)  # Fallback: return results as a string

    return formatted_results

def process_event(event):
    """
    Processes the event to extract, execute, and print the final answer.
    
    Args:
        event (dict): The event containing the state of the workflow.
    """
    if "query_gen_node" not in event:
        return

    query_gen_state = event["query_gen_node"]
    if "messages" not in query_gen_state:
        return

    last_message = query_gen_state["messages"][-1]
    if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
        return

    final_answer = last_message.tool_calls[0]["args"]["final_answer"]
    sql_query = extract_sql_query(final_answer)

    if not sql_query:
        print("No SQL query found in the final answer. The agent provided the following response:")
        print(final_answer)
        return

    print(f"Extracted SQL Query:\n{sql_query}\n")
    results = execute_sql_query(sql_query)

    if results:
        formatted_results = format_results(results)
        print("Final Answer:", formatted_results)

## Execute Queries Based on User Questions

Finally, we execute the workflow for different user questions and print the results.

In [None]:
try:
    question = "What is the total sales amount per country?"
    for event in app.stream({"messages": [("user", question)]}):
        process_event(event)
except BadRequestError as e:
    print(f"Error processing question: {question}")
    print(f"Error details: {e}")

In [None]:
try:
    question = "What is the total sales amount per genre?"
    for event in app.stream({"messages": [("user", question)]}):
        process_event(event)
except BadRequestError as e:
    print(f"Error processing question: {question}")
    print(f"Error details: {e}")

In [None]:
try:
    question = "How many tracks are in each playlist?"
    for event in app.stream({"messages": [("user", question)]}):
        process_event(event)
except BadRequestError as e:
    print(f"Error processing question: {question}")
    print(f"Error details: {e}")

In [None]:
try:
    question = "Which 5 artists have the most tracks in the database?"
    for event in app.stream({"messages": [("user", question)]}):
        process_event(event)
except BadRequestError as e:
    print(f"Error processing question: {question}")
    print(f"Error details: {e}")

In [None]:
try:
    question = "Who are the top 3 customers by total spending?"
    for event in app.stream({"messages": [("user", question)]}):
        process_event(event)
except BadRequestError as e:
    print(f"Error processing question: {question}")
    print(f"Error details: {e}")

## Conclusion

This notebook demonstrates a powerful workflow for generating and executing SQL queries based on natural language questions. By leveraging LangChain, OpenAI, and SQLDatabaseToolkit, we can create a robust system that handles complex queries, corrects common mistakes, and formats results for easy interpretation. This approach can be extended to various other use cases, making it a valuable tool for data analysis and database management.

This workflow can be applied to various real-world scenarios, such as:
- **Business Intelligence**: Automating the generation of reports and insights from databases.
- **Data Exploration**: Enabling non-technical users to query databases using natural language.
- **Customer Support**: Providing automated answers to customer queries based on database data.
- **Education**: Teaching SQL concepts by translating natural language questions into queries.