In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_ibm import WatsonxLLM
from dotenv import load_dotenv
import os
from IPython.display import Image, display
from typing_extensions import TypedDict
from typing import Sequence
import functools
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from ibm_watsonx_ai.foundation_models.utils.enums import DecodingMethods

# Load environment variables from .env file
load_dotenv()

# Define the object passed between nodes
class AgentState(TypedDict):
    messages: Sequence[AIMessage]
    sender: str
    user_query: str  # Add user_query to the AgentState

# Tool Initialization for the Researcher
tavily_tool = TavilySearchResults(max_results=5)  # Tool to fetch search results from the internet

# WatsonxLLM parameters
parameters = {
    GenParams.DECODING_METHOD: DecodingMethods.SAMPLE.value,
    GenParams.MAX_NEW_TOKENS: 1000,
    GenParams.MIN_NEW_TOKENS: 50,
    GenParams.TEMPERATURE: 0.7,
    GenParams.TOP_K: 50,
    GenParams.TOP_P: 1
}

# Load API key and project ID from environment variables
watsonx_api_key = os.getenv("WATSONX_API_KEY")
project_id = os.getenv("PROJECT_ID")
url = "https://us-south.ml.cloud.ibm.com"  # Replace with your region's URL

if not watsonx_api_key:
    raise ValueError("Please set the WATSONX_API_KEY in your .env file.")
if not project_id:
    raise ValueError("Please set the PROJECT_ID in your .env file.")

# Define LLM models using WatsonxLLM
llm_creator = WatsonxLLM(
    model_id="ibm/granite-13b-instruct-v2",  # Choose your desired model
    url=url,
    apikey=watsonx_api_key,
    project_id=project_id,
    params=parameters
)

llm_router = WatsonxLLM(
    model_id="ibm/granite-13b-instruct-v2",  # Choose your desired model
    url=url,
    apikey=watsonx_api_key,
    project_id=project_id,
    params=parameters
)

# Router Agent Node (modified)
def router_agent_node(state, name):
    """Router agent that selects either 'Researcher' or 'Creator' based on the message."""
    query = state["messages"][-1].content

    # Router prompt that instructs the LLM to choose the appropriate agent
    router_prompt = (
        "You are a routing agent. Your task is to select one of two agents based on the user input.\n"
        "If the user query is about recent information or data like prices, select 'Researcher'.\n"
        "If the user query is more general or requires a knowledge-based response, select 'Creator'.\n\n"
        f"User Query: {query}\n\n"
        "Answer with only one word: 'Researcher' or 'Creator'."
    )

    # Generate the response using the LLM
    selected_agent = llm_router.predict(router_prompt).strip().lower()

    # Normalize the output to either 'Researcher' or 'Creator'
    if 'researcher' in selected_agent:
        selected_agent = 'Researcher'
    elif 'creator' in selected_agent:
        selected_agent = 'Creator'
    else:
        raise ValueError(f"Unexpected agent response: {selected_agent}")

    return {
        "messages": [AIMessage(content=f"Router Agent: Selected {selected_agent}", name=name)],
        "sender": name,
        "selected_agent": selected_agent,  # Include the selected agent in the state
        "user_query": query,  # Store the user query in the state
    }

# Researcher Agent Node
def researcher_agent_node(state, name):
    """Researcher agent that fetches data from the internet."""
    query = state["user_query"]  # Access user_query from the state
    search_results = tavily_tool.invoke(query)  # Perform search
    result = f"Researcher Agent: Fetched search results for '{query}': {search_results}"
    return {
        "messages": [AIMessage(content=result, name=name)],
        "sender": name,
    }

# Creator Agent Node (LLM)
def creator_agent_node(state, name):
    """Creator agent that uses a language model to generate responses."""
    query = state["user_query"]  # Access user_query from the state
    # Include the user query in the prompt for the Creator
    creator_prompt = f"Please answer the following query: {query}"
    result = llm_creator.predict(creator_prompt)  # Generate response using the LLM
    return {
        "messages": [AIMessage(content=f"Creator Agent: {result}", name=name)],
        "sender": name,
    }

# Define nodes for each agent
research_node = functools.partial(researcher_agent_node, name="Researcher")
creator_node = functools.partial(creator_agent_node, name="Creator")

# Define the graph and its nodes
workflow = StateGraph(AgentState)
workflow.add_node("Router", functools.partial(router_agent_node, name="Router"))
workflow.add_node("Researcher", research_node)
workflow.add_node("Creator", creator_node)

# Routing logic based on Router Agent's decision
def extract_selected_agent(state) -> str:
    """Extracts the selected agent ('Researcher' or 'Creator') from the Router agent's response."""
    return state["selected_agent"]

# Add conditional edges based on the Router agent's response
workflow.add_conditional_edges("Router", extract_selected_agent, {"Researcher": "Researcher", "Creator": "Creator"})

# After either Researcher or Creator is done, the program ends
workflow.add_conditional_edges("Researcher", lambda state: "__end__", {"__end__": END})
workflow.add_conditional_edges("Creator", lambda state: "__end__", {"__end__": END})

# Set entry point
workflow.set_entry_point("Router")

# Compile the workflow for execution
graph = workflow.compile()

# Function to display the architecture
def display_architecture():
    """Display the graph architecture."""
    try:
        display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
    except:
        print("Unable to display graph architecture. Extra dependencies might be missing.")

# Function to evaluate a message and run the system
def evaluate_message(message: str):
    """Evaluate the message and route it through the system."""
    try:
        events = graph.stream(
            {
                "messages": [
                    HumanMessage(content=message)
                ]
            },
            {"recursion_limit": 50}
        )

        for s in events:
            print(s)
            print("----")
        print("Final")
    except Exception as e:
        print(f"Error during evaluation: {e}")

# Display the architecture
display_architecture()

# Example call to test the system
evaluate_message("Fetch the bitcoin price over the past 5 days.")  # Researcher case
evaluate_message("Explain what Bitcoin is.")  # Creator case
