In [2]:
import os
import sqlite3
import argparse

from pydantic import BaseModel
from typing import Optional, List
from langchain_groq import ChatGroq
from langchain_milvus import Milvus
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain.schema import HumanMessage, AIMessage
from langchain_community.tools import TavilySearchResults
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import StateGraph, START, END, add_messages
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

os.environ["GROQ_API_KEY"] = "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"
os.environ["TAVILY_API_KEY"] = "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"

In [None]:
########## Defaults ##########

TEMPERATURE=0.5
LLM_MODEL_NAME="mixtral-8x7b-32768"
EMBED_MODEL_NAME="sentence-transformers/all-MiniLM-L6-v2"
MILVUS_URI="./milvus_example.db"

########## Components ##########

print("Setting up LLM...")
llm = ChatGroq(temperature=TEMPERATURE, model_name=LLM_MODEL_NAME)

print("Setting up embed model...")
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)

print("Setting up Milvus vector DB...")
vector_db = Milvus(
    embedding_function=embeddings,
    connection_args={"uri": MILVUS_URI},
    index_params={"index_type": "FLAT", "metric_type": "L2"},
    auto_id=True
)

print("Setting up Tavily tool...")
web_search = TavilySearchResults(
    max_results=1,
    search_depth="advanced",
    include_answer=True,
    include_raw_content=True,
    include_images=True
)

In [37]:

########## Schema ##########

class ChatState(BaseModel):
    do_rag: bool = False
    do_web_search: bool = False
    rag_context: Optional[str] = ""
    web_context: Optional[str] = ""
    curr_query: Optional[str] = ""
    curr_response: Optional[str] = ""
    messages: List[HumanMessage | AIMessage] = []

########## Nodes ##########
    
# Routing Function
def router(state: ChatState):
    print("state in router-", state)
    if state.do_rag and state.do_web_search:
        return ["rag_node", "web_search_node"]
    elif state.do_rag:
        return ["rag_node"]
    elif state.do_web_search:
        return ["web_search_node"]
    else:
        return ["llm_node"]

# RAG Node
def rag_node(state: ChatState) -> ChatState:
    print("state in rag-", state)
    """Retrieve context from vector database."""
    context = vector_db.similarity_search(state.curr_query, k=3)
    context = "\n".join([doc.page_content for doc in context])
    return {"rag_context": context}

# Web Search Node
def web_search_node(state: ChatState) -> ChatState:
    print("state in web-", state)
    """Retrieve context from web search."""
    context = web_search.run(state.curr_query)[0]['content']
    return {"web_context": context}

prompt_template = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant"),
    MessagesPlaceholder("messages"),
    ("human", "{prompt}")
])

# LLM Node to handle conversation
def llm_node(state: ChatState) -> ChatState:
    """Generate response using LLM while maintaining conversation history."""

    print("state in llm-", state)

    context_parts = []  # Collect available contexts dynamically
    
    if state.rag_context:
        context_parts.append(f"RAG Context:\n{state.rag_context}")
    if state.web_context:
        context_parts.append(f"Web Context:\n{state.web_context}")
    
    # Join all available parts with spacing
    context_str = "\n\n".join(context_parts) if context_parts else "No additional context available."
    
    # Final formatted string
    prompt = f"User Query:\n{state.curr_query}\n\n{context_str}"
    
    # Extract past conversation messages and format as history
    prompt = prompt_template.format(messages=state.messages, prompt=prompt)
    
    print("llm invoke....")
    # Generate response using LLM
    #response = llm.invoke(prompt).content
    
    # Update conversation history
    state.messages.append(AIMessage(content="Hello"))

    return {"message": state.messages, "curr_response": "hello"}

In [42]:
########## Graph ##########

conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
memory = SqliteSaver(conn)

# Build the LangGraph workflow
graph = StateGraph(ChatState)
graph.add_node("rag_node", rag_node)
graph.add_node("web_search_node", web_search_node)
graph.add_node("llm_node", llm_node)

# Define edges
graph.add_conditional_edges(START, router)
graph.add_edge("rag_node", "llm_node")
graph.add_edge("web_search_node", "llm_node")
graph.add_edge("llm_node", END)  # Ensure LLM output leads to END
print("Compiling graph...")
graph = graph.compile(checkpointer=memory)
# graph = graph.compile()

def invoke_conversation(query, do_web_search, do_rag, thread_id):
    """Handles new and ongoing conversations based on thread_id."""
    state = ChatState(curr_query=query, do_web_search=do_web_search, do_rag=do_rag)
    config = {"configurable": {"thread_id": thread_id}}
    result = graph.invoke(state, config=config)
    return result["response"]

Compiling graph...


In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [44]:
print("Starting...")

state = ChatState(
    do_rag=False,
    do_web_search=False,
    curr_query="My name is yegyanathan"
)

config = {"configurable": {"thread_id": "1"}}
print("Invoking graph...")
result = graph.invoke(state, config=config)
print(result)

Starting...
Invoking graph...
state in router- do_rag=False do_web_search=False rag_context='' web_context='' curr_query='My name is yegyanathan' curr_response='hello' messages=[]
state in llm- do_rag=False do_web_search=False rag_context='' web_context='' curr_query='My name is yegyanathan' curr_response='hello' messages=[]
llm invoke....
{'do_rag': False, 'do_web_search': False, 'curr_query': 'My name is yegyanathan', 'curr_response': 'hello'}


In [36]:
state

ChatState(do_rag=False, do_web_search=False, rag_context='', web_context='', curr_query='Hi how are you?', curr_response='', messages=[])

In [3]:
conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)

In [4]:
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
print(tables)  # List of table names

[('checkpoints',), ('writes',)]


In [5]:
# Get table schema
table_name = "writes"
cursor.execute(f"PRAGMA table_info({table_name})")

# Fetch and print results
columns = cursor.fetchall()
for col in columns:
    print(col)


(0, 'thread_id', 'TEXT', 1, None, 1)
(1, 'checkpoint_ns', 'TEXT', 1, "''", 2)
(2, 'checkpoint_id', 'TEXT', 1, None, 3)
(3, 'task_id', 'TEXT', 1, None, 4)
(4, 'idx', 'INTEGER', 1, None, 5)
(5, 'channel', 'TEXT', 1, None, 0)
(6, 'type', 'TEXT', 0, None, 0)
(7, 'value', 'BLOB', 0, None, 0)


In [6]:
cursor.execute(f"SELECT * from writes")

<sqlite3.Cursor at 0x13b517dc0>

In [7]:
columns = cursor.fetchall()
for col in columns:
    print(col)

('1', '', '1eff0ccf-fcc6-6bbe-bfff-a17a6d99b3d8', 'cd1f5fdc-7bda-7beb-9cd5-c839dd731933', 0, 'use_rag', 'msgpack', b'\xc2')
('1', '', '1eff0ccf-fcc6-6bbe-bfff-a17a6d99b3d8', 'cd1f5fdc-7bda-7beb-9cd5-c839dd731933', 1, 'use_web_search', 'msgpack', b'\xc3')
('1', '', '1eff0ccf-fcc6-6bbe-bfff-a17a6d99b3d8', 'cd1f5fdc-7bda-7beb-9cd5-c839dd731933', 2, 'messages', 'msgpack', b'\x91\xc7\xdf\x05\x94\xbdlangchain_core.messages.human\xacHumanMessage\x87\xa7content\xd9+Who is the current chief minister of delhi?\xb1additional_kwargs\x80\xb1response_metadata\x80\xa4type\xa5human\xa4name\xc0\xa2id\xd9$9028a7bf-43ad-4632-a84d-0259db227f96\xa7example\xc2\xb3model_validate_json')
('1', '', '1eff0ccf-fcc6-6bbe-bfff-a17a6d99b3d8', 'cd1f5fdc-7bda-7beb-9cd5-c839dd731933', 3, 'rag_context', 'msgpack', b'\xa0')
('1', '', '1eff0ccf-fcc6-6bbe-bfff-a17a6d99b3d8', 'cd1f5fdc-7bda-7beb-9cd5-c839dd731933', 4, 'web_context', 'msgpack', b'\xa0')
('1', '', '1eff0ccf-fcc6-6bbe-bfff-a17a6d99b3d8', 'cd1f5fdc-7bda-7beb-9c