In [1]:
from pymongo import MongoClient
from langchain_groq import ChatGroq
from langchain_community.embeddings import HuggingFaceEmbeddings
import os
from dotenv import load_dotenv
load_dotenv()

MONGO_URI = os.getenv("MONGO_URI")
MONGO_DB_NAME = os.getenv("MONGO_DB_NAME")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")

In [2]:

mongo = MongoClient(MONGO_URI)
db = mongo[MONGO_DB_NAME]
bills_col = db.bills


In [None]:
from pinecone import Pinecone, ServerlessSpec
index_name = "bills-rag"
pinecone_client = Pinecone(
        api_key=os.getenv("PINECONE_API_KEY"),
    )

if index_name not in pinecone_client.list_indexes().names():
    pinecone_client.create_index(
        name=index_name,
        dimension=384,
        metric="dotproduct",
        spec=ServerlessSpec(
            cloud="aws",
            region=os.getenv("PINECONE_REGION"),
        ),
    )

In [4]:
pcIndex = pinecone_client.Index(index_name)
pcIndex

<pinecone.db_data.index.Index at 0x1cb67ec1d90>

In [5]:
groq_api_key = os.getenv("GROQ_API_KEY")
groq_llm = ChatGroq(
    api_key=groq_api_key,
    model_name="llama-3.1-8b-instant",
    temperature=0,
)
groq_llm

ChatGroq(profile={'max_input_tokens': 131072, 'max_output_tokens': 8192, 'image_inputs': False, 'audio_inputs': False, 'video_inputs': False, 'image_outputs': False, 'audio_outputs': False, 'video_outputs': False, 'reasoning_output': False, 'tool_calling': True}, client=<groq.resources.chat.completions.Completions object at 0x000001CB6913E590>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001CB694CF3D0>, model_name='llama-3.1-8b-instant', temperature=1e-08, model_kwargs={}, groq_api_key=SecretStr('**********'))

In [6]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel


In [7]:
class QueryPlan(BaseModel):
    type: str
    operation: str
    entities: dict
    filters: dict | None
    time_range: dict | None
    needs_rag: bool


In [8]:
classifier_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a query planner for a bill management app. Output ONLY valid JSON."),
    ("human", "{query}")
])
classifier_chain = (
    classifier_prompt
    | groq_llm
    | JsonOutputParser(pydantic_object=QueryPlan)
)
classifier_chain


ChatPromptTemplate(input_variables=['query'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a query planner for a bill management app. Output ONLY valid JSON.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['query'], input_types={}, partial_variables={}, template='{query}'), additional_kwargs={})])
| ChatGroq(profile={'max_input_tokens': 131072, 'max_output_tokens': 8192, 'image_inputs': False, 'audio_inputs': False, 'video_inputs': False, 'image_outputs': False, 'audio_outputs': False, 'video_outputs': False, 'reasoning_output': False, 'tool_calling': True}, client=<groq.resources.chat.completions.Completions object at 0x000001CB6913E590>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001CB694CF3D0>, model_name='llama-3.1-8b-instant', temperature=1e-08, model_kwargs={}, groq_api_key=S

In [22]:
def run_mongo_pipeline(pipeline):
    return list(bills_col.aggregate(pipeline))


from backend.templates.query_templates import QUERY_TEMPLATES

def execute_mongo(plan, user_id):
    template_fn = QUERY_TEMPLATES[plan.operation]
    pipeline = template_fn(user_id, plan)
    return run_mongo_pipeline(pipeline)


In [11]:
from langchain_community.embeddings import HuggingFaceEmbeddings
embeddings_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True}
)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
def vector_search(query, user_id, category=None, top_k=5):
    vector = embeddings_model.embed_query(query)

    index_filter = {"user_id": user_id}
    if category:
        index_filter["category"] = category

    results = pcIndex.query(
        vector=vector,
        top_k=top_k,
        filter=index_filter,
        include_metadata=True
    )

    return [match["metadata"]["text"] for match in results["matches"]]


In [14]:
def semantic_chain(plan, user_query, user_id):
    context = vector_search(
        query=user_query,
        user_id=user_id,
        category=plan.entities.get("category")
    )

    prompt = f"""
    Answer the question using the following bill context:

    {context}

    Question: {user_query}
    """

    return groq_llm.invoke(prompt).content


In [16]:
def mixed_chain(plan, user_query, user_id):
    mongo_result = execute_mongo(plan, user_id)

    facts = mongo_result[0]
    context = vector_search(
        query=user_query,
        user_id=user_id,
        category=plan.entities.get("category")
    )

    prompt = f"""
    Facts:
    {facts}

    Bill Details:
    {context}

    Answer the user question clearly.
    """

    return groq_llm.invoke(prompt).content


In [17]:
def query_router(user_query, user_id):
    plan = classifier_chain.invoke({"query": user_query})

    if plan.type in ["FILTER", "AGGREGATION"]:
        return execute_mongo(plan, user_id)

    if plan.type == "SEMANTIC":
        return semantic_chain(plan, user_query, user_id)

    if plan.type == "MIXED":
        return mixed_chain(plan, user_query, user_id)

    raise ValueError("Unsupported query type")


In [None]:
# response = query_router(
#     "How much did I spend on medical bills last year and what were the treatments?",
#     user_id="u1"
# )

# print(response)


AttributeError: 'dict' object has no attribute 'type'