In [35]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain.chains.openai_functions import create_structured_output_runnable
from enum import Enum
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict, Union, Literal
from langchain_core.messages import BaseMessage, HumanMessage
import operator
from langchain.pydantic_v1 import BaseModel
import os

class Worker(Enum):
    CODE = 'code_worker'
    SUMMARY = 'summary_worker'

    @classmethod
    def get_description(cls, worker: 'Worker'):    
        return {
            Worker.CODE.value: 'A coding worker/agent that can produce and execute Python code',
            Worker.SUMMARY.value: 'A summary writing worker/agent that can produce summaries of human-AI conversations'
        }.get(worker.value)

class AgentState(TypedDict):
    initial_query: str
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: Worker


def create_agent_supervisor(workers: Sequence[Worker]):
    system_prompt = (
        "You are a supervisor tasked with managing a conversation between the"
        " following workers:\n{workers}\n\n Given the following user request,"
        " respond with the worker to act next. Each worker will perform a"
        " task and respond with their results and status. When finished,"
        " respond with FINISH."
    )

    options = ["FINISH"] + [m.value for m in workers]

    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                },
            },
            "required": ["next"],
        },
    }

    bullet_point_list = "\n".join(f"{i+1}) {worker.value} - {Worker.get_description(worker)}"
                                  for i, worker in enumerate(workers))

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                ( 
                    "Given the conversation above, who should act next, if any?"
                    ' Return "FINISH" if the initial human query ("{initial_query}") has been answered?\n\n'
                    "Select one of: {options}\n\n"
                )
            )
        ]
    ).partial(options=options, workers=bullet_point_list)

    llm = ChatOpenAI(model=os.getenv('GPT4_MODEL_NAME'), streaming=True)

    return (
        prompt
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

In [36]:
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import HumanMessage
from langchain.tools import BaseTool
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser


def create_agent(llm: ChatOpenAI, tools: Sequence[BaseTool], system_prompt: str):
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                system_prompt,
            ),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )

    if len(tools) == 0:
        return (
            ChatPromptTemplate.from_template(system_prompt)
            | llm
            | StrOutputParser()
        )

    agent = create_openai_tools_agent(llm, tools, prompt)
    executor = AgentExecutor(agent=agent, tools=tools)
    return executor


async def agent_node(state: AgentState, agent, name):
    result = await agent.ainvoke(state)
    return {"messages": [HumanMessage(content=result['output'], name=name)]}

In [37]:
import functools
from langgraph.graph import StateGraph, END
from langchain_experimental.tools import PythonREPLTool
import os

supervisor_chain = create_agent_supervisor([e for e in Worker])

llm = ChatOpenAI(model=os.getenv('GPT3_MODEL_NAME'), streaming=True)

summary_agent = create_agent(llm, [], "You are a summary agent. Write a summary of the conversation so far: {messages}")
summary_node = functools.partial(
    agent_node, agent=summary_agent, name="Summarizer")

code_agent = create_agent(
    llm,
    [PythonREPLTool()],
    "You are a coding agent.",
)
code_node = functools.partial(agent_node, agent=code_agent, name="Coder")

SUPERVISOR = 'supervisor'

workflow = StateGraph(AgentState)
workflow.add_node(Worker.SUMMARY.value, summary_node)
workflow.add_node(Worker.CODE.value, code_node)
workflow.add_node(SUPERVISOR, supervisor_chain)

In [38]:
for worker in Worker:
    workflow.add_edge(worker.value, SUPERVISOR)

conditional_map = {k.value: k.value for k in Worker}
conditional_map["FINISH"] = END
workflow.add_conditional_edges(
    SUPERVISOR, lambda x: x["next"], conditional_map)

workflow.set_entry_point(SUPERVISOR)

app = workflow.compile()

In [39]:
# messages = [HumanMessage(content="What is 576 * 5 / 13?")]
# async for event in app.astream_events(
#     {
#         'initial_query': messages[0].content,
#         "messages": messages
#     },
#     {"recursion_limit": 100},
#     version="v1"
# ):
#     kind = event["event"]
#     print(event)
    # if kind == "on_chat_model_stream":
    #     content = event["data"]["chunk"].content
    #     if content:
    #         # Empty content in the context of OpenAI means
    #         # that the model is asking for a tool to be invoked.
    #         # So we only print non-empty content
    #         print(content, end="|")
    # elif kind == "on_tool_start":
    #     print("--")
    #     print(
    #         f"Starting tool: {event['name']} with inputs: {event['data'].get('input')}"
    #     )
    # elif kind == "on_tool_end":
    #     print(f"Done tool: {event['name']}")
    #     print(f"Tool output was: {event['data'].get('output')}")
    #     print("--")

In [40]:
initial_msg = "What are the first 30 numbers of the sequence a^3 - 1/a?"

async for s in app.astream(
    {
        'initial_query': initial_msg,
        "messages": [HumanMessage(
        content=initial_msg)]},
    {"recursion_limit": 100},
):
    if "__end__" not in s:
        print(s)
        print("----")

{'supervisor': {'next': 'code_worker'}}
----
{'code_worker': {'messages': [HumanMessage(content='Here are the first 30 numbers of the sequence \\( a^3 - \\frac{1}{a} \\):\n\n1. 0\n2. 7\n3. 26\n4. 63\n5. 124\n6. 215\n7. 342\n8. 511\n9. 728\n10. 999\n11. 1330\n12. 1727\n13. 2206\n14. 2773\n15. 3434\n16. 4195\n17. 5062\n18. 6041\n19. 7138\n20. 8360\n21. 9713\n22. 11204\n23. 12839\n24. 14624\n25. 16565\n26. 18668\n27. 20939\n28. 23384\n29. 26009\n30. 28820\n\nThese are the values obtained by substituting values from 1 to 30 into the expression \\( a^3 - \\frac{1}{a} \\).', name='Coder')]}}
----
{'supervisor': {'next': 'FINISH'}}
----
