## App querying SQLite database using LangChain's SQLDatabaseChain

In [1]:
# [REQUEST]
#   Create sample code to query SQLite database by LLM, LangChain and SQLDatabaseChain?
# [RESPONSE]
#   I'll provide a comprehensive example of querying a SQLite database using
#   LangChain's SQLDatabaseChain with step-by-step explanations:

# [Install required packages]
# pip install langchain langchain-community langchain-openai sqlalchemy

### Links
- [**SQLDatabase** (LangChain v0.1)](https://python.langchain.com/v0.1/docs/integrations/toolkits/sql_database/)
- [**SQLDatabase Toolkit** (LangChain latest)](https://python.langchain.com/docs/integrations/tools/sql_database/)

In [2]:
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_experimental.sql import SQLDatabaseChain

### Set up OpenAI API Key

In [3]:
# Set up OpenAI API Key
# if "OPENAI_API_KEY" not in os.environ:
#     os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API Key: ")

In [4]:
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
# openai_api_key = os.environ["OPENAI_API_KEY"]

# load_dotenv(override = True)
# api_key = os.getenv('OPENAI_API_KEY')

MODEL_GPT="gpt-4o-mini"

### 1. Create a connection to the SQLite database

In [5]:
# 1. Create a connection to the SQLite database
# Replace 'path/to/your/database.db' with the actual path to your SQLite database
# db = SQLDatabase.from_uri("sqlite:///path/to/your/database.db")
sqlite_db_path = "../../data/street_tree_db.sqlite"
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")

### 2. Initialize the Language Model

In [6]:
# 2. Initialize the Language Model
# You can adjust the temperature and model as needed
llm = ChatOpenAI(
    # model="gpt-3.5-turbo",  # Choose your preferred OpenAI model
    model=MODEL_GPT,  # Choose your preferred OpenAI model
    temperature=0  # Set to 0 for more deterministic responses
)

### 3. Create the SQL Database Chain

In [7]:
# 3. Create the SQL Database Chain
db_chain = SQLDatabaseChain.from_llm(
    llm, 
    db, 
    verbose=True  # Enables detailed output of the chain's reasoning
)

### 4. Example queries

In [8]:
# 4. Example queries
def query_database(question):
    """
    Helper function to run queries against the database.
    
    Args:
        question (str): Natural language question about the database
    
    Returns:
        The answer to the query
    """
    try:
        # result = db_chain.run(question)
        result = db_chain.invoke(question)
        return result
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

### 5. Example usage

In [9]:
# 5. Example usage
if __name__ == "__main__":
    # Example queries
    # queries = [
    #     "What is the total number of records in the Customers table?",
    #     "List the top 3 most expensive products",
    #     "Find the average price of products in each category"
    # ]

    # "How many species of trees are in San Francisco?"
    # 'SQLQuery: SELECT COUNT(DISTINCT "qSpecies") AS "SpeciesCount" FROM street_trees;'
    queries = [
        "How many species of trees are in San Francisco?"
    ]
    
    # Run each query and print results
    for query in queries:
        print(f"\nQuery: {query}")
        result = query_database(query)
        print(f"Result: {result}")


Query: How many species of trees are in San Francisco?


[1m> Entering new SQLDatabaseChain chain...[0m
How many species of trees are in San Francisco?
[32;1m[1;3m```sql
SELECT DISTINCT "qSpecies" FROM street_trees;
```[0mAn error occurred: (sqlite3.OperationalError) near "```sql
SELECT DISTINCT "qSpecies" FROM street_trees;
```": syntax error
[SQL: ```sql
SELECT DISTINCT "qSpecies" FROM street_trees;
```]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Result: None


## [chatbot-database.py]

In [10]:
# # Import required libraries
# from langchain_community.utilities import SQLDatabase
# from langchain_openai import ChatOpenAI
# from langchain_experimental.sql import SQLDatabaseChain

# # 1. Create a connection to the SQLite database
# # Replace 'path/to/your/database.db' with the actual path to your SQLite database
# db = SQLDatabase.from_uri("sqlite:///path/to/your/database.db")

# # 2. Initialize the Language Model
# # You can adjust the temperature and model as needed
# llm = ChatOpenAI(
#     model="gpt-3.5-turbo",  # Choose your preferred OpenAI model
#     temperature=0  # Set to 0 for more deterministic responses
# )

# # 3. Create the SQL Database Chain
# db_chain = SQLDatabaseChain.from_llm(
#     llm, 
#     db, 
#     verbose=True  # Enables detailed output of the chain's reasoning
# )

# # 4. Example queries
# def query_database(question):
#     """
#     Helper function to run queries against the database
    
#     Args:
#         question (str): Natural language question about the database
    
#     Returns:
#         The answer to the query
#     """
#     try:
#         result = db_chain.run(question)
#         return result
#     except Exception as e:
#         print(f"An error occurred: {e}")
#         return None

# # 5. Example usage
# if __name__ == "__main__":
#     # Example queries
#     queries = [
#         "What is the total number of records in the Customers table?",
#         "List the top 3 most expensive products",
#         "Find the average price of products in each category"
#     ]
    
#     # Run each query and print results
#     for query in queries:
#         print(f"\nQuery: {query}")
#         result = query_database(query)
#         print(f"Result: {result}")

# # [Error Handling Example]
# # def safe_query_database(question):
# #     try:
# #         # Add input sanitization if needed
# #         if not question or len(question) > 500:
# #             raise ValueError("Invalid query length")
        
# #         result = db_chain.run(question)
# #         return result
# #     except Exception as e:
# #         print(f"Query failed: {e}")
# #         # Log the error or handle it appropriately
# #         return None


## Build a Question/Answering system over SQL data
- https://python.langchain.com/docs/tutorials/sql_qa/
- **Chinook Database**
  - https://github.com/lerocha/chinook-database
  - https://github.com/lerocha/chinook-database/tree/master/ChinookDatabase/DataSources
  - https://github.com/lerocha/chinook-database/blob/master/ChinookDatabase/DataSources/Chinook_Sqlite.sqlite
  - https://github.com/lerocha/chinook-database/blob/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql
  - https://www.sqlitetutorial.net/sqlite-sample-database/

### Setup

#### Sample data

In [11]:
from langchain_community.utilities import SQLDatabase

# db = SQLDatabase.from_uri("sqlite:///Chinook.db")
sqlite_db_path = "../../data/Chinook.sqlite"
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")

# sqlite_db_path = "../../data/street_tree_db.sqlite"
# db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")

print(db.dialect)
print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10;")
result = db.run("SELECT * FROM Artist LIMIT 10;")
# result = db.run("SELECT * FROM street_trees LIMIT 1;")
print(type(result))
print(result)

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
<class 'str'>
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


### Chains

#### Application state

In [12]:
from typing_extensions import TypedDict

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

#### Convert question to SQL query

In [13]:
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

In [14]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()


Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [15]:
from typing_extensions import Annotated

class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]

def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [16]:
write_query({"question": "How many Employees are there?"})

{'query': 'SELECT COUNT(*) AS TotalEmployees FROM Employee;'}

#### Execute query

In [17]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [18]:
execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"})

{'result': '[(8,)]'}

#### Generate answer

In [19]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

#### Orchestrating with LangGraph

In [20]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [21]:
# LangGraph also comes with built-in utilities for visualizing the control flow of your application:
# from IPython.display import Image, display

# display(Image(graph.get_graph().draw_mermaid_png()))

In [22]:
for step in graph.stream(
    {"question": "How many employees are there?"}, stream_mode="updates"
):
    print(step)

{'write_query': {'query': 'SELECT COUNT(*) AS NumberOfEmployees FROM Employee;'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are 8 employees.'}}


#### Human-in-the-loop with (user approval)

In [23]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])

# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}

In [24]:
# LangGraph also comes with built-in utilities for visualizing the control flow of your application:
# from IPython.display import Image, display

# display(Image(graph.get_graph().draw_mermaid_png()))

In [25]:
for step in graph.stream(
    {"question": "How many employees are there?"},
    config,
    stream_mode="updates",
):
    print(step)

try:
    user_approval = input("Do you want to go to execute query? (yes/no): ")
except Exception:
    user_approval = "no"

if user_approval.lower() == "yes":
    # If approved, continue the graph execution
    for step in graph.stream(None, config, stream_mode="updates"):
        print(step)
else:
    print("Operation cancelled by user.")

{'write_query': {'query': 'SELECT COUNT(*) AS EmployeeCount FROM Employee;'}}
{'__interrupt__': ()}


Do you want to go to execute query? (yes/no):  yes


{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are 8 employees.'}}


### Agents [TODO:]
- https://python.langchain.com/docs/tutorials/sql_qa/#agents