# Import

 devo gestire in maniera diversa query agent 

In [None]:
import getpass
import os
from sqlite_dataset import SQLiteDataset, Field, String, Float, Integer
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import base64  

from typing import Any, Annotated, Literal, Union, Dict, List
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langchain_core.tools import tool, StructuredTool
from langchain_core.messages import ToolMessage, AIMessage, SystemMessage, HumanMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import AzureChatOpenAI

from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages

In [None]:
def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("AZURE_OPENAI_API_KEY")
_set_env("GPT_URL")

In [None]:
db = SQLDatabase.from_uri('sqlite:///sales_dataset.db')
print(db.dialect)
print(db.get_usable_table_names())
# db.run("SELECT * FROM sales LIMIT 10;")

# utils 

In [None]:
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

# print(list_tables_tool.invoke(""))
# print(get_schema_tool.invoke("sales"))

In [None]:
@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


# print(db_query_tool.invoke("SELECT * FROM sales LIMIT 10;"))

In [None]:
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [db_query_tool], tool_choice="required"
)

#query_check.invoke({"messages": [("user", "SELECT * FROM sales LIMIT 10;")]})

In [None]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    

In [None]:
# Define a new graph
workflow = StateGraph(State)

# Add a node for the first tool call -> The agent will first force-call the list_tables_tool to fetch the available tables from the database
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }
    
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}



In [None]:
workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))


In [None]:
# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [get_schema_tool]
)
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)


In [None]:
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")

# chart agent

In [None]:
@tool
def _plot_data(
    data_str: str,
    code: str
) -> str:
    """Executes plotting code on a DataFrame."""
    plt.figure()
    img_bytes = BytesIO()
    
    try:
        # Convert dict to DataFrame
        
        data_list = eval(data_str)
        # Creazione del DataFrame
        df = pd.DataFrame(data_list)
        
        # Create a clean namespace for execution
        namespace = {
            'df': df,
            'plt': plt,
            'pd': pd
        }
        
        # Execute the provided code
        exec(code, namespace)
        # img_path = 'temp_plot.png'
        # plt.savefig(img_path, format='png', bbox_inches='tight', dpi=300)
        # plt.close()
        
    except Exception as e:
        #plt.close()
        return f"Failed to execute. Error: {repr(e)}"
        
    return "\nPlot Genereted. img_path:\ntemp_plot.png"
    
# Add a node for a model to generate a query based on the question and schema
chart_gen_system = """You are a Python expert specialized in data visualization.

Given an input string of data and a specific visualization, output a syntactically correct Python code to run to:
- create that plot
- save the image in temp_plot.png
- close the plot

When generating the code:

Output the Python code that answers the input question that call the proper plot_data tool passing the code and data.
"""


chart_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", chart_gen_system), ("placeholder", "{messages}")]
)
chart_gen = chart_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [_plot_data]
)


def chart_gen_node(state: State):
    message = chart_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] not in  ['_plot_data']:
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call _plot_data to create the chart.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

# Add a node for a model to generate a query based on the question and schema
chart_analyzer_system = """You are an expert data analyst.
Given the image of a plot as input, output a comment to that plot that helps finding insights on the data represented.
You can talk about distributions, anomalies, minimum or maximus values and so on, but stick to the data you see.
When you have generated the comment call the SubmitFinalAnswer tool to show your response to the user.

Do not add this kind of text in the output (example):
![Bar Graph](temp_plot.png) 
since we don't need it, just respond with a comment on the plot.
"""

prompt_messages = [
    SystemMessage(content=chart_analyzer_system),
    HumanMessagePromptTemplate.from_template(
        template=[
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{img_base64}"}},
        ]
    ),
    MessagesPlaceholder("messages"),
]

chart_analyzer_prompt = ChatPromptTemplate(messages=prompt_messages)


chart_analyzer = chart_analyzer_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [SubmitFinalAnswer]
)


def chart_analyze_node(state: State):
    for message in state["messages"]:
        if isinstance(message, ToolMessage) and message.name == "_plot_data":
            img_path = message.content.split("\n")[-1].strip()
            break
    
    if img_path:
        with open(img_path, 'rb') as f:
            img_bytes = f.read()
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            
        message = chart_analyzer.invoke({'messages': state['messages'], "img_base64": img_base64})
       
        # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
        tool_messages = []
        if message.tool_calls:
            for tc in message.tool_calls:
                if tc["name"] not in ['SubmitFinalAnswer']:
                    tool_messages.append(
                        ToolMessage(
                            content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit final answer.",
                            tool_call_id=tc["id"],
                        )
                    )
        else:
            tool_messages = []
        return {"messages": [message] + tool_messages}


workflow.add_node("chart_gen_node", chart_gen_node)
# Add nodes for the first two tools
workflow.add_node(
    "_plot_data", create_tool_node_with_fallback([_plot_data])
)
workflow.add_node("chart_analyze_node", chart_analyze_node)

# query agent

In [None]:

# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then return the results of the query using SubmitFinalAnswer tool.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply return the results of the query using SubmitFinalAnswer tool.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""


query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o")



# def query_gen_node(state: State):
    
#     message = query_gen.invoke(state)
#     print('\n\n\n response query gen', message)
#     # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
#     tool_messages = []
#     # if message.tool_calls:
#     #     for tc in message.tool_calls:
#     #         if tc["name"] != "SubmitFinalAnswer":
#     #             tool_messages.append(
#     #                 ToolMessage(
#     #                     content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
#     #                     tool_call_id=tc["id"],
#     #                 )
#     #             )
#     # else:
#     #     tool_messages = []
#     return {"messages": [message] + tool_messages}


def query_gen_node(state: State):
    message = query_gen.invoke(state)
    tool_messages = []

    if message.tool_calls:
        for tool_call in message.tool_calls:
            if tool_call["name"] == "SubmitFinalAnswer":
                # Call the SubmitFinalAnswer tool and get the response
                submit_final_answer_response = submit_final_answer_tool.invoke(tool_call)
                tool_messages.append(submit_final_answer_response)
            else:
                # Handle any other unexpected tool calls
                tool_messages.append(ToolMessage(
                    content=f"Error: The wrong tool was called: {tool_call['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                    tool_call_id=tool_call["id"],
                ))
    else:
        # No tool calls, so just return the message
        tool_messages = [message]

    return {"messages": tool_messages}


# Main agent

In [None]:
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field
from typing import Dict, Any

# Definiamo lo schema degli input per il tool
class ChartGenInput(BaseModel):
    state: str = Field(
        description="The current state containing query results as string"
    )

# Ora creiamo il tool con lo schema definito
chart_gen_tool = StructuredTool(
    name="chart_gen_tool",
    description="Generates a visualization of the data when needed. Use this when the user asks for a visual representation.",
    func=chart_gen_node,
    args_schema=ChartGenInput,
    return_direct=False
)


# Definiamo lo schema degli input per il tool
class QueryGenInput(BaseModel):
    state: str = Field(
        description="The current state containing input question as string"
    )

# Ora creiamo il tool con lo schema definito
query_gen_tool = StructuredTool(
    name="query_gen_tool",
    description="Generates a SQL query in SQLite. Use this when you need to interrogate a database.",
    func=query_gen_node,
    args_schema=QueryGenInput,
    return_direct=False
)

In [None]:

# Add a node for a model to generate a query based on the question and schema
main_agent_system = """You are an expert data analyst who is asked to perform some analysis on a database.

Given an input question, you will decide where to route your actions:
- if you have already information in memory to answer the question, call SubmitFinalAnswer tool
- if you need to retrieve data before answering call the query_gen_node tool to create an SQL query to interrogate the database.

Then once you have the results, call either:
- SubmitFinalAnswer if with the data retrieved after query_gen_tool you can answer input question
- chart_gen_node if with the data retrieved you must create a plot to answer input question

DO NOT call any tool besides the ones listed.
Remember that chart_gen_node needs the resulting data as a string to create the plot.

"""


main_agent_prompt = ChatPromptTemplate.from_messages(
    [("system", main_agent_system), ("placeholder", "{messages}")]
)
main_agent = main_agent_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [SubmitFinalAnswer, query_gen_node, chart_gen_node]
)


def main_agent_node(state: State):
    
    message = main_agent.invoke(state)
    
    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] not in ["SubmitFinalAnswer", "query_gen_node", "chart_gen_node"]:
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer query_gen_node or chart_gen_node.",
                        tool_call_id=tc["id"],
                    )
                )
            # elif tc["name"] in ["query_gen_node"]: # riporto esattamente la domanda

            #     for content in state['messages']:
            #         if content.name and content.name == 'sql_db_schema':
            #             tool_message = content

            #     tool_messages.append(
            #         tool_message
            #     )
            #     return {"messages": [message] + tool_messages}
               
            
    else:
        tool_messages = []
   
    return {"messages": [message] + tool_messages}

# add nodes and edges

In [None]:
workflow.add_node("query_gen_node", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))

workflow.add_node("main_agent", main_agent_node)

In [None]:
# # Define a conditional edge to decide whether to continue or end the workflow
# def should_continue(state: State) -> Literal["main_agent", "query_gen_node"]:
#     messages = state["messages"]
#     last_message = messages[-1]
#     # If there is a tool call, then we finish
   
#     if getattr(last_message, "tool_calls", None):
#         return "main_agent"
#     else:
#         return "query_gen"

In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_correct_query(state: State) -> Literal["correct_query", "main_agent"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
   
    if getattr(last_message, "tool_calls", None):
        return "main_agent"
    
    if last_message.content.startswith("Error:"):
        return "correct_query"
    # else:
    #     return "main_agent"

In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_continue_to_plot_or_query(state: State) -> Literal[END, "query_gen_node", "chart_gen_node"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
   
    if getattr(last_message, "tool_calls", None):
        for tc in last_message.tool_calls:
            if tc["name"] in ['SubmitFinalAnswer']:
                return END
            elif tc["name"] in ['chart_gen_node']:
                return "chart_gen_node"
            elif tc["name"] in ['query_gen_node']:
                return "query_gen_node"
    
    # if last_message.content.startswith("Error:"):
    #     return "chart_gen_node"
    # else:
    #     return END

In [None]:


# Add the base query path
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "main_agent")
# workflow.add_edge("main_agent", "query_gen")

# workflow.add_conditional_edges(
#     "query_gen",
#     should_continue,
# )

workflow.add_conditional_edges(
    "main_agent",
    should_continue_to_plot_or_query,
)

workflow.add_conditional_edges(
    "query_gen_node",
    should_correct_query,
)

workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen_node")

# Add plotting path
# workflow.add_edge("query_gen", "chart_gen_node")
workflow.add_edge("chart_gen_node", "_plot_data")
workflow.add_edge("_plot_data", "chart_analyze_node")
workflow.add_edge("chart_analyze_node", END)

# Compile
app = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

# Run

In [None]:
for i, event in enumerate(app.stream(
    {"messages": [("user", "What is the average total price?")]}
)):
    print(f'\n======== Event {i} ==========:\n', event)

In [None]:
## debug 
for i, event in enumerate(app.stream(
    {"messages": [("user", "Plot the average total price per product line")]}
)):
    print(f'\n======== Event {i} ==========:\n', event)

# Orchestrator

In [None]:
def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("AZURE_OPENAI_API_KEY")
_set_env("LANGCHAIN_API_KEY")

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "orchestrator"

In [None]:
from typing import Sequence
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
import operator
from typing import Literal
# The agent state is the input to each node in the graph
class AgentState(TypedDict):
    # The annotation tells the graph that new messages will always
    # be added to the current states
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # The 'next' field indicates where to route to next
    next: str


In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result

query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [db_query_tool], tool_choice="required"
)


def model_check_query(state: AgentState) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}



@tool
def _plot_data(
    data_str: str,
    code: str
) -> str:
    """Executes plotting code on a DataFrame. Input are: 
    - data_str: A string with data values as list, that will be converted to a DataFrame 
    - code: A string containing Python plotting code that uses 'df' as the DataFrame name"""
    plt.figure()
    img_bytes = BytesIO()
    
    try:
        # Convert dict to DataFrame
        
        data_list = eval(data_str)
        # Creazione del DataFrame
        df = pd.DataFrame(data_list)
        
        # Create a clean namespace for execution
        namespace = {
            'df': df,
            'plt': plt,
            'pd': pd
        }
        
        # Execute the provided code
        exec(code, namespace)
        # img_path = 'temp_plot.png'
        # plt.savefig(img_path, format='png', bbox_inches='tight', dpi=300)
        # plt.close()
        
    except Exception as e:
        #plt.close()
        return f"Failed to execute. Error: {repr(e)}"
        
    return "\nPlot Genereted. img_path:\ntemp_plot.png"
    

def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

In [None]:
from langchain_core.messages import HumanMessage


def agent_node(state, agent, name):
       
    result = agent.invoke(state)

    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }



In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from typing import Literal

members = ["sql_coder", "chart_plotter"]

system_prompt = (
    "You are a supervisor tasked with managing a conversation between the"
    " following workers:  {members}. Given the following user request,"
    " respond with the worker to act next. Each worker will perform a"
    " task and respond with their results and status. When finished,"
    " respond with FINISH."
)
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = ["FINISH"] + members


class routeResponse(BaseModel):
    
    next: Literal["sql_coder", "chart_plotter", "FINISH"]
   

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Given the conversation above, who should act next?"
            " Or should we FINISH? Select one of: {options}",
        ),
    ]
).partial(options=str(options), members=", ".join(members))


llm = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o")


def supervisor_agent(state):
    supervisor_chain = prompt | llm.with_structured_output(routeResponse)
    return supervisor_chain.invoke(state)

In [None]:
import functools
import operator

from typing_extensions import TypedDict

from langchain_core.messages import BaseMessage

from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import create_react_agent



chart_agent = create_react_agent(llm, tools=[_plot_data])
chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_plotter")


code_agent = create_react_agent(llm, tools=[list_tables_tool, get_schema_tool, db_query_tool])
code_node = functools.partial(agent_node, agent=code_agent, name="sql_coder")

workflow = StateGraph(AgentState)
workflow.add_node("chart_plotter", chart_node)
workflow.add_node("sql_coder", code_node)
workflow.add_node("supervisor", supervisor_agent)



In [None]:
for member in members:
    # We want our workers to ALWAYS "report back" to the supervisor when done
    workflow.add_edge(member, "supervisor")
# The supervisor populates the "next" field in the graph state
# which routes to a node or finishes
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)

# Finally, add entrypoint
workflow.add_edge(START, "supervisor")

graph = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(content="retrieve first 5 elem from table sales")
        ]
    }
):
    
    if "__end__" not in s:
        print(s)
        print("----")


In [None]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(content="plot first 5 top product sold from table sales")
        ]
    }
):
    
    if "__end__" not in s:
        print(s)
        print("----")


## orchestrator teams

In [None]:
import functools
import operator

from typing_extensions import TypedDict
from typing import Optional

from langchain_core.messages import BaseMessage

from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import create_react_agent

members = ["sql_coder", "chart_plotter"]
system_prompt = (
    """You are a supervisor tasked with managing a conversation between the following workers:  {members}. 
    Each worker will perform a task and respond with their results and status. 
    - sql_coder will retrieve data from the database and give you the results performing an SQL query
    - chart_plotter will use data retrieved by the sql_coder to create a plot and comment it, ONLY if user input requires a plot, do not call otherwise
    
    Given the following user request, respond with the worker to act next.
    
    Moreover, use the key question of your output to:
    - rephrase user input to just include the part of the question that is about retrieving data, when you call sql_coder
    - pass both user question and data retrieved by the sql_coder, when you call chart_plotter
    
    When finished, respond with FINISH."""
)
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = ["FINISH"] + members


class routeResponse(BaseModel):
    
    next: Literal["sql_coder", "chart_plotter", "FINISH"]
    question: Optional[str]

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Given the conversation above, who should act next?"
            " Or should we FINISH? Select one of: {options}",
        ),
    ]
).partial(options=str(options), members=", ".join(members))


llm = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o")


def supervisor_agent(state):
    supervisor_chain = prompt | llm.with_structured_output(routeResponse)
    
    return supervisor_chain.invoke(state)

chart_gen_system = """You are a Python expert specialized in data visualization.
Given an input question create a plot with data provided.
When finished comment the plot starting with 

Here is the [plot_type] 

then add some description about the data."""

chart_agent = create_react_agent(llm, tools=[_plot_data], state_modifier= chart_gen_system)
chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_plotter")

# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION
query_gen_system = """You are an SQL expert.
Given an input question you will be called two times:

- The first time you are requested to output a syntactically correct SQLite query to respond to the question. Output just the query.
- The second time you will have the data resulting from query execution, report the results preproned by FINAL ANSWER, (example) FINAL ANSWER: The average value is ...

When generating the query:
If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
DO NOT MAKE UP DATA, you must retrieve them, do not pass back anything that is not the result of a query executed on the db.
When reporting FINAL ANSWER:
you have finished your job only when you have the data to respond to the input question, the final answer do not include the query or any other information
about how to retrieve the data, ONLY the results."""



code_agent = create_react_agent(llm, tools=[list_tables_tool, get_schema_tool], state_modifier = query_gen_system)
code_node = functools.partial(agent_node, agent=code_agent, name="sql_coder")

workflow = StateGraph(AgentState)
workflow.add_node("chart_plotter", chart_node)
workflow.add_node("sql_coder", code_node)
workflow.add_node("supervisor", supervisor_agent)
workflow.add_node("correct_query", model_check_query)
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))


In [None]:

# Define a conditional edge to decide whether to continue or end the workflow
def should_correct_query(state: AgentState) -> Literal["supervisor", "correct_query", "sql_coder"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    
    if last_message.content.startswith("Error:"):
        return "correct_query"
    elif last_message.content.startswith("FINAL ANSWER"):
        return "supervisor"
    else:
        return "sql_coder"


# Define a conditional edge to decide whether to continue or end the workflow
def should_recreate_plot(state: AgentState) -> Literal["supervisor", "chart_plotter"]:
    messages = state["messages"]
    last_message = messages[-1]
    print(last_message.content)
    if last_message.content.startswith("Here is"):
        return "supervisor"
    else:
        return "chart_plotter"
    

conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
workflow.add_conditional_edges(
    "sql_coder",
    should_correct_query,
)
workflow.add_conditional_edges(
    "chart_plotter",
    should_recreate_plot,
)

workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "sql_coder")


# Finally, add entrypoint
workflow.add_edge(START, "supervisor")

graph = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(content="retrieve first 5 elem from table sales")
        ]
    }
):
    
    if "__end__" not in s:
        print(s)
        print("----")


In [None]:
messages = graph.invoke(
   {"messages": [
            HumanMessage(content="retrieve first 5 elem from table sales")
        ]}
)
# json_str = messages["messages"][-1].content
# json_str


In [None]:
messages["messages"][-1].content

In [None]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(content="plot the average total price per product line in a bar plot")
        ]
    }
):
    
    if "__end__" not in s:
        print(s)
        print("----")


In [None]:
messages = graph.invoke(
   {"messages": [
            HumanMessage(content="plot the average total price per product line")
        ]}
)
# json_str = messages["messages"][-1].content
# json_str


In [None]:
messages['messages'][-1].content