The below code is a getting started section explored with the help of the blog from Medium

https://medium.com/towards-data-science/from-basics-to-advanced-exploring-langgraph-e8c1cf4db787

In [2]:
import os
import traceback
from pathlib import Path

import psycopg
from dotenv import load_dotenv


In [None]:
env_loaded = load_dotenv()
print(f"Env loaded: {env_loaded}")

In [36]:
# Global connection object (initialized by helper function)
_conn = None

def _get_db_connection():
    """Helper function to get a database connection."""
    global _conn
    db_host = os.environ.get("POSTGRES_HOST")
    db_user = os.environ.get("POSTGRES_USER")
    db_password = os.environ.get("POSTGRES_PASSWORD")
    db_name = os.environ.get("POSTGRES_DB")
    db_port = os.environ.get("POSTGRES_PORT")

    connection_string = (
        f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
    )
    print(connection_string)
    if _conn is None or _conn.closed:
        try:
            _conn = psycopg.connect(connection_string)
            _conn.autocommit = False  # Set autocommit to False for more control
        except (Exception, psycopg.Error) as error:
            print(traceback.format_exc())
            raise Exception(f"Error connecting to database: {error}")

    return _conn

In [5]:
def get_sql_data(query):
    """
    Executes the SQL query passed as an argument and returns the data from
    Postgresql database in nicely formatted textual format.

    Args:
        query (str): The SQL query to execute.

    Returns:
        str: A nicely formatted string representation of the data returned
             by the query.
             Returns "Error: No results found" if the query returns no data or
             an error message if an exception occurs.
    """
    conn = None
    try:
        conn = _get_db_connection()
        cur = conn.cursor()
        cur.execute(query)
        rows = cur.fetchall()

        if not rows:
            return "Error: No results found"

        # Get column names for headers
        column_names = [desc[0] for desc in cur.description]

        # Format data with headers
        formatted_data = ""

        # Calculate maximum width for each column
        max_widths = [len(str(col)) for col in column_names]
        for row in rows:
            for i, value in enumerate(row):
                max_widths[i] = max(max_widths[i], len(str(value)))

        # Create header line
        header_line = "|"
        for i, col in enumerate(column_names):
            header_line += f" {col.ljust(max_widths[i])} |"
        formatted_data += header_line + "\n"

        # Create separator line
        separator_line = "|"
        for width in max_widths:
            separator_line += f"-{'-'*width}-|"
        formatted_data += separator_line + "\n"

        # Create data lines
        for row in rows:
            row_line = "|"
            for i, value in enumerate(row):
                row_line += f" {str(value).ljust(max_widths[i])} |"
            formatted_data += row_line + "\n"

        cur.close()
        return formatted_data

    except (Exception, psycopg.Error) as error:
        return f"Database returned error: {error}"
    finally:
        if conn:  # if connection was established
            # Do not close the global connection object
            # conn.close()
            pass

In [None]:
# Example usage:
query1 = "SELECT * FROM cate_ewa.pcb_labels;"

print("Query 1 Result:\n", get_sql_data(query1))

Let’s define one tool named execute_sql , which enables the execution of any SQL query. We use pydantic to specify the tool’s structure, ensuring that the LLM agent has all the needed information to use the tool effectively.

In [6]:
from langchain_core.tools import tool
from pydantic.v1 import BaseModel, Field
from typing import Optional


class SQLQuery(BaseModel):
    query: str = Field(description="SQL query to execute")


@tool(args_schema=SQLQuery)
def execute_sql(query: str) -> str:
    """Returns the result of SQL query execution"""
    return get_sql_data(query)

We can print the parameters of the created tool to see what information is passed to LLM.

In [None]:
print(
    f"""
name: {execute_sql.name}
description: {execute_sql.description}
arguments: {execute_sql.to_json()}
"""
)

Our current example is relatively straightforward. So, we will only need to store the history of messages. Let’s define the agent state.

We’ve defined a single parameter in AgentState — messages — which is a list of objects of the class AnyMessage . Additionally, we annotated it with operator.add (reducer). This annotation ensures that each time a node returns a message, it is appended to the existing list in the state. Without this operator, each new message would replace the previous value rather than being added to the list.

In [28]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage


# defining agent state
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], operator.add]

The next step is to define the agent itself. Let’s start with __init__ function. We will specify three arguments for the agent: model, list of tools and system prompt.

In [13]:
class SQLAgent:
    def __init__(self, model, tools, system_prompt=""):
        self.system_prompt = system_prompt

        # initialising graph with a state
        graph = StateGraph(AgentState)

        # adding nodes
        graph.add_node("llm", self.call_llm)
        graph.add_node("function", self.execute_function)
        graph.add_conditional_edges(
            "llm", self.exists_function_calling, {True: "function", False: END}
        )
        graph.add_edge("function", "llm")

        # setting starting point
        graph.set_entry_point("llm")

        self.graph = graph.compile()
        self.tools = {t.name: t for t in tools}
        self.model = model.bind_tools(tools)

    def exists_function_calling(self, state: AgentState):
        result = state["messages"][-1]
        return len(result.tool_calls) > 0

    def call_llm(self, state: AgentState):
        messages = state["messages"]
        if self.system_prompt:
            messages = [SystemMessage(content=self.system_prompt)] + messages
        message = self.model.invoke(messages)
        return {"messages": [message]}

    def execute_function(self, state: AgentState):
        tool_calls = state["messages"][-1].tool_calls
        results = []
        for t in tool_calls:
            print(f"Calling: {t}")
            if not t["name"] in self.tools:  # check for bad tool name from LLM
                print("\n ....bad tool name....")
                result = "bad tool name, retry"  # instruct LLM to retry if bad
            else:
                result = self.tools[t["name"]].invoke(t["args"])
            results.append(
                ToolMessage(tool_call_id=t["id"], name=t["name"], content=str(result))
            )
        print("Back to the model!")
        return {"messages": results}

In [None]:
import os
from langchain_openai.chat_models import AzureChatOpenAI

# system prompt
prompt = """You are a senior expert in SQL and data analysis. 
So, you can help the team to gather needed data to power their decisions. 
You are very accurate and take into account all the nuances in data.
Your goal is to provide the detailed documentation for the table in database 
that will help users."""

model = AzureChatOpenAI(
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"),
    api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
    temperature=0,
    max_tokens=8192,
)


In [None]:
doc_agent = SQLAgent(model, [execute_sql], system_prompt=prompt)

LangGraph provides us with quite a handy feature to visualise graphs. To use it, you need to install pygraphviz.

Refer to the [requirements.txt](../requirements.txt) for more information.

In [None]:
from IPython.display import Image

Image(doc_agent.graph.get_graph().draw_png())

We will ask a query that lets the agent to use the function.

In [None]:
# Simple Question
messages = [HumanMessage(content="Can you list down all the schemas in my database ?")]

# Complicated Question
# messages = [HumanMessage(content="Can you list down all the schemas in my database ? Along with that, List down all the tables present in each of these schemas.")]

result = doc_agent.graph.invoke({"messages": messages})

We will ask a query that is general and does not involve a tool call.

In [20]:
# Simple Question
messages = [HumanMessage(content="What is python ? Is it a popular programming language. Explain within 50 words.")]

result = doc_agent.graph.invoke({"messages": messages})

In the result variable, we can observe all the messages generated during execution.

We can observe that the agent has used the tool to fetch the schema information. If the question is general, then the tool call will not happen.

In [None]:
result["messages"]

Here’s the final result. It looks pretty decent.

In [None]:
print(result["messages"][-1].content)

## Using Prebuilt Agents

In [16]:
from langgraph.prebuilt import create_react_agent

prebuilt_doc_agent = create_react_agent(model, [execute_sql], state_modifier=prompt)

In [None]:
from IPython.display import Image

Image(prebuilt_doc_agent.get_graph().draw_png())

In [None]:
inputs = {"messages": [("user", "What columns are in airflow_test.dag_runs table?")]}

result = prebuilt_doc_agent.invoke(inputs)

In [None]:
result["messages"]

In [None]:
print(result["messages"][-1].content)

In [23]:
def print_stream(stream):
    for s in stream:
        message = s["messages"][-1]
        if isinstance(message, tuple):
            print(message)
        else:
            message.pretty_print()

In [None]:
print_stream(prebuilt_doc_agent.stream(inputs, stream_mode="values"))

## Persistence and streaming

Let us in memory context saving using MemorySaver to provide context to the agents.

In [38]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

In [39]:
prebuilt_doc_agent = create_react_agent(model, [execute_sql], checkpointer=memory)

In [None]:
thread = {"configurable": {"thread_id": "18"}}
messages = [
    HumanMessage(content="What info do we have in airflow_test.dag_runs table?")
]

for event in prebuilt_doc_agent.stream({"messages": messages}, thread):
    for v in event.values():
        v["messages"][-1].pretty_print()

We use the same thread and the agent has the possibility to understand the context from previous invocation.

In [None]:
followup_messages = [
    HumanMessage(
        content="I would like to know the column names and types."
    )
]

for event in prebuilt_doc_agent.stream({"messages": followup_messages}, thread):
    for v in event.values():
        v["messages"][-1].pretty_print()

Let us check the memory object to understand the persistence.

In [None]:
for obj in memory.list(config=None):
    print(obj)

Let us create another thread instance and check if the agent can has sufficient context to answer the question. You could observe that the agent could not respond appropriately.

In [None]:
thread = {"configurable": {"thread_id": "20"}}

followup_messages = [
    HumanMessage(content="I would like to know the column names and types.")
]

for event in prebuilt_doc_agent.stream({"messages": followup_messages}, thread):
    for v in event.values():
        v["messages"][-1].pretty_print()

In real-life applications, managing memory is essential. Conversations might become pretty lengthy, and at some point, it won’t be practical to pass the whole history to LLM every time. Therefore, it’s worth trimming or filtering messages. We won’t go deep into the specifics here, but you can find guidance on it in the [LangGraph documentation](https://langchain-ai.github.io/langgraph/how-tos/memory/manage-conversation-history/). Another option to compress the conversational history is using summarization [example](https://langchain-ai.github.io/langgraph/how-tos/memory/add-summary-conversation-history/#how-to-add-summary-of-the-conversation-history).