In [1]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
import os

llm = ChatOpenAI(
    openai_api_base="https://llmfoundry.straive.com/openai/v1/",
    openai_api_key=f'{os.environ["LLMFOUNDRY_TOKEN"]}:my-test-project',
    model="gpt-4o-mini",
)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

for tool in tools:
    print(f"{tool.name}: {tool.description}\n")

# SQL DatabaseToolkit powered by LLM and prints available tools like querying, summarizing tables,along descriptions

# we use this to interact LLM with SQL Database.

sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.

sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3

sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.

sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!



In [53]:
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode

# Tool Selection and Node 
# this line selects the tool and selects sql_db_schema tool from tools retrives structure such as table names & col names

# Wraps the get_schema_tool into a ToolNode named "get_schema", making it usable in a LangGraph workflow.
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_node = ToolNode([get_schema_tool], name="get_schema")

# this selects the tool named "sql_db_query" — likely used to run actual SQL queries against the database.
# ToolNode called "run_query" to use in the graph.
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")


# Example: create a predetermined tool call
def list_tables(state: MessagesState):
    tool_call = {
        "name": "sql_db_list_tables",
        "args": {},
        "id": "",
        "type": "tool_call",
    }
    tool_call_message = AIMessage(content="", tool_calls=[tool_call])

    list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
    tool_message = list_tables_tool.invoke(tool_call)
    response = AIMessage(f"Available tables: {tool_message.content}")
    print("LIST_OFTABLES",{"messages": [tool_call_message, tool_message, response]})
    return {"messages": [tool_call_message, tool_message, response]}


# Example: force a model to create a tool call
def call_get_schema(state: MessagesState):
    # Note that LangChain enforces that all models accept `tool_choice="any"`
    # as well as `tool_choice=<string name of tool>`.
    llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
    response = llm_with_tools.invoke(state["messages"])
    print("\nSCHEMA_RESPONSE",{"messages": [response]})
    return {"messages": [response]}

# Unless the user
# specifies a specific number of examples they wish to obtain, always limit your
# query to at most {top_k} results.
generate_query_system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. 

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
""".format(
    dialect=db.dialect,
)


def generate_query(state: MessagesState):
    system_message = {
        "role": "system",
        "content": generate_query_system_prompt,
    }
    # We do not force a tool call here, to allow the model to
    # respond naturally when it obtains the solution.
    llm_with_tools = llm.bind_tools([run_query_tool])
    response = llm_with_tools.invoke([system_message] + state["messages"])
    # print(response)
    print("QUERY_RESPONSE",{"messages": [response]})
    return {"messages": [response]}


check_query_system_prompt = """
You are a SQL expert with a strong attention to detail.
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes,
just reproduce the original query.

You will call the appropriate tool to execute the query after running this check.
""".format(dialect=db.dialect)


def check_query(state: MessagesState):
    system_message = {
        "role": "system",
        "content": check_query_system_prompt,
    }

    # Generate an artificial user message to check
    tool_call = state["messages"][-1].tool_calls[0]
    user_message = {"role": "user", "content": tool_call["args"]["query"]}
    llm_with_tools = llm.bind_tools([run_query_tool], tool_choice="any")
    response = llm_with_tools.invoke([system_message, user_message])
    response.id = state["messages"][-1].id

    return {"messages": [response]}

In [61]:
def should_continue(state: MessagesState) -> Literal[END, "check_query"]:
    messages = state["messages"]
    last_message = messages[-1]
    print("LAST_MESSAGE",last_message.tool_calls)
    if not last_message.tool_calls:
        return END
    else:
        return "check_query"


builder = StateGraph(MessagesState)
builder.add_node(list_tables)
builder.add_node(call_get_schema)
builder.add_node(get_schema_node, "get_schema")
builder.add_node(generate_query)
builder.add_node(check_query)
builder.add_node(run_query_node, "run_query")

builder.add_edge(START, "list_tables")
builder.add_edge("list_tables", "call_get_schema")
builder.add_edge("call_get_schema", "get_schema")
builder.add_edge("get_schema", "generate_query")
builder.add_conditional_edges(
    "generate_query",
    should_continue,
)
builder.add_edge("check_query", "run_query")
builder.add_edge("run_query", "generate_query")

agent = builder.compile()

In [62]:
question = "What is the title of the albums by artist 'Queen'?"

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    for  msg in step["messages"]:
        if isinstance(msg, AIMessage):  # ✅ Check type instead of .role
            s = msg.content
    # step["messages"][-1].pretty_print()
print(s)

LIST_OFTABLES {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': '', 'type': 'tool_call'}]), ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', tool_call_id=''), AIMessage(content='Available tables: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', additional_kwargs={}, response_metadata={})]}


KeyboardInterrupt: 

In [56]:
print(agent.get_graph().draw_ascii())

                        +-----------+                    
                        | __start__ |                    
                        +-----------+                    
                              *                          
                              *                          
                              *                          
                       +-------------+                   
                       | list_tables |                   
                       +-------------+                   
                              *                          
                              *                          
                              *                          
                     +-----------------+                 
                     | call_get_schema |                 
                     +-----------------+                 
                              *                          
                              *                          
              

In [45]:
print(agent.get_graph().draw_mermaid())

---
config:
  flowchart:
    curve: linear
---
graph TD;
	__start__([<p>__start__</p>]):::first
	list_tables(list_tables)
	call_get_schema(call_get_schema)
	get_schema(get_schema)
	generate_query(generate_query)
	check_query(check_query)
	run_query(run_query)
	__end__([<p>__end__</p>]):::last
	__start__ --> list_tables;
	call_get_schema --> get_schema;
	check_query --> run_query;
	generate_query -.-> __end__;
	generate_query -.-> check_query;
	get_schema --> generate_query;
	list_tables --> call_get_schema;
	run_query --> generate_query;
	classDef default fill:#f2f0ff,line-height:1.2
	classDef first fill-opacity:0
	classDef last fill:#bfb6fc



LIST_OFTABLES {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': '', 'type': 'tool_call'}]), ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', tool_call_id=''), AIMessage(content='Available tables: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', additional_kwargs={}, response_metadata={})]}

SCHEMA_RESPONSE {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_tTmOOuCmnxu4JURGyOhJJaQu', 'function': {'arguments': '{"table_names":"Album, Artist"}', 'name': 'sql_db_schema'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 207, 'total_tokens': 225, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_to

In [28]:
mermaid = builder.get_mermaid()


AttributeError: 'StateGraph' object has no attribute 'get_mermaid'