In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
from pprint import pprint
from operator import add
import typing as t
import textwrap
from pydantic import BaseModel, Field

from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode

sys.path.append(str(Path().cwd().parent))

from llm_experiments.models import instantiate_chat
from llm_experiments.tools import BrowserManager

In [3]:
class State(t.TypedDict):
    messages: t.Annotated[list, add] = Field(description="The list of messages")
    is_done: t.Literal["true", "false"] = Field(description="Whether the conversation is done")
    next_action: str = Field(description="The next action to take")

In [4]:
class Plan(BaseModel):
    message: str = Field(description="The message to send to the user")
    next_action: str = Field(description="The next action to take")
    is_done: t.Literal["true", "false"] = Field(description="Whether the task is done")

In [None]:
async with BrowserManager() as b:
    model = instantiate_chat("4o-mini")
    model_w_tools = model.bind_tools(b.tools)

    tool_node = ToolNode(b.tools)


    def should_continue(state: State):
        if state["is_done"] == "true":
            return END
        if state["is_done"] == "false":
            return "browser_hanlder"


    def browser_hanlder(state: State):
        latest_message = state["messages"][-1]
        prompt = textwrap.dedent(f"""
            You are an AI browser automation agent responsible for executing web actions. 
            Your task is to analyze the latest message and determine how to interact with the browser.

            ### Context:
            - The planner has determined that a browser action is required.
            - You must analyze the tool calls provided and execute them accordingly.
            - If an action fails (e.g., page not found, age verification required), provide an alternative approach.

            ### Given Information:
            - **Latest instruction from planner:** 
            {latest_message}

            ### Your Task:
            - **Execute browser actions** based on the provided tool calls.
            - **Handle potential errors**, such as:
            - Page not found (404)
            - Age verification prompts
            - Navigation failures
            - If navigation succeeds, proceed to extract relevant information.
            - If an error occurs, suggest a recovery step or inform the user.
        """)
        print("\n\n", prompt, "\n\n")
        return {"messages": [model_w_tools.invoke(prompt)]}
    


    async def planner(state: State):
        prompt = textwrap.dedent(f"""
            You are an AI planner guiding a browser automation agent. 
            Your task is to analyze the conversation history and determine the next action.

            ### Context:
            - The agent interacts with a web browser and performs actions based on user messages.
            - You need to decide the best course of action for the agent.
            - Each message may contain a user query, an instruction, or feedback from previous steps.

            ### Given Information:
            - **Full conversation history:** 
            {state["messages"]}
            
            - **Latest user message:** 
            {state["messages"][-1]}

            ### Your Task:
            - Carefully analyze the intent of the latest message.
            - Decide on the next action:
            - **Use browser tools**: If the message requires web browsing, interacting with a website, or retrieving online content.
            - **Generate a response**: If the message requires answering without web interaction.
            - **End the session**: If the conversation is complete.
            - Clearly specify the next action and include a short explanation of why it was chosen.
        """)

        res = await model_w_tools.with_structured_output(Plan).ainvoke(prompt)
        return {"message": res.message, "next_action": res.next_action, "is_done": res.is_done}


    workflow = StateGraph(State)

    workflow.add_node("planner", planner)
    workflow.add_node("browser_hanlder", browser_hanlder)
    workflow.add_node("tool_node", tool_node)

    workflow.add_edge(START, "planner")
    workflow.add_conditional_edges("planner", should_continue, ["browser_hanlder", END])
    workflow.add_edge("browser_hanlder", "tool_node")
    workflow.add_edge("tool_node", "planner")

    app = workflow.compile()


    async for chunk in app.astream({"messages": ["find the most popular video on youtube"]}):
        print("agent: ", chunk)