In [None]:
#supervisor routes b/w different independant agents

import getpass
import os
import pandas as pd
import sqlite3

def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")


_set_if_undefined("OPENAI_API_KEY") 
_set_if_undefined("LANGCHAIN_API_KEY") 
#_set_if_undefined("TAVILY_API_KEY")

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Agent supervisor"

In [34]:
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDataBaseTool,
)
df = pd.read_csv("salary_data.csv")
df.head()

Unnamed: 0,Age,Gender,Education Level,Job Title,Years of Experience,Salary
0,32.0,Male,Bachelor's,Software Engineer,5.0,90000.0
1,28.0,Female,Master's,Data Analyst,3.0,65000.0
2,45.0,Male,PhD,Senior Manager,15.0,150000.0
3,36.0,Female,Bachelor's,Sales Associate,7.0,60000.0
4,52.0,Male,Master's,Director,20.0,200000.0


In [46]:
#connection = sqlite3.connect("salaries.db")
#df.to_sql(name="salaries", con=connection)

In [35]:
from langchain_community.utilities.sql_database import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///salaries.db")

### Create Tools
For this example, you will make an agent to do web research with a search engine, and one agent to create plots. Define the tools they'll use below:


In [36]:
from typing import Annotated

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool

#tavily_tool = TavilySearchResults(max_results=5)

# This executes code locally, which can be unsafe
#python_repl_tool = PythonREPLTool()

### Helper Utilities
Define a helper function below, which make it easier to add new agent worker nodes.

In [37]:
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_openai import AzureChatOpenAI
from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    ToolMessage,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from langgraph.graph import END, StateGraph, START

def create_agent(llm: AzureChatOpenAI, tools: list, system_prompt: str):
    # Each worker node will be given a name and some tools.
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                system_prompt,
            ),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    agent = create_openai_tools_agent(llm, tools, prompt)
    executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
    return executor

We can also define a function that we will use to be the nodes in the graph - it takes care of converting the agent response to a human message. This is important because that is how we will add it to the global state of the graph

In [38]:
def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {"messages": [HumanMessage(content=result["output"], name=name)]}

#### Create Agent Supervisor
It will use function calling to choose the next worker node OR finish processing.

In [None]:
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

members = ["sql_developer", "chart_generator"]
system_prompt = (
    "You are a supervisor tasked with managing a conversation between the"
    " following workers:  {members}. Given the following user request,"
    " respond with the worker to act next. Each worker will perform a"
    " task and respond with their results and status. When finished,"
    " respond with FINISH."
)
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = ["FINISH"] + members
# Using openai function calling can make output parsing easier for us
function_def = {
    "name": "route",
    "description": "Select the next role.",
    "parameters": {
        "title": "routeSchema",
        "type": "object",
        "properties": {
            "next": {
                "title": "Next",
                "anyOf": [
                    {"enum": options},
                ],
            }
        },
        "required": ["next"],
    },
}
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Given the conversation above, who should act next?"
            " Or should we FINISH? Select one of: {options}",
        ),
    ]
).partial(options=str(options), members=", ".join(members))

llm = AzureChatOpenAI(model="model", openai_api_key = "your_openai_api_key", 
                    openai_api_type="type", 
                    azure_endpoint="endpoint",
                    api_version="version")

supervisor_chain = (
    prompt
    | llm.bind_functions(functions=[function_def], function_call="route")
    | JsonOutputFunctionsParser()
)

In [40]:
from typing import Annotated

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL

#tavily_tool = TavilySearchResults(max_results=5)

# Warning: This executes code locally, which can be unsafe when not sandboxed

repl = PythonREPL()

@tool
def python_repl( # A python tool which can run python code.
    code: Annotated[str, "The python code to execute to generate your chart."],
):
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `print(...)`. This is visible to the user."""
    try:
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
    return (
        result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
    )

@tool("list_tables")
def list_tables() -> str:
    """List the available tables in the database"""
    return ListSQLDatabaseTool(db=db).invoke("")

@tool("tables_schema")
def tables_schema(tables: str) -> str:
    """
    Input is a comma-separated list of tables, output is the schema and sample rows
    for those tables. Be sure that the tables actually exist by calling `list_tables` first!
    Example Input: table1, table2, table3
    """
    tool = InfoSQLDatabaseTool(db=db)
    return tool.invoke(tables)

@tool("execute_sql")
def execute_sql(sql_query: str) -> str:
    """Execute a SQL query against the database. Returns the result"""
    return QuerySQLDataBaseTool(db=db).invoke(sql_query)

@tool("check_sql")
def check_sql(sql_query: str) -> str:
    """
    Use this tool to double check if your query is correct before executing it. Always use this
    tool before executing a query with `execute_sql`.
    """
    return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})

In [41]:
import functools
import operator
from typing import Sequence, TypedDict

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from langgraph.graph import END, StateGraph, START


# The agent state is the input to each node in the graph
class AgentState(TypedDict):
    # The annotation tells the graph that new messages will always
    # be added to the current states
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # The 'next' field indicates where to route to next
    next: str

#sql developer
sql_dev = create_agent(llm, 
                       tools=[list_tables, tables_schema, check_sql, execute_sql], 
                       system_prompt=""" 
                        You are an experienced database engineer who is master at creating efficient and complex SQL queries. 
                        You have a deep understanding of how different databases work and how to optimize queries. 
                        Use the `list_tables` to find available tables. 
                        Use the `tables_schema` to understand the metadata for the tables. 
                        Use the `check_sql` to check your queries for correctness. 
                        Use the `execute_sql` to execute queries against the database.
                        Your main goal is to construct and execute sql queries based on a request    
                        """)
sql_dev_node = functools.partial(agent_node, agent = sql_dev, name="sql_developer")

# chart_generator
chart_generator = create_agent(
    llm,
    tools = [python_repl],
    system_prompt= """ 
                    You are a helpful AI assistant expert that writes and executes Python scripts to visualize data in a dataframe.
                    Analyze the dataframe given to you carefully and write a script to visualize the data as asked by the user, then execute the script.
                    Import the necessary libraries for plotting. Return the path to the saved image as answer. 
                    The final answer should not have a parse-able action.
                    Use the 'python_repl' to execute a python script.
                    Before using plt.show(), save the image using plt.savefig('img.png').
                    If you encounter errors multiple times, try changing the approach.
                    Close the plot with plt.close(). Return the path of the saved image in answer.
                    If the user does not request for a pie chart or bar chart or scatter plot, then don't do anything. 
                    """,
)

# summary_writer
summary_writer = create_agent(
    llm,
    tools = [],
    system_prompt="You are a summary writer. Your task is to provide concise and accurate summaries."
)

chart_node = functools.partial(agent_node, agent = chart_generator, name="chart_generator")

workflow = StateGraph(AgentState)
workflow.add_node("sql_developer", sql_dev_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("supervisor", supervisor_chain)

<langgraph.graph.state.StateGraph at 0x2c43c696c80>

In [42]:
for member in members:
    # We want our workers to ALWAYS "report back" to the supervisor when done
    workflow.add_edge(member, "supervisor")
# The supervisor populates the "next" field in the graph state
# which routes to a node or finishes
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
# Finally, add entrypoint
workflow.add_edge(START, "supervisor")

graph = workflow.compile()

In [43]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(content="What are the number of employees doing data scientist job?")
        ]
    }
):
    if "__end__" not in s:
        print(s)
        print("----")

{'supervisor': {'next': 'sql_developer'}}
----
{'sql_developer': {'messages': [HumanMessage(content='The number of employees doing the job of a Data Scientist is 453.', additional_kwargs={}, response_metadata={}, name='sql_developer')]}}
----
{'supervisor': {'next': 'FINISH'}}
----


In [44]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(content="Draw a bar chart showing the effect of salary on employee experience")
        ]
    }
):
    if "__end__" not in s:
        print(s)
        print("----")

{'supervisor': {'next': 'sql_developer'}}
----
{'sql_developer': {'messages': [HumanMessage(content='Here is the data showing the average salary by years of experience:\n\n| Years of Experience | Average Salary       |\n|---------------------|----------------------|\n| 0.0                 | 29680.23             |\n| 0.5                 | 35000.00             |\n| 1.0                 | 46992.85             |\n| 1.5                 | 36279.17             |\n| 2.0                 | 58699.46             |\n| 3.0                 | 72944.41             |\n| 4.0                 | 83332.09             |\n| 5.0                 | 103111.09            |\n| 6.0                 | 111891.15            |\n| 7.0                 | 122108.23            |\n| 8.0                 | 126438.14            |\n| 9.0                 | 138021.46            |\n| 10.0                | 131690.32            |\n| 11.0                | 153060.32            |\n| 12.0                | 153398.06            |\n| 13.0      