# Basic Multi-agent Collaboration

A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. 

One way to approach complicated tasks is through a "divide-and-conquer" approach: create a "specialized agent" for each task or domain and route tasks to the correct "expert". This means that each agent can become a sequence of LLM calls that chooses how to use a specific "tool".

This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using Burr.


In [1]:
# %pip install -U burr[start] langchain-community langchain-core langchain-experimental openai sf-hamilton[visualization]

In [2]:
# Environment variables
import os
# Make sure TAVILY_API_KEY & OPENAI_API_KEY are set
# os.environ['TAVILY_API_KEY'] = 'your_tavily_api_key' # get one at https://tavily.com
# os.environ['OPENAI_API_KEY'] = 'your_openai_api_key' # get one at https://platform.openai.com

In [20]:
# import everything that you'll need
import json
import pprint

from hamilton import driver
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.utilities import PythonREPL

from burr import core
from burr.core import ApplicationBuilder, State, action, default
from burr.lifecycle import PostRunStepHook
from burr.tracking import client as burr_tclient

# our hamilton module -- see below
import func_agent

 # Define the tools that the agents will use

Here we construct the python objects that will be used as tools by our code.

In [4]:
repl = PythonREPL()
tavily_tool = TavilySearchResults(max_results=5)

def python_repl(code: str) -> dict:
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `print(...)`. This is visible to the user.

    :param code: string. The python code to execute.
    :return: the output
    """
    try:
        result = repl.run(code)
    except BaseException as e:
        return {"error": repr(e), "status": "error", "code": f"```python\n{code}\n```"}
    return {"status": "success", "code": f"```python\n{code}\n```", "Stdout": result}

# These are our tools that we will use in the application.
tools = [tavily_tool, python_repl]

# Define the agents
Our "agents" are effectively an execution of a series of LLM calls. 
In this example we use Hamilton to orchestrate this series of LLM calls.

For the source code and to see all the prompts used, see [func_agent.py](./func_agent.py). You'll see the structure of
this as the output of the next cell.

In [5]:
# The Agent that we'll use. Our agents here only differ by the system message passed in.
agent_dag = driver.Builder().with_modules(func_agent).build()
agent_dag

# Define the actions that map to agents
We now then create specific actions that map to the agents we need for this not example.
    
We want a "chart generator" action, that will map to an agent that can generate a chart based on provided context/data.

We want a "researcher" action, that will map to an agent that can search for information on a topic.

We then want a "tool_node" action, that will run a tool as specified by the prior action, i.e. agent.

In [7]:
@action(reads=["query", "messages"], writes=["messages"])
def chart_generator(state: State) -> tuple[dict, State]:
    """The chart generator action.

    :param state: state of the application
    :return: 
    """
    query = state["query"]
    result = agent_dag.execute(
        ["parsed_tool_calls", "llm_function_message"],
        inputs={
            "tools": [python_repl],
            "system_message": "Any charts you display will be visible by the user. When done say 'FINAL ANSWER'.",
            "user_query": query,
            "messages": state["messages"],
        },
    )
    new_message = result["llm_function_message"]
    parsed_tool_calls = result["parsed_tool_calls"]
    state = state.update(parsed_tool_calls=parsed_tool_calls)
    state = state.append(messages=new_message)
    state = state.update(sender="chart_generator")
    return result, state

@action(reads=["query", "messages"], writes=["messages"])
def researcher(state: State) -> tuple[dict, State]:
    """The researcher action.

    :param state: state of the application
    :return:
    """
    query = state["query"]
    result = agent_dag.execute(
        ["parsed_tool_calls", "llm_function_message"],
        inputs={
            "tools": [tavily_tool],
            "system_message": "You should provide accurate data for the chart generator to use. When done say 'FINAL ANSWER'.",
            "user_query": query,
            "messages": state["messages"],
        },
    )
    new_message = result["llm_function_message"]
    parsed_tool_calls = result["parsed_tool_calls"]
    state = state.update(parsed_tool_calls=parsed_tool_calls)
    state = state.append(messages=new_message)
    state = state.update(sender="researcher")
    return result, state


@action(reads=["messages", "parsed_tool_calls"], writes=["messages", "parsed_tool_calls"])
def tool_node(state: State) -> tuple[dict, State]:
    """Given a tool call, execute it and return the result."""
    new_messages = []
    parsed_tool_calls = state["parsed_tool_calls"]

    for tool_call in parsed_tool_calls:
        tool_name = tool_call["function_name"]
        tool_args = tool_call["function_args"]
        tool_found = False
        for tool in tools:
            name = getattr(tool, "name", None)
            if name is None:
                name = tool.__name__
            if name == tool_name:
                tool_found = True
                kwargs = json.loads(tool_args)
                # Execute the tool!
                if hasattr(tool, "_run"):
                    result = tool._run(**kwargs)
                else:
                    result = tool(**kwargs)
                new_messages.append(
                    {
                        "tool_call_id": tool_call["id"],
                        "role": "tool",
                        "name": tool_name,
                        "content": result,
                    }
                )
        if not tool_found:
            raise ValueError(f"Tool {tool_name} not found.")

    for tool_result in new_messages:
        state = state.append(messages=tool_result)
    state = state.update(parsed_tool_calls=[])
    # We return a list, because this will get added to the existing list
    return {"messages": new_messages}, state

@action(reads=[], writes=[])
def terminal_step(state: State) -> tuple[dict, State]:
    """Terminal step we have here that does nothing, but it could"""
    return {}, state

# Define the Graph / Application
With Burr we need to now construct our application, i.e. graph, by:

1. Defining what the actions are and how to transition between them.
2. Defining the initial state of the application. In our example this means we need to provide a "query" for the agents to work on.

Because Burr comes with built in persistence, we can also load a prior execution and continue from 
any point in its history by specifying a `app_instance_id` and `sequence_number` when building the application.

In [11]:
# Adjust these if you want to load a prior execution
app_instance_id = None
sequence_id = None
project_name = "demo:hamilton-multi-agent-v1"

# CHANGE THIS IF YOU WANT SOMETHING DIFFERENT!
default_query = ("Fetch the UK's GDP over the past 5 years, then draw a line graph of it. "
                 "Once the python code has been written and the graph drawn, the task is complete.")

In [14]:
# Determine initial state and entry point
def default_state_and_entry_point() -> tuple[dict, str]:
    """Returns the default state and entry point for the application."""
    return {
        "messages": [],
        "query": default_query,
        "sender": "",
        "parsed_tool_calls": [],
    }, "researcher"

if app_instance_id:
    tracker = burr_tclient.LocalTrackingClient(project_name)
    persisted_state = tracker.load("demo", app_id=app_instance_id, sequence_no=sequence_id)
    if not persisted_state:
        print(f"Warning: No persisted state found for app_id {app_instance_id}.")
        state, entry_point = default_state_and_entry_point()
    else:
        state = persisted_state["state"]
        entry_point = persisted_state["position"]
else:
    state, entry_point = default_state_and_entry_point()

In [13]:
# Build the application 
def build_application(state: dict, entry_point: str):
    _app = (
        ApplicationBuilder()
        # set the actions
        .with_actions(
            researcher=researcher,
            chart_generator=chart_generator,
            tool_node=tool_node,
            terminal=terminal_step,
        )
        # set the transitions
        .with_transitions(
            ("researcher", "tool_node", core.expr("len(parsed_tool_calls) > 0")),
            (
                "researcher",
                "terminal",
                core.expr("'FINAL ANSWER' in messages[-1]['content']"),
            ),
            ("researcher", "chart_generator", default),
            ("chart_generator", "tool_node", core.expr("len(parsed_tool_calls) > 0")),
            (
                "chart_generator",
                "terminal",
                core.expr("'FINAL ANSWER' in messages[-1]['content']"),
            ),
            ("chart_generator", "researcher", default),
            ("tool_node", "researcher", core.expr("sender == 'researcher'")),
            ("tool_node", "chart_generator", core.expr("sender == 'chart_generator'")),
        )
        # set a few other things
        .with_identifiers(partition_key="demo")
        .with_state(**state)
        .with_entrypoint(entry_point)
        .with_tracker(project=project_name)
        .build()
    )
    return _app
app = build_application(state, entry_point)
app.visualize(
    output_file_path="statemachine", include_conditions=True, format="png"
)

# open up the Burr UI to trace the execution
In another terminal run:
```bash
burr
```
and then open up the browser to [http://localhost:7241](http://localhost:7241) to see the execution of the application.

In [11]:
# this will run until completion.
last_action, last_result, last_state = app.run(halt_after=["terminal"])

In [None]:
pprint.pprint(last_state)

# Change the Query!
Right now we provide the starting query as state. So we just create a new application by adjusting 
the initial state we provide.

In [16]:
# Let's change the query
state["query"] = ("Fetch the USA's GDP over the past 5 years, then draw a line graph of it. "
                 "Once the python code has been written and the graph drawn, the task is complete.")
app2 = build_application(state, entry_point)

In [17]:
# this will run until completion.
last_action, last_result, last_state = app2.run(halt_after=["terminal"])

In [21]:
pprint.pprint(last_state)