# Notebook App for Agentic Research Assistant Agent Using Tavily, LangGrpah

## Requirements

In [None]:
!pip install langchain_cohere langchain-core langgraph langchain_core python-dotenv

In [None]:
!pip install --upgrade tavily-python

## Libraries

In [86]:
import os
import json
import asyncio
import operator
from typing import TypedDict, List, Annotated, Literal, Dict, Union, Optional 
from datetime import datetime

from tavily import AsyncTavilyClient, TavilyClient

from langchain_core.tools import tool
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_cohere.chat_models import ChatCohere
from langgraph.graph import StateGraph, START, END, add_messages

## Set API KEYS

In [87]:
# Set Your API Keys
TAVILY_API_KEY = "YOUR TAIVLY API KEY"
COHERE_API_KEY = "YOUR COHERE API KEY"
OPENAI_API_KEY =  "YOUR OPEN API KEY"

# Or use .env file 
from dotenv import load_dotenv
load_dotenv('.env')

True

## Using Open AI model

In [88]:
class ResearchState(TypedDict):
    # Declare a dictionary where:
    # - The outer dictionary has string keys.
    # - The inner dictionary can have keys of different types (e.g., str, int).
    # - The inner dictionary values can be of different types (e.g., str, float).
    documents: Dict[str, Dict[Union[str, int], Union[str, float]]]
    messages: Annotated[list[AnyMessage], add_messages]

class Citation(BaseModel):
    source_id: str = Field(
        ...,
        description="The url of a SPECIFIC source which justifies the answer.",
    )
    quote: str = Field(
        ...,
        description="The VERBATIM quote from the specified source that justifies the answer.",
    )


class QuotedAnswer(BaseModel):
    """Answer the user question based only on the given sources, and cite the sources used."""
    answer: str = Field(
        ...,
        description="The answer to the user question, which is based only on the given sources. Include any relevant sources in the answer as markdown hyperlinks. For example: 'This is a sample text ([url website](url))'"
    )
    citations: List[Citation] = Field(
        ..., description="Citations from the given sources that justify the answer."
    )

class TavilyQuery(BaseModel):
    query: str = Field(description="sub query")
    topic: str = Field(description="type of search, should be 'general' or 'news'")
    days: int = Field(description="number of days back to run 'news' search")
    raw_content: bool = Field(description="include raw content from found sources, use it ONLY if you need more deatiled information besides the summary content provided")
    domains: Optional[List[str]] = Field(default=None, description="list of domains to include in the research. Useful when trying to gather more detailed information.")
    
# Define args_schema for tavily search
class TavilySearchInput(BaseModel):
    sub_queries: List[TavilyQuery] = Field(description="set of sub-queries that can be answered in isolation")


@tool("tavily_search", args_schema=TavilySearchInput, return_direct=True)
async def tavily_search(sub_queries: List[TavilyQuery]):
    """Perform searches for each sub-query using the Tavily search tool concurrently."""
    # Get the current date
    current_date = datetime.now()
    # Format the date as a string with just the year and month
    date_string = current_date.strftime('%m-%Y')  # Formats the date as "Month-Year"
    
    
    # Define a coroutine function to perform a single search with error handling
    async def perform_search(itm):
        try:
            # Add date to the query if you need recent results
            query_with_date = itm.query + ' ' + date_string
            # Attempt to perform the search
            response = await tavily_client.search(query=query_with_date, topic=itm.topic, days=itm.days, include_raw_content=itm.raw_content, max_results=10)
            return response['results']
        except Exception as e:
            # Handle any exceptions, log them, and return an empty list
            print(f"Error occurred during search for query '{itm.query}': {str(e)}")
            return []
    
    # Gather all the search tasks concurrently
    search_tasks = [perform_search(itm) for itm in sub_queries]
    search_responses = await asyncio.gather(*search_tasks)
    
    # Combine the results from all the responses
    search_results = []
    for response in search_responses:
        search_results.extend(response)
    
    return search_results


tools = [tavily_search]
tavily_client = AsyncTavilyClient()
model = ChatOpenAI(model="gpt-4o-mini",temperature=0).bind_tools(tools)

tools_by_name = {tool.name: tool for tool in tools}
async def tool_node(state: ResearchState):
    docs = state['documents'] or {}
    docs_str = ""
    msgs = []
    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        new_docs = await tool.ainvoke(tool_call["args"])
        for doc in new_docs:
            # Make sure that this document was not retrieved before
            if not docs or doc['url'] not in docs:
                docs[doc['url']] = doc
                docs_str += json.dumps(doc)
        msgs.append(ToolMessage(content=f"Found the following new documents: {docs_str}", tool_call_id=tool_call["id"]))
    return {"messages": msgs, "documents": docs}
    
        
def call_model(state: ResearchState):
    messages = state['messages']
    # print("state['messages']:",state['messages'])
    response = model.invoke(messages)
    print(response)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

def cite_answer(state: ResearchState):
    messages = state['messages']
    response = model.with_structured_output(QuotedAnswer).invoke(input=messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [AIMessage(content=response.answer)]}
    
# Define the function that determines whether to continue or not
def should_continue(state: ResearchState) -> Literal["tools", "citation_model"]:
    messages = state['messages']
    last_message = messages[-1]
    # If the LLM makes a tool call, then we route to the "tools" node
    if last_message.tool_calls:
        return "tools"
    # Otherwise, we stop (reply to the user with citations)
    return "citation_model"

# Define a graph
workflow = StateGraph(ResearchState)

# Add nodes
workflow.add_node("route_query", call_model)
workflow.add_node("tools", tool_node)
workflow.add_node("citation_model", cite_answer)
# Set the entrypoint as route_query
workflow.set_entry_point("route_query")

# Determine which node is called next
workflow.add_conditional_edges(
    "route_query",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
)

# Add a normal edge from `tools` to `route_query`.
# This means that after `tools` is called, `route_query` node is called next.
workflow.add_edge("tools", "route_query")
workflow.add_edge("citation_model", END) # Option in the future, to add another step and filter the documents retrieved using rerhank

app = workflow.compile()

In [85]:
messages = [
    HumanMessage(
        content="Important recent information about the company Stripe"
    )
]
# content="Important recent information about the company Stripe"
# content="Wild fire prevention startups, divided by the type of technology"
async for s in app.astream({"messages": messages}, stream_mode="values"):
    message = s["messages"][-1]
    if isinstance(message, tuple):
        print(message)
    else:
        message.pretty_print()


Important recent information about the company Stripe
content='' additional_kwargs={'tool_calls': [{'id': 'call_eYFdkNULmymySw7guB6hDmAa', 'function': {'arguments': '{"sub_queries":[{"query":"recent news about Stripe","topic":"news","days":30,"raw_content":false},{"query":"Stripe company updates","topic":"general","days":30,"raw_content":false},{"query":"Stripe financial performance","topic":"general","days":30,"raw_content":false}]}', 'name': 'tavily_search'}, 'type': 'function'}]} response_metadata={'token_usage': {'completion_tokens': 75, 'prompt_tokens': 166, 'total_tokens': 241}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_48196bc67a', 'finish_reason': 'tool_calls', 'logprobs': None} id='run-a7ccf93d-d9f5-4b70-8f69-facf1bc13a09-0' tool_calls=[{'name': 'tavily_search', 'args': {'sub_queries': [{'query': 'recent news about Stripe', 'topic': 'news', 'days': 30, 'raw_content': False}, {'query': 'Stripe company updates', 'topic': 'general', 'days': 30, 'raw_conte

## Using Cohere Model

In [8]:
# from langchain_core.tools import tool
# from langgraph.prebuilt import ToolNode
# from langchain_core.pydantic_v1 import BaseModel, Field
# import operator
# from langchain_openai import ChatOpenAI
# from typing import Dict, Union
# import json


# class ResearchState(TypedDict):
#     user_query: str
#     critique: str
#     answer: str
#     documents: Annotated[list[dict], operator.add]
#     #documents: List[dict]
#     #documents: list[dict]
#     web_queries: List[str]
#     revision_number: int
#     max_revisions: int
#     messages: Annotated[list[AnyMessage], add_messages]

# class Citation(BaseModel):
#     source_id: int = Field(
#         ...,
#         description="The integer ID of a SPECIFIC source which justifies the answer.",
#     )
#     quote: str = Field(
#         ...,
#         description="The VERBATIM quote from the specified source that justifies the answer.",
#     )


# class QuotedAnswer(BaseModel):
#     """Answer the user question based only on the given sources, and cite the sources used."""

#     answer: str = Field(
#         ...,
#         description="The answer to the user question, which is based only on the given sources.",
#     )
#     citations: List[Citation] = Field(
#         ..., description="Citations from the given sources that justify the answer."
#     )


# # @tool("tavily_search",args_schema=SearchInput, return_direct=True)
# # async def tavily_search(query: str, topic: str):
# #     """Perform web search using the Tavily search tool."""
# #     return await tavily_client.search(query=query, topic=topic)

# # Define args_schema for tavily search
# class TavilySearchInput(BaseModel):
#     sub_queries: List[str] = Field(description="break down the user's input into a set of sub-queries / sub-problems that can be answered in isolation")
#     topic: str = Field(description="type of search, should be 'general' or 'news'")
#     days: int = Field(description="number of days back to run 'news' search")

# @tool("tavily_search",args_schema=TavilySearchInput, return_direct=True)
# async def tavily_search(sub_queries: List[str], topic: str, days: int):
#     """Perform searches for each sub-query using the Tavily search tool."""
#     search_results = []
#     for sub_query in sub_queries:
#         response = await tavily_client.search(query=sub_query, topic=topic,include_raw_content=False)
#         for r in response['results']:
#             r.pop('raw_content', None)
#             r['score'] = str(r['score']) # Converting to string for cohere
#             search_results.append(r)
#         # print(results)
#         #search_results.extend(response['results'])
#     # print("search_results",search_results)
#     return search_results


# tools = [tavily_search]
# # tool_node = ToolNode(tools)

# tavily_client = AsyncTavilyClient(api_key=TAVILY_API_KEY)
# model = ChatOpenAI(model="gpt-4o-mini",temperature=0).bind_tools(tools)
# # model_with_tools = ChatCohere(model="command-r-plus", temperature=0).bind_tools(tools)

# tools_by_name = {tool.name: tool for tool in tools}
# async def tool_node(state: ResearchState):
#     docs = []
#     msgs = []
#     for tool_call in state["messages"][-1].tool_calls:
#         tool = tools_by_name[tool_call["name"]]
#         # print(tool)
#         observation = await tool.ainvoke(tool_call["args"])
#         # print(observation)
#         docs.extend(observation)
#         msgs.append(ToolMessage(content=f"Added documents: {observation}", tool_call_id=tool_call["id"]))
#     # print("inside tool:",docs)
#     return {"messages": msgs, "documents": docs}
    
        
# def call_model(state: ResearchState):
#     messages = state['messages']
#     # print("state['messages']:",state['messages'])
#     print("state['documents']:",state['documents'])
#     response = model.invoke(messages)
#     print(response)
#     # We return a list, because this will get added to the existing list
#     return {"messages": [response]}

# def call_model_with_docs(state: ResearchState):
#     messages = state['messages']
#     print("state['messages']:",state['messages'])
#     # print("state['documents']:",state['documents'])
#     response = model.with_structured_output(QuotedAnswer).invoke(input=messages)
#     print("response with docs:\n",response)
#     # We return a list, because this will get added to the existing list
#     return {"messages": [response]}
# #COHERE
# # def call_model_with_docs(state: ResearchState):
# #     messages = state['messages']
# #     # print("state['messages']:",state['messages'])
# #     print("state['documents']:",state['documents'])
# #     response = model.with_structured_output(QuotedAnswer).invoke(input=messages, 
# #                             preamble="""You are an expert write a coherent and deatiled response based on the user's question with the most relevant datasources.""",
# #                             documents=state['documents'])
# #     # We return a list, because this will get added to the existing list
# #     return {"messages": [response]}
    
# # Define the function that determines whether to continue or not
# def should_continue(state: ResearchState) -> Literal["tools", "RAG model"]:
#     messages = state['messages']
#     last_message = messages[-1]
#     # If the LLM makes a tool call, then we route to the "tools" node
#     if last_message.tool_calls:
#         return "tools"
#     # Otherwise, we stop (reply to the user)
#     return "RAG model"

# # Define a graph
# workflow = StateGraph(ResearchState)

# # Add nodes
# workflow.add_node("route_query", call_model)
# workflow.add_node("tools", tool_node)
# workflow.add_node("RAG model", call_model_with_docs)
# # Set the entrypoint as route_query
# workflow.set_entry_point("route_query")

# # Determine which node is called next
# workflow.add_conditional_edges(
#     "route_query",
#     # Next, we pass in the function that will determine which node is called next.
#     should_continue,
# )

# # Add a normal edge from `tools` to `route_query`.
# # This means that after `tools` is called, `route_query` node is called next.
# workflow.add_edge("tools", "route_query")
# workflow.add_edge("RAG model", END)

# app = workflow.compile()

In [43]:
        # tool descriptions that the model has access to
        # tools = [
        #    {
        #        "name": "tavily_search",
        #        "description": "Connect to a general/news web search engine to gather more information on user's query",
        #        "parameter_definitions": {
        #            "type": {
        #                "description": "type of search to run, 'general', 'news' or both",
        #                "type": "str",
        #                "required": True
        #            }
        #        }
        #    }
        # ]