In [1]:
from langgraph.graph import START, END, StateGraph


from State_Schema.state_schema import AgentState
from agents.sql_generation_agent import sql_agent
from agents.visualization_agent import visualization_agent
from agents.helper_functions import should_generate_graph, should_retry, check_scope
from agents.analysis_agent import analysis_agent
from agents.decide_graph_need import decide_graph_need
from agents.execute_sql_query_agent import execute_query
from agents.error_agent import error_agent
from agents.guardrails_agent import guardrails_agent


def create_text2sql_agent():
    """
    defines all nodes and edges and creates langgraph graph 
    """

    graph = StateGraph(AgentState)

    graph.add_node("guardrails_agent", guardrails_agent)
    graph.add_node("sql_query_generation_agent", sql_agent)
    graph.add_node("execute_query", execute_query)
    graph.add_node("error", error_agent)
    graph.add_node("visualize_agent", visualization_agent)
    graph.add_node("analysis_agent", analysis_agent)
    graph.add_node("decide_graph_need", decide_graph_need)


    graph.set_entry_point("sql_query_generation_agent")

    # graph.add_conditional_edges(
    #     "guardrails_agent",
    #     check_scope,
    #     {
    #         "in_scope": "sql_query_generation_agent",
    #         "out_of_scope": END
    #     }
    # )
    graph.add_edge("sql_query_generation_agent", "execute_query")
    graph.add_conditional_edges(
        "execute_query",
        should_retry,
        {
            "success": "analysis_agent",
            "retry": "error",
        }
    )

    graph.add_edge("error", "execute_query")

    graph.add_edge(
        "analysis_agent",
        "decide_graph_need"
    )

    graph.add_conditional_edges(
        "decide_graph_need",
        should_generate_graph,
        {
            "visualize_agent": "visualize_agent",
            "skip_graph": END
        }
    )

    graph.add_edge("visualize_agent", END)

    return graph.compile()


def generate_graph_visualization(output_path="text2sql_workflow.png"):
    """
    Generate a PNG visualization of the LangGraph workflow.
    
    Args:
        output_path: Path where the PNG file will be saved (default: "text2sql_workflow.png")
    
    Returns:
        str: Path to the generated PNG file
    """
    try:
        # Get the graph visualization
        graph_image = create_text2sql_agent().get_graph().draw_mermaid_png()
        
        # Save to file
        with open(output_path, "wb") as f:
            f.write(graph_image)
        
        print(f"Graph visualization saved to: {output_path}")
        return output_path
        
    except Exception as e:
        print(f"Error generating graph visualization: {e}")
        print("Make sure you have 'pygraphviz' or 'grandalf' installed:")
        print("  pip install pygraphviz")
        print("  or")
        print("  pip install grandalf")
        return None
    






In [2]:
generate_graph_visualization()

Graph visualization saved to: text2sql_workflow.png


'text2sql_workflow.png'

In [3]:
from langchain_core.messages import HumanMessage


workflow = create_text2sql_agent()

output = workflow.invoke(
    {
        "question": [HumanMessage("what are the top 5 states by number of customers?")],
        "error": "",
        "iteration": 0
    }
)

print(output)

[32m2026-01-15 10:57:16.436[0m | [1mINFO    [0m | [36magents.sql_generation_agent[0m:[36msql_agent[0m:[36m59[0m - [1mreturning sql query for execution[0m
[32m2026-01-15 10:57:16.440[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m23[0m - [1min execute sql query node[0m
[32m2026-01-15 10:57:16.441[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m27[0m - [1mSELECT customer_state, COUNT(customer_id) AS num_customers
FROM customers
GROUP BY customer_state
ORDER BY num_customers DESC
LIMIT 5;[0m
[32m2026-01-15 10:57:16.521[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m83[0m - [1mDone execution[0m
[32m2026-01-15 10:57:16.523[0m | [1mINFO    [0m | [36magents.execute_sql_query_agent[0m:[36mexecute_query[0m:[36m84[0m - [1m[
  {
    "customer_state": "SP",
    "num_customers": 41746
  },
  {
    "customer_state": "RJ",
    "num

**************************************************
before llm call
code="df_sorted = df.sort_values('num_customers', ascending=False).head(20)\nfig = go.Figure(data=go.Bar(\n    x=df_sorted['customer_state'],\n    y=df_sorted['num_customers'],\n    marker_color=px.colors.qualitative.Plotly\n))\nfig.update_layout(\n    title='Top States by Number of Customers',\n    xaxis_title='Customer State',\n    yaxis_title='Number of Customers',\n    hovermode='x unified',\n    autosize=True,\n    height=600,\n    width=800,\n    margin=dict(t=80, l=80, r=80, b=80)\n)"


[32m2026-01-15 10:58:19.893[0m | [1mINFO    [0m | [36magents.visualization_agent[0m:[36mvisualization_agent[0m:[36m128[0m - [1m{"data":[{"marker":{"color":["#636EFA","#EF553B","#00CC96","#AB63FA","#FFA15A","#19D3F3","#FF6692","#B6E880","#FF97FF","#FECB52"]},"x":["SP","RJ","MG","RS","PR"],"y":{"dtype":"i4","bdata":"EqMAADQyAABzLQAAWhUAALUTAAA="},"type":"bar"}],"layout":{"template":{"data":{"histogram2dcontour":[{"type":"histogram2dcontour","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]}],"choropleth":[{"type":"choropleth","colorbar":{"outlinewidth":0,"ticks":""}}],"histogram2d":[{"type":"histogram2d","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#

{'question': [HumanMessage(content='what are the top 5 states by number of customers?', additional_kwargs={}, response_metadata={})], 'sql_query': 'SELECT customer_state, COUNT(customer_id) AS num_customers\nFROM customers\nGROUP BY customer_state\nORDER BY num_customers DESC\nLIMIT 5;', 'query_result': '[\n  {\n    "customer_state": "SP",\n    "num_customers": 41746\n  },\n  {\n    "customer_state": "RJ",\n    "num_customers": 12852\n  },\n  {\n    "customer_state": "MG",\n    "num_customers": 11635\n  },\n  {\n    "customer_state": "RS",\n    "num_customers": 5466\n  },\n  {\n    "customer_state": "PR",\n    "num_customers": 5045\n  }\n]', 'final_answer': 'The top 5 states by number of customers are:\n\n1. **SP** - 41,746 customers  \n2. **RJ** - 12,852 customers  \n3. **MG** - 11,635 customers  \n4. **RS** - 5,466 customers  \n5. **PR** - 5,045 customers  \n\nThese are two-letter state codes (e.g., SP = SÃ£o Paulo, RJ = Rio de Janeiro).', 'error': '', 'iteration': 0, 'needs_graph': 

In [4]:
import json

json_data = [
  {
    "customer_state": "SP",
    "num_customers": 41746
  },
  {
    "customer_state": "RJ",
    "num_customers": 12852
  },
  {
    "customer_state": "MG",
    "num_customers": 11635
  },
  {
    "customer_state": "RS",
    "num_customers": 5466
  },
  {
    "customer_state": "PR",
    "num_customers": 5045
  }
]

json.dumps(json_data)

'[{"customer_state": "SP", "num_customers": 41746}, {"customer_state": "RJ", "num_customers": 12852}, {"customer_state": "MG", "num_customers": 11635}, {"customer_state": "RS", "num_customers": 5466}, {"customer_state": "PR", "num_customers": 5045}]'

In [3]:
output["sql_query"]

'SELECT customer_state, COUNT(*) AS num_customers\nFROM customers\nGROUP BY customer_state\nORDER BY num_customers DESC\nLIMIT 5;'

In [None]:
output["query_result"]

'[\n    {\n        "customer_state": "SP",\n        "num_customers": 41746\n    },\n    {\n        "customer_state": "RJ",\n        "num_customers": 12852\n    },\n    {\n        "customer_state": "MG",\n        "num_customers": 11635\n    },\n    {\n        "customer_state": "RS",\n        "num_customers": 5466\n    },\n    {\n        "customer_state": "PR",\n        "num_customers": 5045\n    }\n]'

In [7]:
import json

x = json.dumps(output["query_result"],)
# x

import pandas as pd

df = pd.read_json(x)
df


Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.



ValueError: DataFrame constructor not properly called!

In [9]:
import plotly.express as px
import plotly.graph_objects as go

df_top5 = df.nlargest(5, 'customer_count')
fig = px.bar(df_top5, x='customer_state', y='customer_count',
              title='Top 5 States by Number of Customers',
              labels={'customer_state': 'State', 'customer_count': 'Number of Customers'},
              color='customer_count',
              color_continuous_scale=px.colors.sequential.Viridis)
fig.update_traces(hoverinfo='x+y', marker=dict(line=dict(width=1, color='DarkSlateGrey')))
fig.update_layout(title_x=0.5, xaxis_title='State', yaxis_title='Number of Customers',
                  template='plotly_white', autosize=True)

KeyError: 'customer_count'

In [8]:
import plotly.express as px
import plotly.graph_objects as go

df_sorted = df.sort_values('num_customers', ascending=False).head(20)
fig = go.Figure(data=go.Bar(
    x=df_sorted['customer_state'],
    y=df_sorted['num_customers'],
    marker_color=px.colors.qualitative.Plotly
))
fig.update_layout(
    title='Top States by Number of Customers',
    xaxis_title='Customer State',
    yaxis_title='Number of Customers',
    hovermode='x unified',
    autosize=True,
    height=600,
    width=800,
    margin=dict(t=80, l=80, r=80, b=80)
)


KeyError: 'num_customers'