In [1]:
#Dependencies
# !pip install langchain langchain_openai langchain_community openai langgraph dotenv pymongo

In [2]:
from dotenv import load_dotenv
load_dotenv("../.env")

True

In [None]:
from typing import Optional, Dict, List, Union, Any
from pydantic import BaseModel, Field, field_validator
from langchain_core.tools import tool
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError, OperationFailure
from bson import ObjectId
from datetime import datetime
from decimal import Decimal
import re
import os
import json

# Enhanced helper to convert various BSON types to JSON-serializable formats
def serialize_document(doc: Any) -> Any:
    """Recursively serialize MongoDB document to JSON-serializable format."""
    if isinstance(doc, dict):
        result = {}
        for key, value in doc.items():
            result[key] = serialize_document(value)
        return result
    elif isinstance(doc, list):
        return [serialize_document(item) for item in doc]
    elif isinstance(doc, ObjectId):
        return str(doc)
    elif isinstance(doc, datetime):
        return doc.isoformat()
    elif isinstance(doc, Decimal):
        return float(doc)
    else:
        return doc


class MongoQueryInput(BaseModel):
    uri: str = Field(..., description="MongoDB connection URI")
    db_name: str = Field(..., description="Name of the MongoDB database")
    collection_name: str = Field(..., description="Name of the collection")
    query: Optional[Dict] = Field(default_factory=dict, description="MongoDB query filter")
    projection: Optional[Dict] = Field(default=None, description="Fields to include or exclude")
    limit: int = Field(default=10, description="Maximum number of documents to retrieve")
    use_aggregation: bool = Field(default=False, description="Use aggregation pipeline instead of find()")
    pipeline: Optional[List[Dict]] = Field(default=None, description="Aggregation pipeline if use_aggregation is True")
    
    @field_validator('db_name', 'collection_name')
    def validate_names(cls, v):
        if not re.match(r'^[a-zA-Z0-9_-]+$', v):
            raise ValueError(f"Invalid name: {v}. Only alphanumeric characters, hyphens, and underscores allowed.")
        return v
    
    @field_validator('limit')
    def validate_limit(cls, v):
        if v <= 0 or v > 1000:
            raise ValueError("Limit must be between 1 and 1000")
        return v
    
    @field_validator('pipeline')
    @classmethod 
    def validate_pipeline(cls, v, info):
        """Validate pipeline when aggregation is used."""
        use_aggregation = info.data.get('use_aggregation', False)
        if use_aggregation and not v:
            raise ValueError("Pipeline is required when use_aggregation is True")
        return v

class MongoQueryOutput(BaseModel):
    documents: List[Dict]
    count: int
    total_available: Optional[int] = Field(default=None, description="Total documents matching query (for non-aggregation queries)")
    success: bool = Field(default=True, description="Whether the operation was successful")
    error_message: Optional[str] = Field(default=None, description="Error message if operation failed")

@tool
def extract_mongo_data(input: MongoQueryInput) -> MongoQueryOutput:
    """
    Extract data from a MongoDB collection using a query or aggregation pipeline.
    
    This tool provides flexible MongoDB querying capabilities with proper error handling
    and BSON serialization. It supports both regular find() queries and aggregation pipelines.
    """
    client = None
    try:
        # Create client with timeout settings
        client = MongoClient(
            input.uri, 
            serverSelectionTimeoutMS=5000,
            connectTimeoutMS=5000,
            socketTimeoutMS=5000
        )
        
        # Test the connection
        client.admin.command('ping')
        
        # Get database and collection
        db = client[input.db_name]
        collection = db[input.collection_name]
        
        # Execute query based on type
        if input.use_aggregation:
            if not input.pipeline:
                raise ValueError("Pipeline is required when use_aggregation is True")
            
            cursor = collection.aggregate(input.pipeline)
            results = [serialize_document(doc) for doc in cursor]
            total_available = None  # Aggregation doesn't have a simple count
            
        else:
            # Regular find query
            query = input.query or {}
            
            # Get total count for pagination info
            total_available = collection.count_documents(query)
            
            # Execute find with projection and limit
            cursor = collection.find(query, input.projection).limit(input.limit)
            results = [serialize_document(doc) for doc in cursor]
        
        return MongoQueryOutput(
            documents=results,
            count=len(results),
            total_available=total_available,
            success=True
        )
        
    except ConnectionFailure as e:
        return MongoQueryOutput(
            documents=[],
            count=0,
            success=False,
            error_message=f"Failed to connect to MongoDB: {str(e)}"
        )
    
    except ServerSelectionTimeoutError as e:
        return MongoQueryOutput(
            documents=[],
            count=0,
            success=False,
            error_message=f"MongoDB server selection timeout: {str(e)}"
        )
    
    except OperationFailure as e:
        return MongoQueryOutput(
            documents=[],
            count=0,
            success=False,
            error_message=f"MongoDB operation failed: {str(e)}"
        )
    
    except ValueError as e:
        return MongoQueryOutput(
            documents=[],
            count=0,
            success=False,
            error_message=f"Invalid input: {str(e)}"
        )
    
    except Exception as e:
        return MongoQueryOutput(
            documents=[],
            count=0,
            success=False,
            error_message=f"Unexpected error: {str(e)}"
        )
    
    finally:
        # Ensure connection is closed
        if client:
            client.close()

In [None]:
from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
# from tools.mongo_tool import extract_mongo_data

llm = ChatOpenAI(model="gpt-4", temperature=0)
tools = [extract_mongo_data]

prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful database manager that can query MongoDB databases and provide information. Use the extract_mongo_data tool when users ask about database queries."),
    ("user", "{input}"),
    ("assistant", "{agent_scratchpad}")
])

agent = create_openai_functions_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

def query_database_with_llm(user_question: str):
    """Let the LLM decide how to query the database based on user input."""
    # Get MongoDB connection details from environment variables
    mongo_uri = os.getenv('MONGODB_URI', 'mongodb://localhost:27017')
    mongo_db = os.getenv('MONGODB_DB_NAME', 'MCQDB')
    mongo_collection = os.getenv('MONGODB_COLLECTION_NAME', 'MigratoryData')
    collection_schema = os.getenv('MONGODB_COLLECTION_SCHEMA', '{}')


    response = agent_executor.invoke({
        "input": f"""
        I need to query a MongoDB database. Here are the connection details:
        - URI: {mongo_uri}
        - Database: {mongo_db}
        - Collection: {mongo_collection}

        Details about the collection:
        {collection_schema}

        Important query constraints:
        - The limit must be between 1 and 10 documents
        - Default limit is 10 if not specified

        User question: {user_question}
        
        Please use the extract_mongo_data tool to answer this question.

        Output Format: Do not show all the content of the documents, just provide the metadata from returned documents.
        """
    })
    return response


if __name__ == "__main__":
    # Example 1: Agent-based approach
    print("=== Agent-based Query ===")
    agent_mongo_query = input("Enter your question for the MongoDB database: ")
    if not agent_mongo_query:
        agent_mongo_query = "Return unique chapter names"
    result = query_database_with_llm(agent_mongo_query)
    print(result['output'])
    

=== Agent-based Query ===


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `extract_mongo_data` with `{'input': {'uri': 'mongodb+srv://shishir:shishir@cluster0.wumztey.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0', 'db_name': 'MCQDB', 'collection_name': 'MigratoryData', 'query': {'chapter_name': 'Life Processes'}, 'projection': {'_id': 1, 'topic_id': 1, 'topic_name': 1, 'grade': 1, 'board': 1, 'subject': 1, 'chapter_name': 1, 'publication': 1, 'status': 1, 'qid': 1, 'lo': 1, 'lu': 1, 'isClone': 1, 'correct_option': 1, 'lu_name': 1, 'loid': 1, 'luid': 1}, 'limit': 10, 'use_aggregation': False, 'pipeline': None}}`


[0m[36;1m[1;3mdocuments=[{'_id': '678631e9a9572dda48204eaf', 'topic_id': '10_biology_life processes_2', 'topic_name': 'Autotrophic Nutrition', 'chapter_name': 'Life Processes', 'status': 'pushed', 'subject': 'biology', 'qid': 7237, 'board': 'CBSE', 'lo': '(LU006446.001); Default - Remember', 'lu': 'LU006446 : Autotrophic mode of nutritio

## Chat-Mongo

In [None]:
from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True
)

prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a helpful and friendly database assistant who can both engage in general conversation 
    and help with MongoDB database queries when needed. Only use the extract_mongo_data tool when users 
    specifically ask for database information or queries. Otherwise, engage in normal conversation.
    
    When using the database:
    - Ensure queries are relevant and necessary
    - Keep limits between 1 and 10 documents
    - Format database results concisely
    """),
    MessagesPlaceholder(variable_name="chat_history"),
    ("user", "{input}"),
    ("assistant", "{agent_scratchpad}")
])

model_name = os.getenv('LLM_MODEL', 'gpt-4')
model_temperature = float(os.getenv('LLM_TEMPERATURE', '0.0'))
llm = ChatOpenAI(model=model_name, temperature=model_temperature)
tools = [extract_mongo_data]

agent = create_openai_functions_agent(llm, tools, prompt)
agent_executor = AgentExecutor(
    agent=agent, 
    tools=tools, 
    memory=memory,
    verbose=True
)

def chat_with_agent(user_input: str):
    """Conversational interface for the agent"""
    
    mongo_uri = os.getenv('MONGODB_URI', 'mongodb://localhost:27017')
    mongo_db = os.getenv('MONGODB_DB_NAME', 'MCQDB')
    mongo_collection = os.getenv('MONGODB_COLLECTION_NAME', 'MigratoryData')
    collection_schema = os.getenv('MONGODB_COLLECTION_SCHEMA', '{}')
    
    enhanced_input = f"""
    Database Context:
    - URI: {mongo_uri}
    - Database: {mongo_db}
    - Collection: {mongo_collection}
    - Schema: {collection_schema}

    User question: {user_input}
    """

    response = agent_executor.invoke({"input": enhanced_input})
    return response['output']


if __name__ == "__main__":
    print("Chat with the AI (type 'exit' to quit)")
    
    while True:
        user_input = input("\nYou: ")
        if user_input.lower() == 'exit':
            break
        response = chat_with_agent(user_input)
        print(f"\nAssistant: {response}")

  memory = ConversationBufferMemory(


Chat with the AI (type 'exit' to quit)


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mHello! How can I assist you today?[0m

[1m> Finished chain.[0m

Assistant: Hello! How can I assist you today?


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mHello! How can I assist you today?[0m

[1m> Finished chain.[0m

Assistant: Hello! How can I assist you today?


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mNice to meet you, Shishir! How can I assist you today?[0m

[1m> Finished chain.[0m

Assistant: Nice to meet you, Shishir! How can I assist you today?


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThe seven wonders of the world are:

1. The Great Wall of China
2. Petra, Jordan
3. Christ the Redeemer in Rio de Janeiro, Brazil
4. Machu Picchu in Peru
5. The Chichen Itza Pyramid in Mexico
6. The Roman Colosseum in Italy
7. The Taj Mahal in India

These are known as the "New Seven Wonders of the World," as determined by a global poll