In [1]:
import os
from dotenv import load_dotenv

load_dotenv('/home/TomKerby/Research/lit_review/.env', override=True)
os.environ["LANGCHAIN_PROJECT"] = "react_agent"

In [2]:
from langchain_ollama import ChatOllama
from langchain_huggingface import ChatHuggingFace
from langchain_openai import ChatOpenAI

def get_model(model_type: str = "ollama", **kwargs):
    """Factor for different model types"""
    model_map = {
        "ollama": ChatOllama,
        "huggingface": ChatHuggingFace,
        "openai": ChatOpenAI
    }
    return model_map[model_type](**kwargs)

model = get_model(
    model_type="ollama",
    model="llama3.3",
    temperature=0.7
)

# Retriever Tool

In [3]:
from langchain.tools.retriever import create_retriever_tool
from langchain_community.retrievers import WikipediaRetriever, ArxivRetriever

wiki_retriever = WikipediaRetriever()
arxiv_retriever = ArxivRetriever()

# wiki_retriever_tool = create_retriever_tool(
#     wiki_retriever,
#     "retrieve_wikipedia_articles",
#     "Search and return information from Wikipedia.",
# )

# arxiv_retriever_tool = create_retriever_tool(
#     arxiv_retriever,
#     "retrieve_arxiv_articles",
#     "Search and return academic articles from ArXiv.",
# )

# tools = [wiki_retriever_tool, arxiv_retriever_tool]



In [None]:
from typing import Annotated, Sequence
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]

In [None]:
from langchain_core.messages import BaseMessage

class CustomDocumentMessage(BaseMessage):
    def __init__(self, documents):
        self.documents = documents
        self.content = "\n\n".join(
            f"Source {i+1}: {doc.metadata.get('source', 'Unknown')}\n{doc.page_content}"
            for i, doc in enumerate(documents)
        )
        self.metadata = {"sources": [doc.metadata.get("source", "Unknown") for doc in documents]}
    
    def __repr__(self):
        return f"CustomDocumentMessage(documents={self.documents!r})"

from langgraph.prebuilt import ToolNode

class CustomToolNode(ToolNode):
    def invoke(self, input, config=None, **kwargs):
        # Call the parent invoke method, passing config and any extra kwargs.
        result = super().invoke(input, config, **kwargs)
        # Check if the result is a list (we assume a list of document objects).
        if isinstance(result, list):
            return {"messages": [CustomDocumentMessage(result)]}
        return result

    
def create_custom_retriever_tool(retriever, name, description):
    # Define a tool function that calls the retriever.
    def tool_function(query: str):
        """
        Calls the retriever to get documents matching the query.
        
        Args:
            query (str): The query string.

        Returns:
            List: A list of document objects with `page_content` and `metadata`.
        """
        documents = retriever.retrieve(query)
        return documents
    
    custom_node = CustomToolNode([tool_function])
    custom_node.name = name
    custom_node.description = description
    return custom_node

wiki_retriever_tool = create_custom_retriever_tool(
    wiki_retriever,
    "retrieve_wikipedia_articles",
    "Search and return information from Wikipedia."
)

arxiv_retriever_tool = create_custom_retriever_tool(
    arxiv_retriever,
    "retrieve_arxiv_articles",
    "Search and return academic articles from ArXiv."
)

tools = [wiki_retriever_tool, arxiv_retriever_tool]


In [None]:
from typing import Annotated, Literal, Sequence
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel
from langgraph.prebuilt import tools_condition

### Edges


def grade_documents(state) -> Literal["generate", "rewrite"]:
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (messages): The current state

    Returns:
        str: A decision for whether the documents are relevant or not
    """

    print("---CHECK RELEVANCE---")

    class grade(BaseModel):
        """Binary score for relevance check."""
        binary_score: Literal["yes", "no"]

    model = ChatOllama(temperature=0, model="llama3.3", streaming=True)
    llm_with_tool = model.with_structured_output(grade, method="json_schema")

    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"],
    )

    chain = prompt | llm_with_tool

    messages = state["messages"]
    last_message = messages[-1]
    
    question = messages[0].content
    docs = last_message.content

    scored_result = chain.invoke({"question": question, "context": docs})
    score = scored_result.binary_score

    if score == "yes":
        print("---DECISION: DOCS RELEVANT---")
        return "generate"

    else:
        print("---DECISION: DOCS NOT RELEVANT---")
        print(score)
        return "rewrite"


### Nodes


def agent(state):
    """
    Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with the agent response appended to messages
    """
    print("---CALL AGENT---")
    messages = state["messages"]
    model = ChatOllama(temperature=0, model="llama3.3", streaming=True)
    model = model.bind_tools(tools)
    response = model.invoke(messages)
    print(response)
    return {"messages": [response]}


def rewrite(state):
    """
    Transform the query to produce a better question.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with re-phrased question
    """

    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    question = messages[0].content

    msg = [
        HumanMessage(
            content=f""" \n 
    Look at the input and try to reason about the underlying semantic intent / meaning. \n 
    Here is the initial question:
    \n ------- \n
    {question} 
    \n ------- \n
    Formulate an improved question: """,
        )
    ]

    # Grader
    model = ChatOllama(temperature=0, num_ctx=32768, model="llama3.3", streaming=True)
    response = model.invoke(msg)
    return {"messages": [response]}


def generate(state):
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]

    # Since last_message is our CustomDocumentMessage, access the underlying documents:
    docs = last_message.documents  # Now docs is a list of document objects.

    # Format docs as desired (using the metadata stored in each document)
    def format_docs(docs):
        return "\n\n".join(
            f"Source {i+1}: {doc.metadata.get('source', '')}\n{doc.page_content}"
            for i, doc in enumerate(docs)
        )

    formatted_docs = format_docs(docs)
    prompt = PromptTemplate.from_template("""
    Answer the question based only on these documents:
    {docs}

    Question: {question}
    """)
    llm = ChatOllama(temperature=0, model="llama3.3", streaming=True)
    rag_chain = prompt | llm | StrOutputParser()
    response = rag_chain.invoke({"docs": formatted_docs, "question": question})
    return {"messages": [response]}


In [13]:
from langchain_core.tools.base import BaseTool  # or import from the appropriate module

class MultiToolNode:
    def __init__(self, tool_list, name="multi_tool_node", description=""):
        self.tool_list = tool_list
        self.name = name
        self.description = description

    def invoke(self, input, config=None, **kwargs):
        # Example: call each tool and aggregate their outputs.
        aggregated_results = []
        for tool in self.tool_list:
            result = tool.invoke(input, config, **kwargs)
            # Expect that each result is a dict with a "messages" key.
            if isinstance(result, dict) and "messages" in result:
                aggregated_results.extend(result["messages"])
            else:
                aggregated_results.append(result)
        # Wrap the aggregated results in your custom message if needed.
        return {"messages": [CustomDocumentMessage(aggregated_results)]}

    # Make the node callable so that the workflow can execute it.
    def __call__(self, input, config=None, **kwargs):
        return self.invoke(input, config, **kwargs)



In [14]:
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode

workflow = StateGraph(AgentState)

workflow.add_node("agent", agent)

retrieve_node = MultiToolNode(tools, name="retrieve", description="Aggregates retrieval tools.")

workflow.add_node("retrieve", retrieve_node)
# retrieve = CustomToolNode(tools)
# workflow.add_node("retrieve", retrieve) 
workflow.add_node("rewrite", rewrite) 
workflow.add_node(
    "generate", generate
)
workflow.add_edge(START, "agent")

workflow.add_conditional_edges(
    "agent",
    tools_condition,
    {
        "tools": "retrieve",
        END: END,
    },
)

workflow.add_conditional_edges(
    "retrieve",
    grade_documents,
)
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")

# Compile
graph = workflow.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    pass

In [16]:
import pprint

inputs = {
    "messages": [
        ("user", "Who is the sitting president of the United States?"),
    ]
}
for output in graph.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"Output from node '{key}':")
        pprint.pprint("---")
        pprint.pprint(value, indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

---CALL AGENT---


ValueError: Unsupported function

retrieve_wikipedia_articles(tags=None, recurse=True, explode_args=False, func_accepts_config=True, func_accepts={'store': ('__pregel_store', None)}, tools_by_name={'tool_function': StructuredTool(name='tool_function', description='Calls the retriever to get documents matching the query.\n\nArgs:\n    query (str): The query string.\n\nReturns:\n    List: A list of document objects with `page_content` and `metadata`.', args_schema=<class 'langchain_core.utils.pydantic.tool_function'>, func=<function create_custom_retriever_tool.<locals>.tool_function at 0x74e86cb70f40>)}, tool_to_state_args={'tool_function': {}}, tool_to_store_arg={'tool_function': None}, handle_tool_errors=True, messages_key='messages', description='Search and return information from Wikipedia.')

Functions must be passed in as Dict, pydantic.BaseModel, or Callable. If they're a dict they must either be in OpenAI function format or valid JSON schema with top-level 'title' and 'description' keys.