In [13]:
import pymongo
from pymongo import MongoClient
from dotenv import load_dotenv
import os

load_dotenv()

True

In [14]:
connection_string = os.getenv("DB_CONNECTION_SECRET")

In [15]:
client = MongoClient(
        connection_string,
        connectTimeoutMS=5000,  # 5 seconds
        serverSelectionTimeoutMS=5000  # 5 seconds
    )

In [16]:
db = client.get_database("devTinder")

In [275]:
from langchain_mongodb.agent_toolkit import MongoDBDatabase, MongoDBDatabaseToolkit
from langchain_groq import ChatGroq

# Set up the MongoDB database
db = MongoDBDatabase.from_connection_string(
    os.getenv("DB_CONNECTION_SECRET"),
    database="devTinder"
)

llm = ChatGroq(model="llama3-8b-8192", api_key=os.getenv("GROQ_API_KEY"))

# Create the toolkit with the database
toolkit = MongoDBDatabaseToolkit(db=db, llm=llm)

# Get the tools from the toolkit
tools = toolkit.get_tools()

print("Available tools:")
for tool in tools:
    print(tool.name)

Available tools:
mongodb_query
mongodb_schema
mongodb_list_collections
mongodb_query_checker


In [34]:
db.get_usable_collection_names()

['chats', 'connectionrequests', 'payments', 'test_collection', 'users']

In [35]:
llm.invoke("Hello, how are you?")

AIMessage(content="I'm just a language model, I don't have emotions or feelings like humans do, but I'm functioning properly and ready to assist you with any questions or tasks you may have! It's great to chat with you. How can I help you today?", additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 16, 'total_tokens': 69, 'completion_time': 0.044166667, 'prompt_time': 0.00246693, 'queue_time': 0.234590618, 'total_time': 0.046633597}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_dadc9d6142', 'finish_reason': 'stop', 'logprobs': None}, id='run-5a84df2d-7d2a-4199-b2c9-8ef58d97bee0-0', usage_metadata={'input_tokens': 16, 'output_tokens': 53, 'total_tokens': 69})

In [36]:
for tool in tools:
    # Check if the tool's name is "mongodb_list_collections"
    if tool.name == "mongodb_list_collections":
        list_collections_tool = tool
        break

In [37]:
for tool in tools:
    # Check if the tool's name is "mongodb_schema"
    if tool.name == "mongodb_schema":
        # If found, assign it to get_schema_tool
        get_schema_tool = tool
        break 

In [38]:
llm_to_get_schema=llm.bind_tools([get_schema_tool])
llm_to_get_schema

RunnableBinding(bound=ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x1249fa650>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x1257f0770>, model_name='llama3-8b-8192', model_kwargs={}, groq_api_key=SecretStr('**********')), kwargs={'tools': [{'type': 'function', 'function': {'name': 'mongodb_schema', 'description': 'Input to this tool is a comma-separated list of collections, output is the schema and sample rows for those collections. Be sure that the collectionss actually exist by calling mongodb_list_collections first! Example Input: collection1, collection2, collection3', 'parameters': {'properties': {'collection_names': {'description': "A comma-separated list of the collection names for which to return the schema. Example input: 'collection1, collection2, collection3'", 'type': 'string'}}, 'required': ['collection_names'], 'type': 'object'}}}]}, config={}, config_factories=[])

In [100]:
from langchain.tools import tool

# The @tool Decorator: This marks the function as a LangChain tool, making it available for use by a language model agent. 
# Tools in LangChain are functions that agents can call to perform specific tasks.
# Mentioning the prompt inside as in Mongodb (https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/agent_toolkit/langchain_mongodb.agent_toolkit.database.MongoDBDatabase.html#langchain_mongodb.agent_toolkit.database.MongoDBDatabase.run_no_throw) only uses aggregation queries.

@tool
def query_to_database(query: str) -> str:
    """
        Execute a MongoDB **aggregation query string** against the database and return the result.
        The query string MUST be in the MongoDB shell format: 'db.collectionName.aggregate([pipeline])'.
        
        DONT USE other aggregation queries like find, findOne, etc.
        
        Example query: 
        
        **IMPORTANT TOOL USAGE RULES:**
        1.  The `query_to_database` tool ONLY accepts MongoDB **aggregation query strings**.
        2.  The query string MUST strictly follow the format: `'db.collectionName.aggregate([pipeline])'`.
        3.  Use the correct collection name (e.g., `users`, `payments`).
        4.  Use the correct field names based on the known schema (e.g., `firstName`, `lastName`, `emailId`, `createdAt`). Do NOT guess field names like `first_name`.
        5.  Use `$match` within the pipeline for filtering documents (like a WHERE clause).
        6.  Use `$project` to select specific fields.
        7.  Use `$count` to count documents.
        8.  Use `$limit` and `$sort` for those specific operations.
        9.  Do NOT attempt to use other commands like `find`, `findOne`, `countDocuments` directly in the query string.

        **QUERY EXAMPLES:**

        **1. Show all documents in a collection (e.g., `users`):**
        *User Request:* "Show all users", "List all users"
        *Tool Query:* `'db.users.aggregate([ { "$match": {} } ])'`

        **2. Show documents matching specific criteria (e.g., users with `firstName` "Rohan"):**
        *User Request:* "Find users named Rohan", "Get user Rohan's details"
        *Tool Query:* `'db.users.aggregate([ { "$match": { "firstName": "Rohan" } } ])'`

        **3. Show documents matching multiple criteria (e.g., users with `firstName` "Rohan" AND `lastName` "Gore"):**
        *User Request:* "Find user Rohan Gore"
        *Tool Query:* `'db.users.aggregate([ { "$match": { "firstName": "Rohan", "lastName": "Gore" } } ])'`
        *(Alternative using $and):* `'db.users.aggregate([ { "$match": { "$and": [ { "firstName": "Rohan" }, { "lastName": "Gore" } ] } } ])'`

        **4. Show specific fields for matching documents (e.g., `firstName` and `emailId` for user "Rohan"):**
        *User Request:* "Show Rohan's first name and email"
        *Tool Query:* `'db.users.aggregate([ { "$match": { "firstName": "Rohan" } }, { "$project": { "firstName": 1, "emailId": 1, "_id": 0 } } ])'`

        **5. Count documents matching criteria (e.g., count users named "Rohan"):**
        *User Request:* "How many users are named Rohan?"
        *Tool Query:* `'db.users.aggregate([ { "$match": { "firstName": "Rohan" } }, { "$count": "matching_users_count" } ])'`

        **6. Limit the number of results (e.g., show the first 5 users):**
        *User Request:* "Show 5 users"
        *Tool Query:* `'db.users.aggregate([ { "$match": {} }, { "$limit": 5 } ])'`

        **7. Sort results (e.g., show users sorted by `createdAt` descending):**
        *User Request:* "Show users sorted by creation date, newest first"
        *Tool Query:* `'db.users.aggregate([ { "$match": {} }, { "$sort": { "createdAt": -1 } } ])'`

        **8. Combine operations (e.g., show `emailId` of the 5 newest users named "Rohan"):**
        *User Request:* "Show the email addresses of the 5 most recent users named Rohan"
        *Tool Query:* `'db.users.aggregate([ { "$match": { "firstName": "Rohan" } }, { "$sort": { "createdAt": -1 } }, { "$limit": 5 }, { "$project": { "emailId": 1, "_id": 0 } } ])'`

        Remember to always construct the query string in the exact `db.collectionName.aggregate([pipeline])` format for the `query_to_database` tool. Use the collection schema information (like field names `firstName`, `emailId`) when formulating the pipeline stages.
"""
    
    # runs the query and if it is invalid or returns no result, gracefully handles it will return an error message.
    result = db.run_no_throw(query) 
    
    
    if not result:
        return "No result returned from the query. Please try again."
    return result

In [101]:
db.get_collection_info(["users"])

'Database name: devTinder\nCollection name: users\nSchema from a sample of documents from the collection:\n_id: ObjectId\nfirstName: String\nlastName: String\nemailId: String\npassword: String\nphotoUrl: String\nskills: Array\ncreatedAt: Timestamp\nupdatedAt: Timestamp\n__v: Number\n\n/*\n3 documents from users collection:\n[\n  {\n    "_id": {\n      "$oid": "66ece21ff3406ae729fafc0c"\n    },\n    "firstName": "Rohan",\n    "lastName": "Gore",\n    "emailId": "rg@gmail.com",\n    "password": "$2b$10$ptCy5NAP59AF5t",\n    "photoUrl": "http://dummy.com",\n    "skills": [],\n    "createdAt": {\n      "$date": "2024-09-20T02:46:55.558Z"\n    },\n    "updatedAt": {\n      "$date": "2024-09-20T02:46:55.558Z"\n    },\n    "__v": 0\n  },\n  {\n    "_id": {\n      "$oid": "66eec283bf081b9c5cb8c96e"\n    },\n    "firstName": "Vibhor",\n    "lastName": "J",\n    "emailId": "vb@gmail.com",\n    "password": "$2b$10$hMCB8xIJxKcr1y",\n    "photoUrl": "http://dummy.com",\n    "skills": [],\n    "crea

In [102]:
# Example of using the tool
query_to_database.invoke('db.users.aggregate([ { "$match": { "firstName": "Sahil" } } ])')

'[\n  {\n    "_id": {\n      "$oid": "67ce3b18d583a646a05cb04f"\n    },\n    "firstName": "Sahil",\n    "lastName": "Bhoir",\n    "emailId": "sb@gmail.com",\n    "password": "$2b$10$HagDQ/B00ra/BpR8tXTSy.ANQMuhsTSaPkwwz4UBn2TYpvXKGub2q",\n    "photoUrl": "http://dummy.com",\n    "skills": [],\n    "createdAt": {\n      "$date": "2025-03-10T01:06:32.236Z"\n    },\n    "updatedAt": {\n      "$date": "2025-03-12T01:59:29.740Z"\n    },\n    "__v": 0,\n    "isPayment": true\n  }\n]'

In [103]:
query_to_database.invoke('db.users.aggregate([ { "$match": { } } ])')

'[\n  {\n    "_id": {\n      "$oid": "66ece21ff3406ae729fafc0c"\n    },\n    "firstName": "Rohan",\n    "lastName": "Gore",\n    "emailId": "rg@gmail.com",\n    "password": "$2b$10$ptCy5NAP59AF5txOtIutV.CmoIZoqh1TK1kAaEWpIQjQSjlTwhdgi",\n    "photoUrl": "http://dummy.com",\n    "skills": [],\n    "createdAt": {\n      "$date": "2024-09-20T02:46:55.558Z"\n    },\n    "updatedAt": {\n      "$date": "2024-09-20T02:46:55.558Z"\n    },\n    "__v": 0\n  },\n  {\n    "_id": {\n      "$oid": "66eec283bf081b9c5cb8c96e"\n    },\n    "firstName": "Vibhor",\n    "lastName": "J",\n    "emailId": "vb@gmail.com",\n    "password": "$2b$10$hMCB8xIJxKcr1y5Ho9s94.jSm2/TMNZRWK0ojct5hVigojFJ7hLb.",\n    "photoUrl": "http://dummy.com",\n    "skills": [],\n    "createdAt": {\n      "$date": "2024-09-21T12:56:35.503Z"\n    },\n    "updatedAt": {\n      "$date": "2024-09-26T03:18:45.028Z"\n    },\n    "__v": 0\n  },\n  {\n    "_id": {\n      "$oid": "66f0e4794fad30bd74decb15"\n    },\n    "firstName": "Shirish

In [104]:
## Tool binding
"""
First, it binds the query_to_database tool to the language model (LLM).
This essentially gives the LLM access to the database query functionality.
"""
llm_with_tools = llm.bind_tools([query_to_database])

# Now, when the LLM is asked to show all employees, it can use the query_to_database tool to execute the query.
llm_with_tools.invoke("show all users") # it executes "select * from employees;"


AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_34c5', 'function': {'arguments': '{"query":"\'db.users.aggregate([ { \\"$match\\": {} } ])\'"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 85, 'prompt_tokens': 1899, 'total_tokens': 1984, 'completion_time': 0.070833333, 'prompt_time': 0.234551936, 'queue_time': 0.237272201, 'total_time': 0.305385269}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_a97cfe35ae', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-ba714cdf-96a2-475b-962e-9edad7e17138-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': '\'db.users.aggregate([ { "$match": {} } ])\''}, 'id': 'call_34c5', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1899, 'output_tokens': 85, 'total_tokens': 1984})

In [110]:
# Suppress all warnings
import warnings
warnings.filterwarnings("ignore")

from typing import Annotated, Literal
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks

In [122]:
from langchain_core.prompts import ChatPromptTemplate

# System Prompt for MongoDB Aggregation String Check
mongo_query_check_system = """
You are an expert MongoDB BSON Query Checker, specializing in validating aggregation query strings for a specific tool.

**Tool Requirement:** The tool ONLY accepts query strings in the exact MongoDB shell format: `'db.collectionName.aggregate([pipeline])'`.

**Your Task:** Carefully review the provided MongoDB aggregation query string for correctness and adherence to the required format. Check for common mistakes, including:

1.  **Format Adherence:** Does the string strictly follow `'db.collectionName.aggregate([pipeline])'`? (e.g., presence of `db.`, `.aggregate(`, `[` and `]`).
2.  **Valid Collection Name:** Is a valid collection name used (e.g., `users`, `payments`)?
3.  **Valid Pipeline:** Is the pipeline `[...]` a valid JSON array of aggregation stages?
4.  **Correct Field Names:** Are the field names likely correct based on common conventions or known schema? (e.g., prefer `firstName` over `first_name` if that's the pattern).
5.  **Correct Operators:** Are valid aggregation pipeline stages used (e.g., `$match`, `$project`, `$group`, `$sort`, `$limit`, `$count`, `$lookup`)?
6.  **Operator Syntax:** Is the syntax for each operator correct (e.g., `$match` takes an object, `$project` takes an object, `$sort` takes an object, `$limit` takes a number, `$count` takes a string)?
7.  **Data Type Mismatches:** Are query conditions comparing fields to values of the correct type? (e.g., querying a string field with a number, or vice-versa).
8.  **String Quoting:** Are string literals correctly enclosed in double quotes (`"`) within the query object parts?
9.  **Null Handling:** Is comparison with `null` done correctly?
10. **Disallowed Commands:** Does the string incorrectly contain other commands like `find`, `findOne`, `countDocuments`? The tool only runs `aggregate`.

**Output:**
- If you find any mistakes, **rewrite the entire query string** in the correct `'db.collectionName.aggregate([pipeline])'` format to fix the errors while preserving the original intent.
- If the query string is already correct and perfectly follows the required format, reproduce it **exactly** as is.
"""

# Create the prompt template
query_check_prompt = ChatPromptTemplate.from_messages([
    ("system", mongo_query_check_system),
    ("placeholder", "{query}") # Placeholder for the MongoDB query string to check
])

"""
Combining with Tools: 
The line check_generated_query = query_check_prompt | llm_with_tools 
combines the prompt template with the LLM that has access to the database tools. This means that when the LLM processes a query, 
it can also utilize the query_to_database tool if needed.
"""
check_generated_query = query_check_prompt | llm_with_tools

In [124]:
check_generated_query.invoke({"messages": [("user", 'SHOW ALL USERS')]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_1jem', 'function': {'arguments': '{"query":"\'db.users.aggregate([ { \\"$match\\": { \\"firstName\\": \\"Rohan\\" } } ])\'"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 86, 'prompt_tokens': 4769, 'total_tokens': 4855, 'completion_time': 0.071666667, 'prompt_time': 0.616123907, 'queue_time': -0.850360135, 'total_time': 0.687790574}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_dadc9d6142', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6cf9a823-ff2b-42f9-99fd-95e16c290c2f-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': '\'db.users.aggregate([ { "$match": { "firstName": "Rohan" } } ])\''}, 'id': 'call_1jem', 'type': 'tool_call'}], usage_metadata={'input_tokens': 4769, 'output_tokens': 86, 'total_tokens': 4855})

In [125]:
check_generated_query.invoke({"messages": [("user", "SELECT everything FROM users LIMITs 5;")]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_zky2', 'function': {'arguments': '{"query":"db.users.aggregate([ { \\"$match\\": { \\"firstName\\": \\"Rohan\\" } } ])"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 85, 'prompt_tokens': 4769, 'total_tokens': 4854, 'completion_time': 0.070833333, 'prompt_time': 0.651271192, 'queue_time': -0.886334001, 'total_time': 0.722104525}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_179b0f92c9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8bd159ab-71e9-4550-b350-d88056f6736d-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': 'db.users.aggregate([ { "$match": { "firstName": "Rohan" } } ])'}, 'id': 'call_zky2', 'type': 'tool_call'}], usage_metadata={'input_tokens': 4769, 'output_tokens': 85, 'total_tokens': 4854})

In [126]:
# Defining the Class for formatted output

class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""
    final_answer: str = Field(..., description="The final answer to the user's question.")
    
llm_with_final_answer = llm.bind_tools([SubmitFinalAnswer])

In [145]:
# First, update the system prompt to be more explicit about the tool calling format
query_gen_system_prompt = """You are a MongoDB aggregation expert who helps translate natural language questions into proper MongoDB aggregation queries.

IMPORTANT: When calling the query_to_database tool, you MUST provide a SINGLE STRING parameter named "query" in the exact format: 'db.collectionName.aggregate([pipeline])'.

For example:
- CORRECT: query_to_database(query='db.users.aggregate([ {{ "$match": {{}} }} ])')
- INCORRECT: query_to_database(pipeline=[{{"$match": {{}}}}])

Guidelines for generating MongoDB queries:
1. Use ONLY aggregation framework syntax ($match, $project, $sort, $limit, $count)
2. Always use the correct field names from the schema (e.g., 'firstName', not 'first_name')
3. Always limit results to 5 documents using $limit unless otherwise specified
4. Format your query as a properly escaped string with single quotes around the entire query
5. Use double quotes for field names and string values inside the query

Example of a complete valid call:
query_to_database(query='db.users.aggregate([ {{ "$match": {{}} }}, {{ "$limit": 5 }} ])')

After the query returns results:
1. If you get an error, correct your query and try again
2. Once you have the results, use SubmitFinalAnswer to provide the final response to the user
"""

# Create the prompt template
query_gen_prompt = ChatPromptTemplate.from_messages([
    ("system", query_gen_system_prompt), 
    ("placeholder", "{messages}")
])

# Make sure we have both tools available in a single binding
all_tools = [query_to_database, SubmitFinalAnswer]
llm_with_all_tools = llm.bind_tools(all_tools)

# Bind the prompt to the LLM with all tools
query_generator = query_gen_prompt | llm_with_all_tools

In [147]:
# Example of using the query generator
query_generator.invoke({"messages":[("can you fetch the user name Rohan Gore from users?")]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_m5qt', 'function': {'arguments': '{"query":"\'db.users.aggregate([ { \\"$match\\": { \\"firstName\\": \\"Rohan\\", \\"lastName\\": \\"Gore\\" } } ])\'"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 93, 'prompt_tokens': 4617, 'total_tokens': 4710, 'completion_time': 0.0775, 'prompt_time': 0.594773278, 'queue_time': -0.829749747, 'total_time': 0.672273278}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_dadc9d6142', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-14c0452a-3b8b-4979-8a9b-6d09f0c812d7-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': '\'db.users.aggregate([ { "$match": { "firstName": "Rohan", "lastName": "Gore" } } ])\''}, 'id': 'call_m5qt', 'type': 'tool_call'}], usage_metadata={'input_tokens': 4617, 'output_tokens': 93, 'total_tokens': 4710})

In [148]:
query_generator.invoke({"messages":[("can you give the count of users?")]})

AIMessage(content='/tool-use>\n{\n    "tool_calls": [\n        {\n            "id": "pending",\n            "type": "function",\n            "function": {\n                "name": "query_to_database"\n            },\n            "parameters": {\n                "query": "\'db.users.aggregate([ { \\"$match\\": {} }, { \\"$count\\": \\"total_users_count\\" } ])\'"\n            }\n        }\n    ]\n}\n</tool-use>', additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 86, 'prompt_tokens': 2294, 'total_tokens': 2380, 'completion_time': 0.071666667, 'prompt_time': 0.286891079, 'queue_time': 0.23272687000000003, 'total_time': 0.358557746}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_a97cfe35ae', 'finish_reason': 'stop', 'logprobs': None}, id='run-c665a030-44fb-4b0c-8c56-adde8db44c7f-0', usage_metadata={'input_tokens': 2294, 'output_tokens': 86, 'total_tokens': 2380})

# Creating the Nodes

In [289]:
"""
Purpose: This class defines the structure of the State object, which holds the conversation messages.

Attributes: It has a single attribute, messages, which is a list of messages (of type AnyMessage). 

The Annotated type suggests that there may be additional processing or validation applied to the messages.
"""
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    
    
"""
Purpose: This function handles errors that occur during tool calls.

Parameters: It takes the state object as input.

Error Handling: It retrieves the error message from the state and the list of tool calls from the last message.

Return Value: It returns a dictionary containing messages that inform the user of the error, 
including the specific tool call that failed. Each error message is associated with the corresponding tool call ID.

Usage: This function is useful for providing feedback to users when something goes wrong during a tool invocation.
"""
def handle_tool_error(state:State):
    error = state.get("error") 
    tool_calls = state["messages"][-1].tool_calls
    
    return { "messages": [ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.",tool_call_id=tc["id"],) for tc in tool_calls ] }

"""
Purpose: This function creates a node that can execute a list of tools and handle errors with a fallback mechanism.

Parameters: It takes a list of tools as input.

Return Value: It returns a ToolNode that is configured to use the provided tools and includes a 
fallback to the handle_tool_error function if an error occurs during execution.

Usage: This setup allows for robust error handling in a system where multiple tools may be called, 
ensuring that users receive appropriate feedback if something goes wrong.
"""
def create_node_from_tool_with_fallback(tools:list)-> RunnableWithFallbacks[Any, dict]:
    
    # Create a simple tool lookup dictionary
    tool_map = {tool.name: tool for tool in tools}
    
    def execute_tools(state):
        messages = state["messages"]
        last_message = messages[-1]
        
        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
            return {"messages": []}
        
        results = []
        for tool_call in last_message.tool_calls:
            try:
                tool_name = tool_call["name"]
                tool = tool_map.get(tool_name)
                if not tool:
                    raise ValueError(f"Tool '{tool_name}' not found")
                
                args = tool_call.get("args", {})
                
                # Print the query if this is the query_to_database tool
                if tool_name == "query_to_database" and "query" in args:
                    print(f"\n==== EXECUTING MONGODB QUERY ====")
                    print(f"{args['query']}")
                    print("==================================\n")
                                    
                output = tool.invoke(args)
                results.append(ToolMessage(content=str(output), tool_call_id=tool_call["id"]))
            except Exception as e:
                results.append(ToolMessage(content=f"Error: {str(e)}", tool_call_id=tool_call["id"]))
        
        return {"messages": results}
    
    return RunnableLambda(execute_tools).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")

In [290]:
# Node Creation: This line creates a node named list_collection using the create_node_from_tool_with_fallback function.
# Tool Binding: The argument [list_collections_tool] indicates that this node will use the list_collections_tool, 
# which is presumably a tool designed to list the collection in a database.
list_collection=create_node_from_tool_with_fallback([list_collections_tool])

# Node Creation: This line creates a node named get_schema in a similar manner.
# Tool Binding: The argument [get_schema_tool] indicates that this node will use the get_schema_tool, 
# which is intended to retrieve the database schema.
get_schema=create_node_from_tool_with_fallback([get_schema_tool])

# Node Creation: This line creates a node named query_database.
# Tool Binding: The argument [query_to_database] indicates that this node will use the query_to_database tool, 
# which is responsible for executing SQL queries against the database.
query_database=create_node_from_tool_with_fallback([query_to_database])

In [291]:
"""
Purpose: This function is designed to be called when the first tool (in this case, mongodb_list_collections) is invoked in the conversation.

Parameters: It takes a state parameter, which represents the current state of the conversation.

Return Value: It returns a dictionary containing a list of messages. 
The message includes a tool call to mongodb_list_collections, which is likely intended to retrieve a list of collections from a Mongodb.
"""

def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    """
    This function is called when the first tool is called.
    It takes the state of the conversation and returns a dictionary with the tool call and the list of messages.
    """
    return {"messages": [AIMessage(content="", tool_calls=[{"name":"mongodb_list_collections", "args":{}, "id": "tool_call_id"}])]}
    

In [292]:
"""
The primary purpose of this function is to validate or check the Mongodb query provided by the user. 
It leverages the check_generated_query mechanism, which is presumably set up to analyze Mongodb queries for correctness and provide feedback.
"""
def check_the_given_query(state: State):
    print(f"Checking the given query: {state}")
    
    """
    It invokes method on check_generated_query, passing in a dictionary that contains the last message from the state object.
    state["messages"][-1] retrieves the most recent message in the conversation, which is likely the Mongodb query that needs to be checked.
    """
    return {"messages": [check_generated_query.invoke({"messages": [state["messages"][-1]]})]}

In [293]:
# This state contains information about the current conversation, including user messages.
def generation_query(state: State):
    # calls the invoke method on query_generator, passing the current state. 
    # This generates a message that includes the mongo query based on the user's input.
    message = query_generator.invoke(state)
    
    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = [] # To collect any error messages related to tool calls.
    if message.tool_calls:
        for tc in message.tool_calls:
            # Other tools, like query_to_database, are used for different purposes 
            # (e.g., generating or executing mongo queries) and are not meant to be called directly 
            # in the context of submitting answers.
            # This ensures that the LLM only calls the SubmitFinalAnswer tool 
            # when it's ready to submit the final answer.

            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"""Error: The tool {tc['name']} is not valid. Please fix your mistakes. 
                        Remember to only call SubmitFinalAnswer to submit the final answer. 
                        Generated queries should be outputted WITHOUT a tool call.""",
                        tool_call_id=tc["id"]  # Correctly placed
                    )
                )    
    else:
        tool_messages = []
    
    return {"messages": [message] + tool_messages}

In [294]:
def should_continue(state: State):
    print(f"State from should_continue: {state}")
    messages = state["messages"]
    last_message = messages[-1]
    
    print(f"Last message: {last_message}")
    
    if getattr(last_message, "tool_calls", None): # checks if the last message has any associated tool calls
        print("Last message is a tool call. Ending the conversation.")
        
        return END
    elif last_message.content.startswith("Error: "):
        print("Last message is an error. regenrating the query.")
        # suggesting that the system should return to the query generation phase to correct the issue.
        return "query_gen"
    else:
        print("Last message is a normal message. Continuing the conversation.")
        # the system should proceed to the next step in the query correction or generation process.
        return "correct_query"

In [295]:
"""
the llm_get_schema function retrieves the database schema by invoking the LLM with the current 
conversation messages.
"""

def llm_get_schema(state: State):
    print(f"Getting the llm_get_schema: {state}")
    # It invoke method on llm_to_get_schema, passing the list of messages from the current state.
    # This invocation is expected to trigger the LLM to process the messages and 
    # generate a response related to the database schema.
    messages = state["messages"]
    response = llm_to_get_schema.invoke(messages) # "llm_to_get_schema" output is the schema and sample rows for those tables
    
    return {"messages": [response]}

## Building Agent Workflow

In [296]:
"""
This initializes a new StateGraph object called workflow. The StateGraph is likely a structure that 
manages different states and transitions between them based on certain conditions or events. 
The State parameter indicates that the graph will use the State type to track the current state of the conversation or process.
"""
workflow = StateGraph(State)

"""
Node Name: "first_tool_call" is the name of the node.
Function: first_tool_call is the function that will be executed when this node is reached. 
This function likely initiates the first tool call in the workflow.
"""
workflow.add_node("first_tool_call", first_tool_call) # invokes sql_db_list_tables tool from SQLDatabaseToolkit
workflow.add_node("list_collections_tool", list_collection) # get list of tables from the database
workflow.add_node("model_get_schema", llm_get_schema) # get schema of the database from the model
workflow.add_node("get_schema_tool", get_schema) # get schema of the database
workflow.add_node("query_gen", generation_query) # generate the query
workflow.add_node("correct_query", check_the_given_query) # check the given query
workflow.add_node("execute_query", query_database) # execute the query

<langgraph.graph.state.StateGraph at 0x124fc6510>

In [297]:
"""
The provided code snippet sets up the edges (or transitions) between nodes in a workflow 
represented by a StateGraph. This defines how the system moves from one state to another 
based on specific actions or conditions.
"""

# Start the workflow.
workflow.add_edge(START, "first_tool_call") # directed edge from the START node to the "first_tool_call" node

# This indicates that the next action is to list the tables in the database.
workflow.add_edge("first_tool_call", "list_collections_tool") # get list of tables from the database

# This indicates that the next action is to get the schema of the database from the model.
workflow.add_edge("list_collections_tool", "model_get_schema") # get schema of the database from the model

# responsible for interacting with the language model to retrieve the schema.
workflow.add_edge("model_get_schema", "get_schema_tool")

# This indicates that the next action is to generate the query.
workflow.add_edge("get_schema_tool", "query_gen") # generate the query

# This indicates that the next action is to check the given query.
workflow.add_conditional_edges("query_gen", should_continue, { 
    END: END,
    "correct_query": "correct_query",
    "query_gen": "query_gen"  # Add this line to handle the error case
}) # check the given query

# This indicates that the next action is to execute the query.
workflow.add_edge("correct_query", "execute_query") # execute the query

# This indicates that the next action is to generate the query.
workflow.add_edge("execute_query", "query_gen") # generate the query

<langgraph.graph.state.StateGraph at 0x124fc6510>

In [298]:
app = workflow.compile()

In [300]:
from langchain_core.messages import HumanMessage

query = {
    "messages": [
        HumanMessage(content="how many users are in users collection?")
    ]
}

response=app.invoke(query)

Getting the llm_get_schema: {'messages': [HumanMessage(content='how many users are in users collection?', additional_kwargs={}, response_metadata={}, id='256be232-30cd-4ca6-8e0b-d245f0d6f32d'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='7e2a64ce-ea78-4fe1-b8a2-c0616150fcac', tool_calls=[{'name': 'mongodb_list_collections', 'args': {}, 'id': 'tool_call_id', 'type': 'tool_call'}]), ToolMessage(content='chats, connectionrequests, payments, test_collection, users', id='235e8853-1462-4831-931e-76eb1a909fb0', tool_call_id='tool_call_id')]}
State from should_continue: {'messages': [HumanMessage(content='how many users are in users collection?', additional_kwargs={}, response_metadata={}, id='256be232-30cd-4ca6-8e0b-d245f0d6f32d'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='7e2a64ce-ea78-4fe1-b8a2-c0616150fcac', tool_calls=[{'name': 'mongodb_list_collections', 'args': {}, 'id': 'tool_call_id', 'type': 'tool_call'}]), ToolMessage(content='ch

BadRequestError: Error code: 400 - {'error': {'message': "Failed to call a function. Please adjust your prompt. See 'failed_generation' for more details.", 'type': 'invalid_request_error', 'code': 'tool_use_failed', 'failed_generation': '<tool-use>\n{\n    "tool_calls": [\n        {\n            "id": "pending",\n            "type": "function",\n            "function": {\n                "name": "query_to_database",\n                "parameters": {\n                    "query": "\'db.users.aggregate([ { \\"$match\\": { \\"firstName\\": \\"Rohan\\" } } ])\'"\n                }\n            }\n        }\n    ]\n}\n</tool-use>'}}

In [286]:
response["messages"][-1].tool_calls[0]["args"]["final_answer"]

'There are 6 users in the users collection'

In [None]:
from langchain_core.messages import HumanMessage

query = {
    "messages": [
        HumanMessage(content="find the user name Rohan Gore from users collection?")
    ]
}

response=app.invoke(query)

response["messages"][-1].tool_calls[0]["args"]["final_answer"]

Getting the llm_get_schema: {'messages': [HumanMessage(content='find the user name Rohan Gore from users collection?', additional_kwargs={}, response_metadata={}, id='2357e13c-09da-45b6-966c-b0593126eaf6'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='55793cc7-3cda-4797-a6fe-8ef83877f856', tool_calls=[{'name': 'mongodb_list_collections', 'args': {}, 'id': 'tool_call_id', 'type': 'tool_call'}]), ToolMessage(content='chats, connectionrequests, payments, test_collection, users', id='3afea25a-6872-4170-8a3f-d74ff128bfbc', tool_call_id='tool_call_id')]}
State from should_continue: {'messages': [HumanMessage(content='find the user name Rohan Gore from users collection?', additional_kwargs={}, response_metadata={}, id='2357e13c-09da-45b6-966c-b0593126eaf6'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='55793cc7-3cda-4797-a6fe-8ef83877f856', tool_calls=[{'name': 'mongodb_list_collections', 'args': {}, 'id': 'tool_call_id', 'type': 'tool_call'}]