In [13]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit  
from langchain_community.utilities.sql_database import SQLDatabase  
import os
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from typing_extensions import TypedDict
from typing import Annotated, Literal
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.graph import START, END, StateGraph
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langchain_core.messages import AIMessage


In [14]:
load_dotenv()
os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY')

In [15]:
db = SQLDatabase.from_uri("sqlite:///Report_Card.db")
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x22414df8cd0>

In [16]:
llm =  ChatGroq(model= 'llama3')
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

In [17]:
for tool in tools:
    print(tool.name)

sql_db_query
sql_db_schema
sql_db_list_tables
sql_db_query_checker


In [18]:
from langchain_core.tools import tool
@tool

def query_to_database(query:str) -> str:
    '''
    Execute a SQL Query against the database and return the result.
    If the query is invalid and returns no result, an error message will be returned.
    In case of an error, the user is advised to rewrite the query and try again.
    '''
    result = db.run_no_throw(query)
    if not result:
        return 'Error!! Query failed please rewrite the query.'
    else:
        return result

In [19]:
list_tables_tool = next((tool for tool in tools if tool.name == "sql_db_list_tables"), None)
get_schema_tool = next((tool for tool in tools if tool.name == "sql_db_schema"), None)
llm_to_get_schema=llm.bind_tools([get_schema_tool])

In [20]:
llm_with_tools = llm.bind_tools(tools= [query_to_database])
llm_with_tools

RunnableBinding(bound=ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x0000022414DF9A90>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x0000022414DFA490>, model_name='llama3', model_kwargs={}, groq_api_key=SecretStr('**********')), kwargs={'tools': [{'type': 'function', 'function': {'name': 'query_to_database', 'description': 'Execute a SQL Query against the database and return the result.\nIf the query is invalid and returns no result, an error message will be returned.\nIn case of an error, the user is advised to rewrite the query and try again.', 'parameters': {'properties': {'query': {'type': 'string'}}, 'required': ['query'], 'type': 'object'}}}]}, config={}, config_factories=[])

In [9]:
class State (TypedDict):
    messages : Annotated[list[AnyMessage], add_messages]

In [11]:
def handle_tool_error(state:State):
    error = state.get("error") 
    tool_calls = state["messages"][-1].tool_calls
    return { "messages": [ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.",tool_call_id=tc["id"],) for tc in tool_calls ] }

def create_node_from_tool_with_fallback(tools:list)-> RunnableWithFallbacks[Any, dict]:
    return ToolNode(tools).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")

In [None]:
list_tables=create_node_from_tool_with_fallback([list_tables_tool])
get_schema=create_node_from_tool_with_fallback([get_schema_tool])
query_database=create_node_from_tool_with_fallback([query_to_database])

In [10]:
def first_tool_call(state : State) -> dict[str, list[AIMessage]]:
    print(f'log of the state from first_tool_call, {state}')

    return {'message' : list(AIMessage(content= '', tool_calls= [{'name' : 'sql_db_list_tables', 'args' : {}, 'id' : 'tool101' }]))}

In [None]:
def list_all_table(state : State):
    

In [None]:
def model_get_schema(state : State):
    pass

In [None]:
def get_schema_tool(state : State):
    pass

In [None]:
def query_generation(state : State):
    pass


In [None]:
def correct_query(state : State):
    pass

In [None]:
def execute_query(state : State):
    pass