In [1]:
from langgraph.graph import START, END, StateGraph


from State_Schema.state_schema import AgentState
from agents.sql_generation_agent import sql_agent
from agents.visualization_agent import visualization_agent
from agents.helper_functions import should_generate_graph, should_retry, check_scope
from agents.analysis_agent import analysis_agent
from agents.decide_graph_need import decide_graph_need
from agents.execute_sql_query_agent import execute_query
from agents.error_agent import error_agent
from agents.guardrails_agent import guardrails_agent


def create_text2sql_agent():
    """
    defines all nodes and edges and creates langgraph graph 
    """

    graph = StateGraph(AgentState)

    graph.add_node("guardrails_agent", guardrails_agent)
    graph.add_node("sql_query_generation_agent", sql_agent)
    graph.add_node("execute_query", execute_query)
    graph.add_node("error", error_agent)
    graph.add_node("visualize_agent", visualization_agent)
    graph.add_node("analysis_agent", analysis_agent)
    graph.add_node("decide_graph_need", decide_graph_need)


    graph.set_entry_point("sql_query_generation_agent")

    # graph.add_conditional_edges(
    #     "guardrails_agent",
    #     check_scope,
    #     {
    #         "in_scope": "sql_query_generation_agent",
    #         "out_of_scope": END
    #     }
    # )
    graph.add_edge("sql_query_generation_agent", "execute_query")
    graph.add_conditional_edges(
        "execute_query",
        should_retry,
        {
            "success": "analysis_agent",
            "retry": "error",
        }
    )

    graph.add_edge("error", "execute_query")

    graph.add_edge(
        "analysis_agent",
        "decide_graph_need"
    )

    graph.add_conditional_edges(
        "decide_graph_need",
        should_generate_graph,
        {
            "visualize_agent": "visualize_agent",
            "skip_graph": END
        }
    )

    graph.add_edge("visualize_agent", END)

    return graph.compile()


def generate_graph_visualization(output_path="text2sql_workflow.png"):
    """
    Generate a PNG visualization of the LangGraph workflow.
    
    Args:
        output_path: Path where the PNG file will be saved (default: "text2sql_workflow.png")
    
    Returns:
        str: Path to the generated PNG file
    """
    try:
        # Get the graph visualization
        graph_image = create_text2sql_agent().get_graph().draw_mermaid_png()
        
        # Save to file
        with open(output_path, "wb") as f:
            f.write(graph_image)
        
        print(f"Graph visualization saved to: {output_path}")
        return output_path
        
    except Exception as e:
        print(f"Error generating graph visualization: {e}")
        print("Make sure you have 'pygraphviz' or 'grandalf' installed:")
        print("  pip install pygraphviz")
        print("  or")
        print("  pip install grandalf")
        return None
    






In [3]:
generate_graph_visualization()

Graph visualization saved to: text2sql_workflow.png


'text2sql_workflow.png'

In [2]:
from langchain_core.messages import HumanMessage


workflow = create_text2sql_agent()

output = workflow.invoke(
    {
        "question": [HumanMessage("what are the top 5 states by number of customers?")],
        "error": "",
        "iteration": 0
    }
)

print(output)

[32m2026-01-13 11:41:06.602[0m | [1mINFO    [0m | [36magents.sql_generation_agent[0m:[36msql_agent[0m:[36m59[0m - [1mreturning sql query for execution[0m
[32m2026-01-13 11:41:06.605[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m23[0m - [1min execute sql query node[0m
[32m2026-01-13 11:41:06.606[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m27[0m - [1mSELECT customer_state, COUNT(*) AS customer_count
FROM customers
GROUP BY customer_state
ORDER BY customer_count DESC
LIMIT 5;[0m
[32m2026-01-13 11:41:06.755[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m83[0m - [1mDone execution[0m
[32m2026-01-13 11:41:06.756[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m84[0m - [1m[
    {
        "customer_state": "SP",
        "customer_count": 41746
    },
    {
        "customer_state": "RJ

**************************************************
before llm call
code="fig = px.bar(df.head(5), x='customer_state', y='customer_count',\n             title='Top 5 States by Number of Customers',\n             labels={'customer_state': 'State', 'customer_count': 'Number of Customers'},\n             color='customer_count', color_continuous_scale=px.colors.sequential.Viridis)\nfig.update_layout(autosize=True, hovermode='x unified',\n                  margin=dict(l=50, r=50, t=100, b=50))\nfig.update_xaxes(title_text='State')\nfig.update_yaxes(title_text='Number of Customers')"
{'question': [HumanMessage(content='what are the top 5 states by number of customers?', additional_kwargs={}, response_metadata={})], 'sql_query': 'SELECT customer_state, COUNT(*) AS customer_count\nFROM customers\nGROUP BY customer_state\nORDER BY customer_count DESC\nLIMIT 5;', 'query_result': '[\n    {\n        "customer_state": "SP",\n        "customer_count": 41746\n    },\n    {\n        "customer_state": "

In [3]:
output["sql_query"]

'SELECT customer_state, COUNT(*) AS num_customers\nFROM customers\nGROUP BY customer_state\nORDER BY num_customers DESC\nLIMIT 5;'

In [3]:
output["query_result"]

'[\n    {\n        "customer_state": "SP",\n        "customer_count": 41746\n    },\n    {\n        "customer_state": "RJ",\n        "customer_count": 12852\n    },\n    {\n        "customer_state": "MG",\n        "customer_count": 11635\n    },\n    {\n        "customer_state": "RS",\n        "customer_count": 5466\n    },\n    {\n        "customer_state": "PR",\n        "customer_count": 5045\n    }\n]'

In [2]:
import json

x = json.dumps([{
        "customer_state": "SP",
        "customer_count": 41746
    },
    {
        "customer_state": "RJ",
        "customer_count": 12852
    },
    {
        "customer_state": "MG",
        "customer_count": 11635
    },
    {
        "customer_state": "RS",
        "customer_count": 5466
    }],)

import pandas as pd

df = pd.read_json(x)
df

  df = pd.read_json(x)


Unnamed: 0,customer_state,customer_count
0,SP,41746
1,RJ,12852
2,MG,11635
3,RS,5466


In [3]:
import plotly.express as px

fig = px.bar(df.head(5), x='customer_state', y='customer_count',
             title='Top 5 States by Number of Customers',             
             labels={'customer_state': 'State', 'customer_count': 'Number of Customers'},             
             color='customer_count', 
             color_continuous_scale=px.colors.sequential.Viridis)
fig.update_layout(autosize=True, 
                  hovermode='x unified',                  
                  margin=dict(l=50, r=50, t=100, b=50))
fig.update_xaxes(title_text='State')
fig.update_yaxes(title_text='Number of Customers')
