In [1]:
import streamlit as st
from langchain_community.utilities import SQLDatabase
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
import re
from langchain.prompts import PromptTemplate
from langchain_ollama import OllamaLLM
from typing_extensions import TypedDict


db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [2]:



class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [3]:


# Load Mistral from Ollama
llm = OllamaLLM(model="mistral")


In [4]:



query_prompt_template = PromptTemplate(
    input_variables=["query", "dialect", "top_k"],
    template="""Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.


**Instructions:**
- You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
- **Always** use column names explicitly (e.g., `COUNT(EmployeeId) AS EmployeeCount` instead of `COUNT(*)`).
- Do **not** explain the query.
- Do **not** include any additional text.
- Do **not** describe database tables unless explicitly asked.
- Only return the final SQL query.
- Always limit results to **{top_k}** unless specified otherwise.
- Do **not** add anything to the end of the table name (e.g 'Employee' instead of 'Employees'

To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables.



**User Question:** {query}

**SQL Query Output:**
```sql
"""
)

In [5]:


def extract_sql_query(text):
    """Extract SQL query from the LLM response."""
    match = re.search(r"```sql\n(.*?)\n```", text, re.DOTALL)
    return match.group(1) if match else text.strip()

def write_query(state):
    """Generate a SQL query using the refined prompt."""
    prompt = query_prompt_template.format(
        query=state["question"],
        dialect=db.dialect,
        top_k=5
    )

    result = llm.invoke(prompt)

    return {"query": extract_sql_query(result)}


In [6]:
write_query({"question": "How many Employees are there?"})  


ResponseError: <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"
"http://www.w3.org/TR/html4/loose.dtd">
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<title>Notification: Proxy Authorization Required</title>
<style type="text/css">
body {
  font-family: Arial, Helvetica, sans-serif;
  font-size: 14px;
  color:#333333;
  background-color: #ffffff;
}
h1 {
  font-size: 18px;
  font-weight: bold;
  text-decoration: none;
  padding-top: 0px;
  color: #2970A6;
}
a:link {
    color: #2970A6;
  text-decoration: none;
}
a:hover {
    color: #2970A6;
  text-decoration: underline;
}
p.buttonlink {
  margin-bottom: 24px;
}
.copyright {
  font-size: 12px;
  color: #666666;
  margin: 5px 5px 0px 30px;

}
.details {
  font-size: 14px;
  color: #969696;
  border: none;
  padding: 20px 20px 20px 20px;
  margin: 0px 10px 10px 35px;
}

.shadow {
  border: 3px solid #9f9f9f;
  padding: 10px 25px 10px 25px;
  margin: 10px 35px 0px 30px;
  background-color: #ffffff;
  width: 600px;

  -moz-box-shadow: 3px 3px 3px #cccccc;
  -webkit-box-shadow: 3px 3px 3px #cccccc;
  box-shadow: 3px 3px 3px #cccccc;
  /* For IE 8 */
  -ms-filter: "progid:DXImageTransform.Microsoft.Shadow(Strength=5, Direction=135, Color='cccccc')";
  /* For IE 5.5 - 7 */
  filter: progid:DXImageTransform.Microsoft.Shadow(Strength=5, Direction=135, Color='cccccc');
}
.logo {
  border: none;
  margin: 5px 5px 0px 30px;
}
</style>
</head>

<body>
<div class="logo"></div><p>&nbsp;</p>
<div class="shadow">
<h1>This Page Cannot Be Displayed</h1>


<p>
Authentication is required to access the Internet using this system.
A valid user ID and password must be entered when prompted.
</p>



<p>
If you have questions, please contact
your organization's network administrator 
and provide the codes shown below.
</p>

</div>

<div class="details"><p>
Date: Wed, 12 Feb 2025 09:55:54 CST<br />
Username: <br />
Source IP: 53.144.125.164<br />
URL: POST http://127.0.0.1/api/generate<br />
Category: URL Filtering Bypassed<br />
Reason: UNKNOWN<br />
Notification: PROXY_AUTH_REQUIRED
</p></div>
</body>
</html>
 (status code: 407)

In [7]:



def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [8]:
execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"})

{'result': '[(8,)]'}

In [None]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response}

In [None]:


graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [None]:
for step in graph.stream(
    {"question": "How many employee are there?"}, stream_mode="updates"
):
    print(step)

In [None]:


memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])

# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}

## Below is code for an agent instead of a chain



In [None]:


toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

tools

In [None]:
prompt_template = PromptTemplate(
    input_variables=["dialect", "top_k", "input"],
    template="""You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

**Instructions:**
- You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
- **Always** use column names explicitly (e.g., `COUNT(EmployeeId) AS EmployeeCount` instead of `COUNT(*)`).
- Do **not** explain the query.
- Do **not** include any additional text.
- Do **not** describe database tables unless explicitly asked.
- Only return the final SQL query.
- Always limit results to **{top_k}** unless specified otherwise.
- Do **not** add anything to the end of the table name (e.g 'Employee' instead of 'Employees')



To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables.
Question: {input}
"""
)

In [None]:
user_question =  "Which country's customers spent the most?"

system_message = prompt_template.format(dialect="SQLite", top_k=5, input = user_question)

In [None]:

# ... (your database connection and LLM initialization) ...

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

# Refine tool descriptions (VERY IMPORTANT for ZERO_SHOT_REACT_DESCRIPTION):
for tool in tools:
    if tool.name == "QuerySQLDatabaseTool":
        tool.description = """
        Use this tool to execute SQL queries against the database. 
        Input: A detailed and CORRECT SQL query.
        Output: The result from the database.
        IMPORTANT: Before using this tool, ALWAYS use the 'QuerySQLCheckerTool' to validate your query.
        If the query is incorrect, the tool will return an error message.
        If you encounter an issue like 'Unknown column...', use the 'InfoSQLDatabaseTool' to get the correct table fields.
        """
    elif tool.name == "InfoSQLDatabaseTool":
        tool.description = """
        Use this tool to get information about the database schema and sample rows.
        Input: A comma-separated list of table names.
        Output: The schema and sample rows for those tables.
        Use this tool to understand the database structure or to find the correct column names for your queries.
        Call the 'ListSQLDatabaseTool' first to know what tables are available.
        """
    elif tool.name == "ListSQLDatabaseTool":
        tool.description = """
        Use this tool to get a list of available tables in the database.
        Input: None.
        Output: A list of table names.
        Use this tool before using 'InfoSQLDatabaseTool' to make sure the tables exist.
        """
    elif tool.name == "QuerySQLCheckerTool":
        tool.description = """
        Use this tool to check if your SQL query is correct BEFORE executing it with 'QuerySQLDatabaseTool'.
        Input: The SQL query you want to check.
        Output: The original query if it is correct, or a corrected query if there were mistakes.
        ALWAYS use this tool first.
        """


agent = initialize_agent(
    tools,
    llm,  # Your Ollama LLM instance
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    handle_parsing_errors=True,
    prompt=system_message,
)

# Example usage:
agent.invoke( "Which country's customers spent the most?")

In [None]:


agent.invoke("Describe the playlisttrack table")


In [None]:
agent.invoke("What are 10 track names and their artists?")