# Import

In [None]:
# devo gestire in maniera diversa main agent verso first tool -> first tool diventa esso stesso un agent?

In [None]:
import getpass
import os
from sqlite_dataset import SQLiteDataset, Field, String, Float, Integer
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import base64  

from typing import Any, Annotated, Literal, Union, Dict, List, Sequence, Optional
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import operator

from langchain_core.tools import tool, StructuredTool
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage, SystemMessage, HumanMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode, InjectedState, create_react_agent
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import AzureChatOpenAI

from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages

In [None]:
def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("AZURE_OPENAI_API_KEY")
_set_env("GPT_URL")
_set_env("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "langchain-academy-project"

In [None]:
db = SQLDatabase.from_uri('sqlite:///sales_dataset.db')
print(db.dialect)
print(db.get_usable_table_names())
# db.run("SELECT * FROM sales LIMIT 10;")

# utils 

In [None]:
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

# print(list_tables_tool.invoke(""))
# print(get_schema_tool.invoke("sales"))

In [None]:
@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


# print(db_query_tool.invoke("SELECT * FROM sales LIMIT 10;"))

In [None]:
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [db_query_tool], tool_choice="required"
)

#query_check.invoke({"messages": [("user", "SELECT * FROM sales LIMIT 10;")]})

## States

In [None]:
class PrivateState(TypedDict):
    next: Literal["FINISH", "query_gen_node", "chart_analyze_node"]
    messages: Annotated[list[Any], add_messages]
    final_answer: Optional[str] = None

class State(TypedDict):
    #messages: Annotated[Sequence[BaseMessage], operator.add]
    messages: Annotated[list[AnyMessage], add_messages]



In [None]:
# Define a new graph
workflow = StateGraph(State)

# Add a node for the first tool call -> The agent will first force-call the list_tables_tool to fetch the available tables from the database
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }
    
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}



In [None]:
workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))


In [None]:
# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [get_schema_tool]
)
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)




In [None]:
# # Describe a tool to represent the end state
# class SubmitFinalAnswer(BaseModel):
#     """Submit the final answer to the user based on the query results."""

#     final_answer: str = Field(..., description="The final answer to the user")

# CHART TEAM

In [None]:
@tool
def _plot_data(
    data_str: str,
    code: str
) -> str:
    """Executes plotting code on a DataFrame."""
    plt.figure()
    img_bytes = BytesIO()
    
    try:
        # Convert dict to DataFrame
        
        data_list = eval(data_str)
        # Creazione del DataFrame
        df = pd.DataFrame(data_list)
        
        # Create a clean namespace for execution
        namespace = {
            'df': df,
            'plt': plt,
            'pd': pd
        }
        
        # Execute the provided code
        exec(code, namespace)
        # img_path = 'temp_plot.png'
        # plt.savefig(img_path, format='png', bbox_inches='tight', dpi=300)
        # plt.close()
        
    except Exception as e:
        #plt.close()
        return f"Failed to execute. Error: {repr(e)}"
        
    return "\nPlot Genereted. img_path:\ntemp_plot.png"
    
# Add a node for a model to generate a query based on the question and schema
chart_gen_system = """You are a Python expert specialized in data visualization.

Given an input string of data and a specific visualization, output a syntactically correct Python code to run to:
- create that plot
- save the image in temp_plot.png
- close the plot

When generating the code:

Output the Python code that answers the input question that call the proper plot_data tool passing the code and data.
"""


chart_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", chart_gen_system), ("placeholder", "{messages}")]
)
chart_gen = chart_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o").bind_tools(
    [_plot_data],tool_choice="required"
)


def check_chart_gen_node(state: State):
    """
    Use this tool to double-check if your python code to plot the data is correct before executing it. you have to 
    find all these steps:
    - create that plot
    - save the image in temp_plot.png
    - close the plot
    """
    return {"messages": [chart_gen.invoke({"messages": [state["messages"][-1]]})]}#{"messages": [chart_gen.invoke(state)]}


# Add a node for a model to generate a query based on the question and schema
chart_analyzer_system = """You are an expert data analyst.

Given an input string of data and a specific visualization, output a syntactically correct Python code with this steps to run:
- create that plot
- save the image in temp_plot.png
- close the plot

then look at the plot created and return a comment to that plot to supervisor.

When generating the code:

Output the Python code that answers the input question without tool call.
If you get an error while executing the code, rewrite it and try again.

The comment to the generated plot must help the user finding insights on the data represented.
You can talk about distributions, anomalies, minimum or maximus values and so on, but stick to the data you see.

Do not add this kind of text in the output (example):
![Bar Graph](temp_plot.png) 
since we don't need it, just respond with a comment on the plot.
"""

prompt_messages = [
    SystemMessage(content=chart_analyzer_system),
    HumanMessagePromptTemplate.from_template(
        template=[
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{img_base64}"}},
        ]
    ),
    MessagesPlaceholder("messages"),
]

chart_analyzer_prompt = ChatPromptTemplate(messages=prompt_messages)


chart_analyzer = chart_analyzer_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o"
)


def chart_analyze_node(state: Annotated[dict, InjectedState]):
    """Use this tool to plot data retrieved from db"""
    #print('state all analyzer', state)
    for message in state["messages"]:
        if isinstance(message, ToolMessage) and message.name == "_plot_data":
            img_path = message.content.split("\n")[-1].strip()
            break
    
    if img_path:
        with open(img_path, 'rb') as f:
            img_bytes = f.read()
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            
        message = chart_analyzer.invoke({'messages': state['messages'], "img_base64": img_base64})

       
    else:
        message = chart_analyzer.invoke(state)

    return {"messages": [message.content]}


workflow.add_node("chart_analyze_node", chart_analyze_node)
workflow.add_node("check_chart_gen_node", check_chart_gen_node)
# Add nodes for the first two tools
workflow.add_node(
    "_plot_data", create_tool_node_with_fallback([_plot_data])
)


# QUERY RETRIEVAL TEAM

In [None]:

# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer to supervisor.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply report the results of the query to the supervisor.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""


query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o"
)


def query_gen_node(state: Annotated[dict, InjectedState]):
    """Use this tool to retrieve data from db"""
    print('\n\n\nquery in ', state)
    message = query_gen.invoke(state)
    print('\n\n\nquery out ', message.content)
    return {"messages": [message.content] }
   

In [None]:
# from langchain.tools import StructuredTool
# from pydantic import BaseModel, Field
# from typing import Dict, Any

# # Definiamo lo schema degli input per il tool
# class ChartGenInput(BaseModel):
#     state: str = Field(
#         description="The current state containing query results as string"
#     )

# # Ora creiamo il tool con lo schema definito
# chart_gen_tool = StructuredTool(
#     name="chart_gen_tool",
#     description="Generates a visualization of the data when needed. Use this when the user asks for a visual representation.",
#     func=check_chart_gen_node,
#     args_schema=ChartGenInput,
#     return_direct=False
# )



# Main agent

In [None]:
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")

In [None]:
from typing import Literal
# Add a node for a model to generate a query based on the question and schema
main_agent_system = """
You are a supervisor data analyst tasked with managing a conversation between the user and 
you team, made by following workers: {members}. The user is asking you and your team for help in analyzing a database,
so it might ask you simple questions like "what is the most sold product of 2020?" or questions more complex 
that also require a plot, like "plot the top three product sold in 2020 in a bar graph". 
Given the following user request, your work is to respond with the worker to act next until you have satisfied user request. 

Each worker will perform a task and respond with their results and status. 
- query_gen_node will retrieve data from the database and give you the results performing an SQL query
- chart_analyze_node will use data retrieved by the query_gen_node to create a plot and comment it, if user needs to

When you have completed user's requests, call SubmitFinalAnswer tool to submit the final answer.
DO NOT call any tool besides SubmitFinalAnswer to submit the final answer, query_gen_node when you need to retrieve data and chart_analyze_node to plot them.
"""

members = ["query_gen_node", "chart_analyze_node", "SubmitFinalAnswer"]
# options = ["FINISH"] + members

main_agent_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", main_agent_system),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Given the conversation above, what tool should you call next?"
            "Select one of your tools."
            "If you think you have final response to report call SubmitFinalAnswer",
        )
    ]
).partial(members=", ".join(members))


tools = [SubmitFinalAnswer, query_gen_node, chart_analyze_node]
llm = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ['GPT_URL'],
    azure_deployment="gpt-4o")

# def main_agent_node(state) -> PrivateState:
#     main_agent = main_agent_prompt | llm.with_structured_output(routeResponse)
#     response = main_agent.invoke(state)
#     print(f"Main agent response: {response}")
#     # print(f"Main agent response: {state}")
#     # Debug log
#     #return response
#     return {"messages": state['messages'],
#             "next": response.next,
#             "final_answer": response.final_answer}

main_agent_node = create_react_agent(model=llm, 
                                     tools=tools, 
                                     state_modifier = main_agent_prompt)

# add nodes and edges

In [None]:
workflow.add_node("query_gen_node", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))

workflow.add_node("main_agent", main_agent_node)

In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_correct_query(state: State) -> Literal["correct_query", "query_gen_node"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    
    if last_message.content.startswith("Error:"):
        return "correct_query"
    else:
        return "query_gen_node"

In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_correct_code(state: State) -> Literal["check_chart_gen_node", "chart_analyze_node"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    
    if last_message.content.startswith("Error:"):
        return "check_chart_gen_node"
    else:
        return "chart_analyze_node"

In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def who_is_next(state: State) -> Literal[END, "chart_analyze_node", "query_gen_node"]:
    
    messages = state["messages"]
    last_message = messages[-1]
    
    tool_calls = getattr(last_message, "tool_calls", None)
    print(tool_calls)
    if tool_calls:
        print(tool_calls)
        # Estrarre il nome dello strumento dal primo tool call
        tool_name = tool_calls[0].get("name")
        if tool_name == "query_gen_node":
            return "query_gen_node"
        if tool_name == "chart_analyze_node":
            return "chart_analyze_node"
        # Default se il tool chiamato non Ã¨ conosciuto
        return END
    else:
        print('error in tool call, last message: ', last_message)

In [None]:
# Add the base query path
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "main_agent")
#workflow.add_node("format_response", format_final_response)

# for member in members:
#     # We want our workers to ALWAYS "report back" to the supervisor when done
#     workflow.add_edge(member, "main_agent")
    
# conditional_map = {k: k for k in members}
# conditional_map["FINISH"] = "format_response"

# workflow.add_conditional_edges("main_agent", lambda x: x["next"], conditional_map)

workflow.add_conditional_edges(
    "query_gen_node",
    should_correct_query,
)

workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen_node")

workflow.add_conditional_edges(
    "chart_analyze_node",
    should_correct_code,
)
workflow.add_edge("check_chart_gen_node", "_plot_data")
workflow.add_edge("_plot_data", "chart_analyze_node")

workflow.add_conditional_edges("main_agent", 
                               who_is_next)


In [None]:
# Compile
app = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

# Run

## sql basic

In [None]:
for i, event in enumerate(app.stream(
    {"messages": [HumanMessage(content="What is the average total price?")]},
     {"recursion_limit": 8}
)):
    
    if "__end__" not in event:
        call = next(iter(event.keys()))
        print(f'\n======== Event {i} ==========:\n', call)

## plot 

In [None]:
## debug 
for i, event in enumerate(app.stream(
    {"messages": [("user", "Plot the average total price per product line")]}
)):
    print(f'\n======== Event {i} ==========:\n', event)