In [None]:
!pip install pymongo python-dotenv langchain langchain-groq


In [1]:
import json
import os
from dotenv import load_dotenv
from pymongo import MongoClient
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

In [2]:
# Load secrets
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
MONGO_URI = os.getenv("MONGO_URI")

# Connect to MongoDB
client = MongoClient(MONGO_URI)
db = client["sample_analytics"]

# Load Groq LLaMA-4 model
llm = ChatGroq(
    model="meta-llama/llama-4-scout-17b-16e-instruct",
    temperature=0,
    api_key=GROQ_API_KEY
)

In [3]:

# MongoDB schema-aware prompt
query_prompt = PromptTemplate(
    input_variables=["question", "sample"],
    template="""
You are a MongoDB AI assistant. Your task is to output ONLY a valid JSON object with:
{{
  "collection": "<collection_name>",
  "query": [ aggregation_pipeline_array ]
}}

Valid collections in the `sample_analytics` database:

1. `accounts`
   - `account_id` (int)
   - `limit` (number): amount of money in the account
   - `products` (array of strings): may include "InvestmentStock", "InvestmentFund", "Derivatives", "Commodity"
   - `customer_id` (string)

2. `customers`
   - `username` (string): register number
   - `name` (string)
   - `email` (string)
   - `birthdate` (date)
   - `address` (object): {{ "street": ..., "city": ..., "state": ..., "zip": ..., "country": ... }}
   - `accounts` (array of account_ids)

3. `transactions`
   - `_id`
   - `account_id` (int)
   - `transaction_count` (int)
   - `bucket_start_date` (ISODate)
   - `bucket_end_date` (ISODate)
   - `transactions`: array of nested transactions (up to 66)

RULES:
- Only return a valid JSON object.
- Never return explanations or extra text.
- Use $match for filters like amount > 1000.
- Use numeric filters, not "$1000" strings.

Example:
Q: List all transactions above $1000
A:
{{
  "collection": "transactions",
  "query": [
    {{ "$match": {{ "transactions.amount": {{ "$gt": 1000 }} }} }}
  ]
}}

Now answer:
Question: {question}
Sample: {sample}
"""
)


# Result summarization prompt
final_prompt = PromptTemplate(
    input_variables=["question", "results"],
    template="""
You have a MongoDB question and its query output below. Convert the results into a simple natural language answer.
Only include the answer, nothing else.

Question: {question}
Query Output: {results}
Answer:
"""
)



In [5]:
llmchain = LLMChain(llm=llm, prompt=query_prompt)

# Main loop to ask questions
while True:
    user_question = input("Ask your MongoDB question (or type 'exit' to quit): ").strip()
    if user_question.lower() in ["exit", "quit"]:
        print("Exiting...")
        break
    if not user_question:
        print("⚠️ Please enter a question.")
        continue

    # Generate aggregation pipeline from LLM
    response = llmchain.invoke({
        "question": user_question,
        "sample": "Q: Show top 5 accounts by limit.\nA: {\"collection\": \"accounts\", \"query\": [{ \"$sort\": { \"limit\": -1 } }, { \"$limit\": 5 }]}"
    })

    print("\n🔧 LLM-generated query:")
    print(response["text"])

    try:
        result = json.loads(response["text"])
        collection_name = result["collection"]
        query = result["query"]

        if collection_name not in db.list_collection_names():
            raise ValueError(f"Collection `{collection_name}` not found in MongoDB")

        collection = db[collection_name]
        results = list(collection.aggregate(query))

        if not results:
            print("⚠️ No results found. Here's a sample from the collection:")
            print(list(collection.find().limit(3)))
        else:
            final_chain = LLMChain(llm=llm, prompt=final_prompt)
            for i, result_doc in enumerate(results, 1):
                final_response = final_chain.invoke({
                    "question": user_question,
                    "results": result_doc
                })
                print(f"\n🔹 Result {i}: {final_response['text']}")

    except Exception as e:
        print("❌ Error parsing or executing the LLM output.")
        print("🔧 Raw Output:\n", response["text"])
        raise e


🔧 LLM-generated query:
{
 "collection": "accounts",
 "query": [
  { "$group": { "_id": null, "count": { "$sum": 1 } } },
  { "$project": { "_id": 0 } }
 ]
}

🔹 Result 1: There are 1746 unique accounts.

🔧 LLM-generated query:
{
 "collection": "accounts",
 "query": [
  { "$sort": { "limit": -1 } },
  { "$limit": 1 },
  { "$lookup": { "from": "customers", "localField": "customer_id", "foreignField": "username", "as": "customer" } },
  { "$unwind": "$customer" },
  { "$project": { "_id": 0, "account_id": 1, "customer_name": "$customer.name" } }
 ]
}
⚠️ No results found. Here's a sample from the collection:
[{'_id': ObjectId('5ca4bbc7a2dd94ee5816238c'), 'account_id': 371138, 'limit': 9000, 'products': ['Derivatives', 'InvestmentStock']}, {'_id': ObjectId('5ca4bbc7a2dd94ee5816238d'), 'account_id': 557378, 'limit': 10000, 'products': ['InvestmentStock', 'Commodity', 'Brokerage', 'CurrencyService']}, {'_id': ObjectId('5ca4bbc7a2dd94ee5816238e'), 'account_id': 198100, 'limit': 10000, 'product