In [1]:
# %%
%pip install langgraph openai chromadb pydantic tiktoken langchain

Collecting langchain
  Downloading langchain-0.3.25-py3-none-any.whl.metadata (7.8 kB)
Collecting langchain-text-splitters<1.0.0,>=0.3.8 (from langchain)
  Downloading langchain_text_splitters-0.3.8-py3-none-any.whl.metadata (1.9 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading sqlalchemy-2.0.41-cp39-cp39-win_amd64.whl.metadata (9.8 kB)
Collecting async-timeout<5.0.0,>=4.0.0 (from langchain)
  Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)
Collecting greenlet>=1 (from SQLAlchemy<3,>=1.4->langchain)
  Downloading greenlet-3.2.3-cp39-cp39-win_amd64.whl.metadata (4.2 kB)
Downloading langchain-0.3.25-py3-none-any.whl (1.0 MB)
   ---------------------------------------- 0.0/1.0 MB ? eta -:--:--
   ------------------------------- -------- 0.8/1.0 MB 4.8 MB/s eta 0:00:01
   ---------------------------------------- 1.0/1.0 MB 4.8 MB/s eta 0:00:00
Downloading async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Downloading langchain_text_splitters-0.3.8-py3-none-any

In [None]:
# %%
import os

# Set your API key here
os.environ["OPENAI_API_KEY"] = "Enter your OpenAI API Key"


In [None]:
# %%
import json
from openai import OpenAI
import chromadb
from chromadb.config import Settings
from openai import OpenAI
from openai import embeddings
from uuid import uuid4

# Load KB
with open("self_critique_loop_dataset.json", "r") as f:
    kb_data = json.load(f)

# Initialize Chroma
chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
collection = chroma_client.get_or_create_collection(name="kb_index")

# Embed and upsert
from openai import OpenAI

client = OpenAI()

def embed(text):
    return client.embeddings.create(
        model="text-embedding-3-small",
        input=text
    ).data[0].embedding

for doc in kb_data:
    vector = embed(doc["answer_snippet"])
    collection.upsert(
        ids=[doc["doc_id"]],
        embeddings=[vector],
        documents=[doc["answer_snippet"]],
        metadatas=[{
            "source": doc["source"],
            "last_updated": doc["last_updated"]
        }]
    )


In [None]:
# %%
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict
from langchain_core.runnables import Runnable
from langchain_core.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI

# LangGraph State
class RAGState(TypedDict):
    user_question: str
    kb_hits: List[Dict]
    initial_answer: str
    critique_result: str
    refined_answer: str


In [None]:
# %%
def retrieve_kb(state: RAGState) -> RAGState:
    query = state["user_question"]
    query_vector = embed(query)
    results = collection.query(query_embeddings=[query_vector], n_results=5)

    kb_hits = []
    for i in range(len(results["ids"][0])):
        kb_hits.append({
            "doc_id": results["ids"][0][i],
            "answer_snippet": results["documents"][0][i],
            "source": results["metadatas"][0][i]["source"]
        })

    return {**state, "kb_hits": kb_hits}


In [None]:
# %%
llm = ChatOpenAI(model="gpt-4", temperature=0)

prompt_gen = PromptTemplate.from_template("""
You are a software best-practices assistant.

User Question:
{user_question}

Retrieved Snippets:
{kb_context}

Task:
Based on these snippets, write a concise answer to the user’s question.
Cite each snippet you use by its doc_id in square brackets (e.g., [KB004]).
Return only the answer text.
""")

def generate_answer(state: RAGState) -> RAGState:
    kb_context = "\n".join([f"[{hit['doc_id']}] {hit['answer_snippet']}" for hit in state["kb_hits"]])
    formatted_prompt = prompt_gen.format(user_question=state["user_question"], kb_context=kb_context)
    response = llm.invoke(formatted_prompt)
    return {**state, "initial_answer": response.content}


In [None]:
# %%
prompt_critique = PromptTemplate.from_template("""
You are a critical QA assistant. The user asked: {user_question}

Initial Answer:
{initial_answer}

KB Snippets:
{kb_context}

Task:
Determine if the initial answer fully addresses the question using only these snippets.
- If it does, respond exactly: COMPLETE
- If it misses any point or cites missing info, respond: REFINE: <short list of missing topic keywords>

Return exactly one line.
""")

def critique_answer(state: RAGState) -> RAGState:
    kb_context = "\n".join([f"[{hit['doc_id']}] {hit['answer_snippet']}" for hit in state["kb_hits"]])
    formatted_prompt = prompt_critique.format(
        user_question=state["user_question"],
        initial_answer=state["initial_answer"],
        kb_context=kb_context
    )
    response = llm.invoke(formatted_prompt)
    return {**state, "critique_result": response.content.strip()}


In [None]:
# %%
prompt_refine = PromptTemplate.from_template("""
You are a software best-practices assistant refining your answer. The user asked: {user_question}

Initial Answer:
{initial_answer}

Critique: {critique_result}

Additional Snippet:
[{extra_doc_id}] {extra_snippet}

Task:
Incorporate this snippet into the answer, covering the missing points.
Cite any snippet you use by doc_id in square brackets.
Return only the final refined answer.
""")

def refine_answer(state: RAGState) -> RAGState:
    if "REFINE" not in state["critique_result"]:
        return {**state, "refined_answer": state["initial_answer"]}
    
    missing_keywords = state["critique_result"].replace("REFINE:", "").strip()
    new_query = f"{state['user_question']} and information on {missing_keywords}"
    query_vector = embed(new_query)
    results = collection.query(query_embeddings=[query_vector], n_results=1)

    extra_doc_id = results["ids"][0][0]
    extra_snippet = results["documents"][0][0]

    formatted_prompt = prompt_refine.format(
        user_question=state["user_question"],
        initial_answer=state["initial_answer"],
        critique_result=state["critique_result"],
        extra_doc_id=extra_doc_id,
        extra_snippet=extra_snippet
    )
    response = llm.invoke(formatted_prompt)
    return {**state, "refined_answer": response.content}


In [None]:
# %%
workflow = StateGraph(RAGState)

workflow.add_node("retrieve_kb", retrieve_kb)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("critique_answer", critique_answer)
workflow.add_node("refine_answer", refine_answer)

workflow.set_entry_point("retrieve_kb")
workflow.add_edge("retrieve_kb", "generate_answer")
workflow.add_edge("generate_answer", "critique_answer")

def decide_next(state: RAGState):
    return "refine_answer" if "REFINE" in state["critique_result"] else END

workflow.add_conditional_edges("critique_answer", decide_next, {
    "refine_answer": "refine_answer",
    END: END
})

workflow.add_edge("refine_answer", END)

app = workflow.compile()


In [None]:
# %%
question = "What are best practices for caching?"

final_output = app.invoke({"user_question": question})

print(" Initial KB Hits:")
for hit in final_output["kb_hits"]:
    print(f"- [{hit['doc_id']}] {hit['answer_snippet']}")

print(" Initial Answer:")
print(final_output["initial_answer"])

print("\n Critique Result:")
print(final_output["critique_result"])

if "REFINE" in final_output["critique_result"]:
    print("\n Refined Answer:")
    print(final_output["refined_answer"])
    print("\n Final Answer:")
    print(final_output["refined_answer"])
else:
    print("\n Final Answer:")
    print(final_output["initial_answer"])
