In [None]:
# Rough Plan
# Make sure backend is running before running this script localhost:8080 or ping the correct url
# Function to do a check on backend

# Use langgraph to create a doc reviewer

# TODO: Use an agent to decided weather to retreive or not create basic agents to do non-retreival tasks like what time it is? dont do rag for that but use a tool

# Once retrieved, use a LLM to do the review and use the following example to do the review 

In [None]:
import requests
from typing import Dict, Any, List, Optional
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain.tools.retriever import create_retriever_tool
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import AIMessage
from langchain_community.vectorstores import Chroma
from pydantic import BaseModel, Field
from typing import Annotated, Sequence
from langchain_core.messages import BaseMessage

from langgraph.graph.message import add_messages
import os
import logging

import dotenv
dotenv.load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

In [None]:
# paste in the azure auth token here
AUTH_TOKEN = ""

In [None]:
def fetch_context(query: str, tags: List[str], customer_uuid: str = None, filter_data: dict = None) -> List[str]:
    """
    Fetches context from the backend.

    Args:
        query (str): The query to send to the backend.
        tags (List[str]): Tags to filter the context.
        customer_uuid (str): Customer UUID for filtering.
        filter_data (dict): Additional filter data.

    Returns:
        List[str]: The list of relevant contexts retrieved from the backend.
    """
    payload = {
        "query": query,
        "tags": tags,
        "customerUuid": customer_uuid or "",
        "filterData": filter_data or {}
    }
    headers = {
        "Authorization": f"Bearer {AUTH_TOKEN}",
        "Content-Type": "application/json"
    }

    response = requests.post("http://localhost:8080/api/context", json=payload, headers=headers)
    response.raise_for_status()
    context_list = response.json().get("context", [])

    if not isinstance(context_list, list):
        raise ValueError("The 'context' key in the response must be a list.")

    return context_list

In [None]:
def grade_documents(contexts: List[str], question: str) -> List[str]:
    """
    Grades each context string for relevance to the question.

    Args:
        contexts (List[str]): List of context strings to grade.
        question (str): The user question.

    Returns:
        List[str]: List of relevant contexts that passed the grading.
    """
    relevant_contexts = []
    for idx, context in enumerate(contexts):
        logging.info(f"Grading context {idx + 1}/{len(contexts)}: {context}")

        # Grader logic
        model = AzureChatOpenAI(
            temperature=0, 
            azure_endpoint=os.environ["OPENAI_API_HOST"],
            azure_deployment='gpt-4o-mini',
            openai_api_version='2024-07-18', 
            streaming=True
        )
        
        # model = AzureChatOpenAI(
        #     temperature=0,
        #     azure_endpoint=os.environ["OPENAI_API_HOST"],
        #     azure_deployment="gpt-4o",
        #     openai_api_version="2024-11-20",
        #     streaming=True,
        # )
        prompt = PromptTemplate(
            template="""You are a grader assessing relevance of a retrieved document to a user question. 
            Here is the retrieved document: {context}
            Here is the user question: {question}
            If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
            Give a binary score 'yes' or 'no'.""",
            input_variables=["context", "question"],
        )
        chain = prompt | model
        result = chain.invoke({"context": context, "question": question})

        # Extract the content from the AIMessage object
        if isinstance(result, AIMessage):
            result_content = result.content.strip()
        else:
            logging.error(f"Unexpected result type: {type(result)}")
            continue

        if result_content.lower() == "yes":
            relevant_contexts.append(context)
        else:
            logging.warning(f"Context {idx + 1} dropped: {context}")

    # Log reordering if necessary
    if relevant_contexts != contexts:
        logging.info("The order of contexts has been modified after grading.")

    return relevant_contexts

In [None]:
def create_langgraph_workflow():
    """
    Creates a LangGraph workflow for document review using the backend API.

    Returns:
        StateGraph: The compiled LangGraph workflow.
    """
    # Define the state schema
    from pydantic import BaseModel

    class AgentState(BaseModel):
        messages: Annotated[Sequence[BaseMessage], add_messages]
        contexts: Optional[List[str]] = None
        graded_contexts: Optional[List[str]] = None  # Store indices of graded contexts
        dropped_contexts: Optional[List[str]] = None  # Store indices of dropped contexts
        graded_indices: Optional[List[int]] = None
        dropped_indices: Optional[List[int]] = None
        result: Optional[str] = None

    def agent(state: AgentState):
        """
        Fetches contexts from the backend and adds them to the state.
        """
        messages = state.messages
        question = messages[0].content
        tags = ["CHIPGPT"]

        logging.info(f"Agent received question: {question}")
        contexts = fetch_context(query=question, tags=tags)

        logging.info(f"Fetched {len(contexts)} contexts from the backend.")
        state.contexts = contexts
        return state

    def grade_wrapper(state: AgentState):
        """
        Grades the contexts and updates the state with graded contexts or a message indicating insufficient context.
        """
        messages = state.messages
        question = messages[0].content
        contexts = state.contexts

        logging.info(f"Grading contexts for question: {question}")

        if not contexts:
            logging.info("No contexts to grade.")
            state.result = "no"
            return state

        graded_contexts = grade_documents(contexts, question)

        logging.info(f"Graded contexts: {graded_contexts}")
        
        # Determine dropped contexts by index
        graded_indices = [i for i, context in enumerate(contexts) if context in graded_contexts]
        dropped_indices = [i for i in range(len(contexts)) if i not in graded_indices]

        state.graded_indices = graded_indices
        state.dropped_indices = dropped_indices
        # Determine dropped contexts
        dropped_contexts = [context for context in contexts if context not in graded_contexts]
        state.dropped_contexts = dropped_contexts

        if not graded_contexts:
            logging.info("No relevant contexts found.")
            state.messages = messages + [HumanMessage(content="I don't have enough relevant context to give you an answer.")]
            state.result = "no"
        else:
            state.graded_contexts = graded_contexts
            state.result = "yes"

        return state

    def get_result(state: AgentState):
        """
        Returns the result of the grading step.
        """
        return state.result

    def generate(state: AgentState):
        """
        Generates the final response using the graded contexts.
        """
        messages = state.messages
        question = messages[0].content
        docs = state.graded_contexts

        logging.debug(f"Docs received in generate: {docs} (type: {type(docs)})")

        if not docs:
            logging.info("No contexts available for generation.")
            state.messages = messages + [HumanMessage(content="I don't have enough relevant context to give you an answer.")]
            return state

        docs_combined = "\n".join(docs)

        logging.info(f"Generating response for question: {question}")
        logging.info(f"Using graded contexts: {docs_combined}")

        prompt = PromptTemplate(
            template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. 
                    Question: {question}  
                    Context: {context}  
                    Answer:""",
            input_variables=["context", "question"],
        )

        llm = AzureChatOpenAI(
            temperature=0,
            azure_endpoint=os.environ["OPENAI_API_HOST"],
            azure_deployment="gpt-4o",
            openai_api_version="2024-11-20",
            streaming=True,
        )
        rag_chain = prompt | llm | StrOutputParser()
        response = rag_chain.invoke({"context": docs_combined, "question": question})

        logging.info(f"Generated response: {response}")
        state.messages = messages + [response]
        return state

    def not_enough_context(state: AgentState):
        """
        Handles cases where there's not enough context.
        """
        return state  # State already has the message set in grade_wrapper

    def log_summary(state: AgentState):
        """
        Logs a summary of the workflow execution, including dropped context indices and the final order of contexts.
        """
        logging.info("Workflow Summary:")
        logging.info(f"Original context count: {len(state.contexts)}")
        logging.info(f"Dropped context indices: {state.dropped_indices}")
        logging.info(f"Final graded context indices (used in response): {state.graded_indices}")
        return state

    # Use the defined schema for the state
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", agent)
    workflow.add_node("grade_wrapper", grade_wrapper)
    workflow.add_node("generate", generate)
    workflow.add_node("not_enough_context", not_enough_context)
    workflow.add_node("log_summary", log_summary)

    # Add edges
    workflow.add_edge(START, "agent")
    workflow.add_edge("agent", "grade_wrapper")

    # Use get_result here
    workflow.add_conditional_edges("grade_wrapper", get_result, {"yes": "generate", "no": "not_enough_context"})

    workflow.add_edge("generate", "log_summary")
    workflow.add_edge("not_enough_context", "log_summary")
    workflow.add_edge("log_summary", END)

    return workflow.compile()

In [None]:
if __name__ == "__main__":  
    # Create the LangGraph workflow  
    graph = create_langgraph_workflow()  
    
    # Define the input messages using HumanMessage  
    from langchain_core.messages import HumanMessage  
    
    inputs = {  
        "messages": [  
            HumanMessage(content="What is Nomic Embed?")  # Use HumanMessage instead of dict  
        ]  
    }  
    
    # Stream the outputs from the graph  
    try:  
        for output in graph.stream(inputs):  
            print(output)  # Output will include the final AIMessage
    except Exception as e:  
        print(f"An error occurred: {e}")

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass