In [17]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit  
from langchain_community.utilities.sql_database import SQLDatabase  
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
import os
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from typing_extensions import TypedDict
from typing import Annotated, Literal
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.graph import START, END, StateGraph
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langchain_core.messages import AIMessage


In [18]:
load_dotenv()
os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY')

In [19]:
db = SQLDatabase.from_uri("sqlite:///Report_Card.db")
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x1b935263250>

In [20]:
llm =  ChatGroq(model= 'llama3')
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

In [21]:
for tool in tools:
    print(tool.name)

sql_db_query
sql_db_schema
sql_db_list_tables
sql_db_query_checker


In [22]:
from langchain_core.tools import tool
@tool

def query_to_database(query:str) -> str:
    '''
    Execute a SQL Query against the database and return the result.
    If the query is invalid and returns no result, an error message will be returned.
    In case of an error, the user is advised to rewrite the query and try again.
    '''
    result = db.run_no_throw(query)
    if not result:
        return 'Error!! Query failed please rewrite the query.'
    else:
        return result

In [23]:
list_tables_tool = next((tool for tool in tools if tool.name == "sql_db_list_tables"), None)
get_schema_tool = next((tool for tool in tools if tool.name == "sql_db_schema"), None)
llm_to_get_schema=llm.bind_tools([get_schema_tool])

In [24]:
llm_with_tools = llm.bind_tools(tools= [query_to_database])
llm_with_tools

RunnableBinding(bound=ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x000001B9359B07D0>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001B9359B11D0>, model_name='llama3', model_kwargs={}, groq_api_key=SecretStr('**********')), kwargs={'tools': [{'type': 'function', 'function': {'name': 'query_to_database', 'description': 'Execute a SQL Query against the database and return the result.\nIf the query is invalid and returns no result, an error message will be returned.\nIn case of an error, the user is advised to rewrite the query and try again.', 'parameters': {'properties': {'query': {'type': 'string'}}, 'required': ['query'], 'type': 'object'}}}]}, config={}, config_factories=[])

In [25]:
class State (TypedDict):
    messages : Annotated[list[AnyMessage], add_messages]

In [26]:
def handle_tool_error(state:State):
    error = state.get("error") 
    tool_calls = state["messages"][-1].tool_calls
    return { "messages": [ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.",tool_call_id=tc["id"],) for tc in tool_calls ] }

def create_node_from_tool_with_fallback(tools:list)-> RunnableWithFallbacks[Any, dict]:
    return ToolNode(tools).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")

In [27]:
list_tables=create_node_from_tool_with_fallback([list_tables_tool])
get_schema=create_node_from_tool_with_fallback([get_schema_tool])
query_database=create_node_from_tool_with_fallback([query_to_database])

In [28]:
query_check_system = """You are a SQL expert. Carefully review the SQL query for common mistakes, including:

Issues with NULL handling (e.g., NOT IN with NULLs)
Improper use of UNION instead of UNION ALL
Incorrect use of BETWEEN for exclusive ranges
Data type mismatches or incorrect casting
Quoting identifiers improperly
Incorrect number of arguments in functions
Errors in JOIN conditions

If you find any mistakes, rewrite the query to fix them. If it's correct, reproduce it as is."""

query_check_prompt = ChatPromptTemplate.from_messages([
    {'role' : 'system', 'content' : query_check_system},
    MessagesPlaceholder(variable_name= 'messages')
])

check_generated_query = query_check_prompt | llm_with_tools

In [29]:
class SubmitFinalAnswer(BaseModel):
    '''
    Submit the final answer to the user based on the query result.
    '''

    final_answer : str = Field(..., description= 'Final answer to the user')

llm_with_final_answer = llm.bind_tools(tools= [SubmitFinalAnswer])

In [30]:
query_gen_system_prompt = """You are a SQL expert with a strong attention to detail.Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

1. DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

2. Output the SQL query that answers the input question without a tool call.

3. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.

4. You can order the results by a relevant column to return the most interesting examples in the database.

5. Never query for all the columns from a specific table, only ask for the relevant columns given the question.

6. If you get an error while executing a query, rewrite the query and try again.

7. If you get an empty result set, you should try to rewrite the query to get a non-empty result set.

8. NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

9. If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

10. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer. """

query_gen_prompt = ChatPromptTemplate.from_messages([
    {'role' : "system", 'content' :query_gen_system_prompt}, 
    MessagesPlaceholder(variable_name= 'messages')
])

query_generator = query_gen_prompt | llm_with_final_answer

In [32]:
def first_tool_call(state : State) -> dict[str, list[AIMessage]]:
    print(f'log of the state from first_tool_call, {state}')

    return {'messages' : list(AIMessage(content= '', tool_calls= [{'name' : 'sql_db_list_tables', 'args' : {}, 'id' : 'tool101' }]))}

In [33]:
def check_the_given_query(state : State):
    print(f'state from check the given query: {state}')

    return {[
        check_generated_query.invoke({
        'messages' : [state['messages'][-1]]
        })
    ]}

In [None]:
def generation_query(state : State):
    