# Import

In [None]:
import getpass
import os
from sqlite_dataset import SQLiteDataset, Field, String, Float, Integer
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import base64  

from typing import Any, Annotated, Literal, Union, Dict, List
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langchain_core.tools import tool, StructuredTool
from langchain_core.messages import ToolMessage, AIMessage, SystemMessage, HumanMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import AzureChatOpenAI

from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages

## setup

In [None]:
def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("AZURE_OPENAI_API_KEY")
_set_env("GPT_URL")

## DB: costruzione se necessario

In [None]:
# # Define the MySalesDataset class
# class MySalesDataset(SQLiteDataset):
#     # Define fields with their data types and specify the table name 'sales'
#     invoice_id = Field(String, tablename='sales')
#     branch = Field(String, tablename='sales')
#     city = Field(String, tablename='sales')
#     customer_type = Field(String, tablename='sales')
#     gender = Field(String, tablename='sales')
#     product_line = Field(String, tablename='sales')
#     unit_price = Field(Float, tablename='sales')
#     quantity = Field(Integer, tablename='sales')
#     tax_5_perc = Field(Float, tablename='sales')
#     total = Field(Float, tablename='sales')
#     date = Field(String, tablename='sales')  # Assuming date in string format like '1/5/2019'
#     time = Field(String, tablename='sales')  # Assuming time in string format like '13:08'
#     payment = Field(String, tablename='sales')
#     cogs = Field(Float, tablename='sales')
#     gross_margin_percentage = Field(Float, tablename='sales')
#     gross_income = Field(Float, tablename='sales')
#     rating = Field(Float, tablename='sales')


# data = pd.read_csv('data/supermarket_sales.csv')
# data.columns = data.columns.str.lower().str.replace(' ', '_')
# data.rename(columns={'tax_5%': 'tax_5_perc'}, inplace=True)
# # Load data from a DataFrame (assuming `data` is already defined as a pandas DataFrame)
# data_records = data.to_dict(orient='records')

# # Initialize and insert data into MySalesDataset
# with MySalesDataset('sales_dataset.db') as ds:
#     ds.insert_data('sales', data_records)


## set db

In [None]:
db = SQLDatabase.from_uri('sqlite:///sales_dataset.db')
print(db.dialect)
print(db.get_usable_table_names())
# db.run("SELECT * FROM sales LIMIT 10;")

# Query Agent with validation/retry

## fallback tool

In [None]:
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

## query agent tools

In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

print(list_tables_tool.invoke(""))
print(get_schema_tool.invoke("sales"))

In [None]:
@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


print(db_query_tool.invoke("SELECT * FROM sales LIMIT 10;"))

## error checker LLM 

In [None]:
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [db_query_tool], tool_choice="required"
)

query_check.invoke({"messages": [("user", "SELECT * FROM sales LIMIT 10;")]})

# Define Sql Graph

In [None]:
# Define the state for the agent
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

In [None]:
# Define a new graph
workflow = StateGraph(State)

# Add a node for the first tool call -> The agent will first force-call the list_tables_tool to fetch the available tables from the database
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }
    
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}



In [None]:
workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))

In [None]:
# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [get_schema_tool]
)
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)



In [None]:
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")

In [None]:

# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""


query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [SubmitFinalAnswer]
)


def query_gen_node(state: State):
    message = query_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

In [None]:
workflow.add_node("query_gen", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))


In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"

## edges

In [None]:
# Specify the edges between the nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges(
    "query_gen",
    should_continue,
)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

# Compile the workflow into a runnable
app = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
messages = app.invoke(
    {"messages": [("user", "What is the average total price?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
json_str

In [None]:
## debug 
for i, event in enumerate(app.stream(
    {"messages": [("user", "What is the average total price?")]}
)):
    print(f'\n======== Event {i} ==========:\n', event)

# Create Chart analyzer graph
parto dai dati e provo a fare
-> tool per generazione plot -> save img -> pass to analyzer agent -> open image e commento -> END -> show plot

In [None]:
# Define the state for the agent
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

## fallback tool

In [None]:
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

## plot tool

In [None]:
# class PlotArgs(BaseModel):
#     """Arguments for the plot_data function"""
#     data_str: str = Field(
#         ..., 
#         description="String containing the data to plot, will be converted to DataFrame"
#     )
#     code: str = Field(
#         ..., 
#         description="Python code string that uses 'df' to create the plot"
#     )
    
#     class Config:
#         arbitrary_types_allowed = True
# @tool
# def _plot_data(
#     data_str: str,
#     code: str
# ) -> str:
#     """Executes plotting code on a DataFrame."""
#     plt.figure()
#     img_bytes = BytesIO()
    
#     try:
#         # Convert dict to DataFrame
        
#         data_list = eval(data_str)
#         # Creazione del DataFrame
#         df = pd.DataFrame(data_list)
        
#         # Create a clean namespace for execution
#         namespace = {
#             'df': df,
#             'plt': plt,
#             'pd': pd
#         }
        
#         # Execute the provided code
#         exec(code, namespace)
        
#         plt.savefig(img_bytes, format='png', bbox_inches='tight', dpi=300)
#         plt.close()
#     except Exception as e:
#         plt.close()
#         return f"Failed to execute. Error: {repr(e)}"
        
#     return "\nPlot Genereted."

# # Create the structured tool
# plot_data_tool = StructuredTool(
#     name="plot_data",
#     description="""Create plots from data. 
#     Required Arguments:
#         - data_str: A string with data values that will be converted to a DataFrame
#         - code: A string containing Python plotting code that uses 'df' as the DataFrame name
#     """,
#     func = _plot_data,
#     args_schema = PlotArgs
# )

In [None]:
@tool
def _plot_data(
    data_str: str,
    code: str
) -> str:
    """Executes plotting code on a DataFrame."""
    plt.figure()
    img_bytes = BytesIO()
    
    try:
        # Convert dict to DataFrame
        
        data_list = eval(data_str)
        # Creazione del DataFrame
        df = pd.DataFrame(data_list)
        
        # Create a clean namespace for execution
        namespace = {
            'df': df,
            'plt': plt,
            'pd': pd
        }
        
        # Execute the provided code
        exec(code, namespace)
        # img_path = 'temp_plot.png'
        # plt.savefig(img_path, format='png', bbox_inches='tight', dpi=300)
        # plt.close()
        
    except Exception as e:
        #plt.close()
        return f"Failed to execute. Error: {repr(e)}"
        
    return "\nPlot Genereted. img_path:\ntemp_plot.png"

In [None]:
# print(_plot_data.invoke(data_str="[(319.632538235294, 'Electronic accessories'), (305.0892977528089, 'Fashion accessories'), (322.67151724137955, 'Food and beverages'), (323.64301973684223, 'Health and beauty'), (336.6369562500001, 'Home and lifestyle'), (332.06521987951794, 'Sports and travel')]",
#                         code='''import matplotlib.pyplot as plt

#                             # Creazione del bar plot
#                             plt.figure(figsize=(10, 6))
#                             plt.bar(df["Category"], df["Value"], color='skyblue')
#                             plt.xlabel("Category")
#                             plt.ylabel("Value")
#                             plt.title("Bar Plot of Categories vs Values")
#                             plt.xticks(rotation=45, ha="right")
#                             plt.tight_layout()
                            
#                             # Mostra il grafico
#                             plt.show()'''))

## code creator agent

In [None]:
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")

In [None]:
# Add a node for a model to generate a query based on the question and schema
chart_gen_system = """You are a Python expert specialized in data visualization.

Given an input string of data and a specific visualization, output a syntactically correct Python code to run to:
- create that plot
- save the image in temp_plot.png
- close the plot

When generating the code:

Output the Python code that answers the input question that call the proper plot_data tool passing the code and data.
"""


chart_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", chart_gen_system), ("placeholder", "{messages}")]
)
chart_gen = chart_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [_plot_data]
)


def chart_gen_node(state: State):
    message = chart_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] not in  ['_plot_data']:
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call _plot_data to create the chart.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

In [None]:
# Add a node for a model to generate a query based on the question and schema
chart_analyzer_system = """You are an expert data analyst.
Given the image of a plot as input, output a comment to that plot that helps finding insights on the data represented.
You can talk about distributions, anomalies, minimum or maximus values and so on, but stick to the data you see.
When you have generated the comment call the SubmitFinalAnswer tool to show your response to the user.

Do not add this kind of text (example) ![Bar Graph](temp_plot.png) since we don't need it, just respond with a comment on the plot.
"""


# chart_analyzer_prompt = ChatPromptTemplate.from_messages(
#     [("system", chart_analyzer_system), ("placeholder", "{messages}")]
# )

#def chart_analyze_node(state: State):
    # message = chart_analyzer.invoke(state)

    # # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    # tool_messages = []
    # if message.tool_calls:
    #     for tc in message.tool_calls:
    #         if tc["name"] not in  ['SubmitFinalAnswer']:
    #             tool_messages.append(
    #                 ToolMessage(
    #                     content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit final answer.",
    #                     tool_call_id=tc["id"],
    #                 )
    #             )
    # else:
    #     tool_messages = []
    # return {"messages": [message] + tool_messages}


prompt_messages = [
    SystemMessage(content=chart_analyzer_system),
    HumanMessagePromptTemplate.from_template(
        template=[
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{img_base64}"}},
        ]
    ),
    MessagesPlaceholder("messages"),
]

chart_analyzer_prompt = ChatPromptTemplate(messages=prompt_messages)


chart_analyzer = chart_analyzer_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [SubmitFinalAnswer]
)


def chart_analyze_node(state: State):
    for message in state["messages"]:
        if isinstance(message, ToolMessage) and message.name == "_plot_data":
            img_path = message.content.split("\n")[-1].strip()
            break
    
    if img_path:
        with open(img_path, 'rb') as f:
            img_bytes = f.read()
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            
        message = chart_analyzer.invoke({'messages': state['messages'], "img_base64": img_base64})
       
        # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
        tool_messages = []
        if message.tool_calls:
            for tc in message.tool_calls:
                if tc["name"] not in ['SubmitFinalAnswer']:
                    tool_messages.append(
                        ToolMessage(
                            content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit final answer.",
                            tool_call_id=tc["id"],
                        )
                    )
        else:
            tool_messages = []
        return {"messages": [message] + tool_messages}
 

## Nodes and edges

In [None]:

workflow = StateGraph(State)
workflow.add_node("chart_gen_node", chart_gen_node)
# Add nodes for the first two tools
workflow.add_node(
    "_plot_data", create_tool_node_with_fallback([_plot_data])
)
workflow.add_node("chart_analyze_node", chart_analyze_node)

In [None]:
# Specify the edges between the nodes
workflow.add_edge(START, "chart_gen_node")
workflow.add_edge("chart_gen_node", "_plot_data")
workflow.add_edge("_plot_data", "chart_analyze_node")
# workflow.add_conditional_edges(
#     "query_gen",
#     should_continue,
# )
# workflow.add_edge("correct_query", "execute_query")
# workflow.add_edge("execute_query", "query_gen")

# Compile the workflow into a runnable
app = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
# ## debug 
# for event in app.stream(
#     {"messages": [("user", """[(319.632538235294, 'Electronic accessories'), (305.0892977528089, 'Fashion accessories'), (322.67151724137955, 'Food and beverages'), (323.64301973684223, 'Health and beauty'), (336.6369562500001, 'Home and lifestyle'), (332.06521987951794, 'Sports and travel')]
#  Plot this data in a bar graph""")]}
# ):
#     print(event)

In [None]:
messages = app.invoke(
   {"messages": [("user", """[(319.632538235294, 'Electronic accessories'), (305.0892977528089, 'Fashion accessories'), (322.67151724137955, 'Food and beverages'), (323.64301973684223, 'Health and beauty'), (336.6369562500001, 'Home and lifestyle'), (332.06521987951794, 'Sports and travel')]
 Plot this data in a bar graph""")]}
)
json_str = messages["messages"][-1].content
json_str


In [None]:

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img_path = "temp_plot.png"
img = mpimg.imread(img_path)

plt.figure(figsize=(12, 8))  # Puoi modificare questi valori per ingrandire ulteriormente

# Most
plt.imshow(img)
plt.axis('off')  # Rimuove gli assi per una visualizzazione piÃ¹ pulita
plt.show()

# Put together  Query + Chart analyzer

## fallback errors

In [None]:
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

## query agent tools

In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

# print(list_tables_tool.invoke(""))
# print(get_schema_tool.invoke("sales"))

In [None]:
@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


# print(db_query_tool.invoke("SELECT * FROM sales LIMIT 10;"))

## error checker LLM

In [None]:
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [db_query_tool], tool_choice="required"
)

#query_check.invoke({"messages": [("user", "SELECT * FROM sales LIMIT 10;")]})

## Define the state for the agent

In [None]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

## initialize workflow

In [None]:
# Define a new graph
workflow = StateGraph(State)

# Add a node for the first tool call -> The agent will first force-call the list_tables_tool to fetch the available tables from the database
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }
    
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}



In [None]:
workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))


In [None]:
# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [get_schema_tool]
)
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)


## final answer tool

In [None]:
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")

## chart agent

In [None]:
@tool
def _plot_data(
    data_str: str,
    code: str
) -> str:
    """Executes plotting code on a DataFrame."""
    plt.figure()
    img_bytes = BytesIO()
    
    try:
        # Convert dict to DataFrame
        
        data_list = eval(data_str)
        # Creazione del DataFrame
        df = pd.DataFrame(data_list)
        
        # Create a clean namespace for execution
        namespace = {
            'df': df,
            'plt': plt,
            'pd': pd
        }
        
        # Execute the provided code
        exec(code, namespace)
        # img_path = 'temp_plot.png'
        # plt.savefig(img_path, format='png', bbox_inches='tight', dpi=300)
        # plt.close()
        
    except Exception as e:
        #plt.close()
        return f"Failed to execute. Error: {repr(e)}"
        
    return "\nPlot Genereted. img_path:\ntemp_plot.png"
    
# Add a node for a model to generate a query based on the question and schema
chart_gen_system = """You are a Python expert specialized in data visualization.

Given an input string of data and a specific visualization, output a syntactically correct Python code to run to:
- create that plot
- save the image in temp_plot.png
- close the plot

When generating the code:

Output the Python code that answers the input question that call the proper plot_data tool passing the code and data.
"""


chart_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", chart_gen_system), ("placeholder", "{messages}")]
)
chart_gen = chart_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [_plot_data]
)


def chart_gen_node(state: State):
    message = chart_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] not in  ['_plot_data']:
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call _plot_data to create the chart.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

# Add a node for a model to generate a query based on the question and schema
chart_analyzer_system = """You are an expert data analyst.
Given the image of a plot as input, output a comment to that plot that helps finding insights on the data represented.
You can talk about distributions, anomalies, minimum or maximus values and so on, but stick to the data you see.
When you have generated the comment call the SubmitFinalAnswer tool to show your response to the user.

Do not add this kind of text in the output (example):
![Bar Graph](temp_plot.png) 
since we don't need it, just respond with a comment on the plot.
"""

prompt_messages = [
    SystemMessage(content=chart_analyzer_system),
    HumanMessagePromptTemplate.from_template(
        template=[
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{img_base64}"}},
        ]
    ),
    MessagesPlaceholder("messages"),
]

chart_analyzer_prompt = ChatPromptTemplate(messages=prompt_messages)


chart_analyzer = chart_analyzer_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [SubmitFinalAnswer]
)


def chart_analyze_node(state: State):
    for message in state["messages"]:
        if isinstance(message, ToolMessage) and message.name == "_plot_data":
            img_path = message.content.split("\n")[-1].strip()
            break
    
    if img_path:
        with open(img_path, 'rb') as f:
            img_bytes = f.read()
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            
        message = chart_analyzer.invoke({'messages': state['messages'], "img_base64": img_base64})
       
        # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
        tool_messages = []
        if message.tool_calls:
            for tc in message.tool_calls:
                if tc["name"] not in ['SubmitFinalAnswer']:
                    tool_messages.append(
                        ToolMessage(
                            content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit final answer.",
                            tool_call_id=tc["id"],
                        )
                    )
        else:
            tool_messages = []
        return {"messages": [message] + tool_messages}


workflow.add_node("chart_gen_node", chart_gen_node)
# Add nodes for the first two tools
workflow.add_node(
    "_plot_data", create_tool_node_with_fallback([_plot_data])
)
workflow.add_node("chart_analyze_node", chart_analyze_node)

## query agent

In [None]:
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field
from typing import Dict, Any

# Definiamo lo schema degli input per il tool
class ChartGenInput(BaseModel):
    state: str = Field(
        description="The current state containing query results as string"
    )

# Ora creiamo il tool con lo schema definito
chart_gen_tool = StructuredTool(
    name="generate_chart",
    description="Generates a visualization of the data when needed. Use this when the user asks for a visual representation.",
    func=chart_gen_node,
    args_schema=ChartGenInput,
    return_direct=False
)

In [None]:

# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer, unless input question is about data visualization. 
In that case call chart_gen_node passing the results as a string.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
"""


query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | AzureChatOpenAI(
    temperature=0,
    api_version="2024-08-01-preview",
    azure_endpoint=os.environ["GPT_URL"],
    azure_deployment="gpt-4o").bind_tools(
    [SubmitFinalAnswer, chart_gen_node]
)


def query_gen_node(state: State):
    message = query_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] not in ["SubmitFinalAnswer", "chart_gen_node"]:
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer or chart_gen_node to create a plot.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

In [None]:
workflow.add_node("query_gen", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))


In [None]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal[END, "correct_query", "chart_gen_node", "query_gen"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
   
    if getattr(last_message, "tool_calls", None):
        for tc in last_message.tool_calls:
            if tc["name"] in ['SubmitFinalAnswer']:
                return END
            elif tc["name"] in ['chart_gen_node']:
                return "chart_gen_node"
    
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"

In [None]:


# Add the base query path
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")



workflow.add_conditional_edges(
    "query_gen",
    should_continue,
)

workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

# Add plotting path
# workflow.add_edge("query_gen", "chart_gen_node")
workflow.add_edge("chart_gen_node", "_plot_data")
workflow.add_edge("_plot_data", "chart_analyze_node")
workflow.add_edge("chart_analyze_node", END)

# Compile
app = workflow.compile()

### display

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

# Run

In [None]:
for i, event in enumerate(app.stream(
    {"messages": [("user", "What is the average total price?")]}
)):
    print(f'\n======== Event {i} ==========:\n', event)

In [None]:
messages = app.invoke(
   {"messages": [("user","What is the average total price?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
json_str


In [None]:
## debug 
for i, event in enumerate(app.stream(
    {"messages": [("user", "Plot the average total price per product line")]}
)):
    print(f'\n======== Event {i} ==========:\n', event)