# How to create a custom checkpointer using Postgres

When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions.

This example shows how to use `Postgres` as the backend for persisting checkpoint state.

NOTE: this is just an example implementation. You can implement your own checkpointer using a different database or modify this one as long as it conforms to the `BaseCheckpointSaver` interface.

## Checkpointer implementation

In [1]:
# %%capture --no-stderr
# %pip install -U psycopg psycopg-pool langgraph

In [2]:
import os
import logging
from dotenv import load_dotenv

from langgraph.checkpoint.sqlite import SqliteSaver
from IPython.display import Image, display

In [3]:
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

load_dotenv()

True

In [4]:
"""Implementation of a langgraph checkpoint saver using Postgres."""
from contextlib import asynccontextmanager, contextmanager
from typing import (
    Any,
    AsyncGenerator,
    AsyncIterator,
    Generator,
    Optional,
    Union,
    Tuple,
    List,
    Sequence,
)

import psycopg
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, CheckpointTuple
from psycopg_pool import AsyncConnectionPool, ConnectionPool


class JsonAndBinarySerializer(JsonPlusSerializer):
    def _default(self, obj):
        if isinstance(obj, (bytes, bytearray)):
            return self._encode_constructor_args(
                obj.__class__, method="fromhex", args=[obj.hex()]
            )
        return super()._default(obj)

    def dumps(self, obj: Any) -> tuple[str, bytes]:
        if isinstance(obj, bytes):
            return "bytes", obj
        elif isinstance(obj, bytearray):
            return "bytearray", obj

        return "json", super().dumps(obj)

    def loads(self, s: tuple[str, bytes]) -> Any:
        if s[0] == "bytes":
            return s[1]
        elif s[0] == "bytearray":
            return bytearray(s[1])
        elif s[0] == "json":
            return super().loads(s[1])
        else:
            raise NotImplementedError(f"Unknown serialization type: {s[0]}")


@contextmanager
def _get_sync_connection(
    connection: Union[psycopg.Connection, ConnectionPool, None],
) -> Generator[psycopg.Connection, None, None]:
    """Get the connection to the Postgres database."""
    if isinstance(connection, psycopg.Connection):
        yield connection
    elif isinstance(connection, ConnectionPool):
        with connection.connection() as conn:
            yield conn
    else:
        raise ValueError(
            "Invalid sync connection object. Please initialize the check pointer "
            f"with an appropriate sync connection object. "
            f"Got {type(connection)}."
        )


@asynccontextmanager
async def _get_async_connection(
    connection: Union[psycopg.AsyncConnection, AsyncConnectionPool, None],
) -> AsyncGenerator[psycopg.AsyncConnection, None]:
    """Get the connection to the Postgres database."""
    if isinstance(connection, psycopg.AsyncConnection):
        yield connection
    elif isinstance(connection, AsyncConnectionPool):
        async with connection.connection() as conn:
            yield conn
    else:
        raise ValueError(
            "Invalid async connection object. Please initialize the check pointer "
            f"with an appropriate async connection object. "
            f"Got {type(connection)}."
        )


class PostgresSaver(BaseCheckpointSaver):
    sync_connection: Optional[Union[psycopg.Connection, ConnectionPool]] = None
    """The synchronous connection or pool to the Postgres database.
    
    If providing a connection object, please ensure that the connection is open
    and remember to close the connection when done.
    """
    async_connection: Optional[
        Union[psycopg.AsyncConnection, AsyncConnectionPool]
    ] = None
    """The asynchronous connection or pool to the Postgres database.
    
    If providing a connection object, please ensure that the connection is open
    and remember to close the connection when done.
    """

    def __init__(
        self,
        sync_connection: Optional[Union[psycopg.Connection, ConnectionPool]] = None,
        async_connection: Optional[
            Union[psycopg.AsyncConnection, AsyncConnectionPool]
        ] = None,
    ):
        super().__init__(serde=JsonPlusSerializer())
        self.sync_connection = sync_connection
        self.async_connection = async_connection

    @contextmanager
    def _get_sync_connection(self) -> Generator[psycopg.Connection, None, None]:
        """Get the connection to the Postgres database."""
        with _get_sync_connection(self.sync_connection) as connection:
            yield connection

    @asynccontextmanager
    async def _get_async_connection(
        self,
    ) -> AsyncGenerator[psycopg.AsyncConnection, None]:
        """Get the connection to the Postgres database."""
        async with _get_async_connection(self.async_connection) as connection:
            yield connection

    CREATE_TABLES_QUERY = """
    CREATE TABLE IF NOT EXISTS checkpoints (
        thread_id TEXT NOT NULL,
        thread_ts TEXT NOT NULL,
        parent_ts TEXT,
        checkpoint BYTEA NOT NULL,
        metadata BYTEA NOT NULL,
        PRIMARY KEY (thread_id, thread_ts)
    );
    CREATE TABLE IF NOT EXISTS writes (
        thread_id TEXT NOT NULL,
        thread_ts TEXT NOT NULL,
        task_id TEXT NOT NULL,
        idx INTEGER NOT NULL,
        channel TEXT NOT NULL,
        value BYTEA,
        PRIMARY KEY (thread_id, thread_ts, task_id, idx)
    );
    """

    @staticmethod
    def create_tables(connection: Union[psycopg.Connection, ConnectionPool], /) -> None:
        """Create the schema for the checkpoint saver."""
        with _get_sync_connection(connection) as conn:
            with conn.cursor() as cur:
                cur.execute(PostgresSaver.CREATE_TABLES_QUERY)

    @staticmethod
    async def acreate_tables(
        connection: Union[psycopg.AsyncConnection, AsyncConnectionPool], /
    ) -> None:
        """Create the schema for the checkpoint saver."""
        async with _get_async_connection(connection) as conn:
            async with conn.cursor() as cur:
                await cur.execute(PostgresSaver.CREATE_TABLES_QUERY)

    @staticmethod
    def drop_tables(connection: psycopg.Connection, /) -> None:
        """Drop the table for the checkpoint saver."""
        with connection.cursor() as cur:
            cur.execute("DROP TABLE IF EXISTS checkpoints, writes;")

    @staticmethod
    async def adrop_tables(connection: psycopg.AsyncConnection, /) -> None:
        """Drop the table for the checkpoint saver."""
        async with connection.cursor() as cur:
            await cur.execute("DROP TABLE IF EXISTS checkpoints, writes;")

    UPSERT_CHECKPOINT_QUERY = """
    INSERT INTO checkpoints 
        (thread_id, thread_ts, parent_ts, checkpoint, metadata)
    VALUES 
        (%s, %s, %s, %s, %s)
    ON CONFLICT (thread_id, thread_ts)
    DO UPDATE SET checkpoint = EXCLUDED.checkpoint,
                  metadata = EXCLUDED.metadata;
    """

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        """Put the checkpoint for the given configuration.
        Args:
            config: The configuration for the checkpoint.
                A dict with a `configurable` key which is a dict with
                a `thread_id` key and an optional `thread_ts` key.
                For example, { 'configurable': { 'thread_id': 'test_thread' } }
            checkpoint: The checkpoint to persist.
        Returns:
            The RunnableConfig that describes the checkpoint that was just created.
            It'll contain the `thread_id` and `thread_ts` of the checkpoint.
        """
        thread_id = config["configurable"]["thread_id"]
        parent_ts = config["configurable"].get("thread_ts")
        with self._get_sync_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    self.UPSERT_CHECKPOINT_QUERY,
                    (
                        thread_id,
                        checkpoint["id"],
                        parent_ts if parent_ts else None,
                        self.serde.dumps(checkpoint),
                        self.serde.dumps(metadata),
                    ),
                )

        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": checkpoint["id"],
            },
        }

    async def aput(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        """Put the checkpoint for the given configuration.
        Args:
            config: The configuration for the checkpoint.
                A dict with a `configurable` key which is a dict with
                a `thread_id` key and an optional `thread_ts` key.
                For example, { 'configurable': { 'thread_id': 'test_thread' } }
            checkpoint: The checkpoint to persist.
        Returns:
            The RunnableConfig that describes the checkpoint that was just created.
            It'll contain the `thread_id` and `thread_ts` of the checkpoint.
        """
        thread_id = config["configurable"]["thread_id"]
        parent_ts = config["configurable"].get("thread_ts")
        async with self._get_async_connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    self.UPSERT_CHECKPOINT_QUERY,
                    (
                        thread_id,
                        checkpoint["id"],
                        parent_ts if parent_ts else None,
                        self.serde.dumps(checkpoint),
                        self.serde.dumps(metadata),
                    ),
                )

        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": checkpoint["id"],
            },
        }

    UPSERT_WRITES_QUERY = """
    INSERT INTO writes
        (thread_id, thread_ts, task_id, idx, channel, value)
    VALUES
        (%s, %s, %s, %s, %s, %s)
    ON CONFLICT (thread_id, thread_ts, task_id, idx)
    DO UPDATE SET value = EXCLUDED.value;
    """

    def put_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
    ) -> None:
        with self._get_sync_connection() as conn:
            with conn.cursor() as cur:
                cur.executemany(
                    self.UPSERT_WRITES_QUERY,
                    [
                        (
                            str(config["configurable"]["thread_id"]),
                            str(config["configurable"]["thread_ts"]),
                            task_id,
                            idx,
                            channel,
                            self.serde.dumps(value),
                        )
                        for idx, (channel, value) in enumerate(writes)
                    ],
                )
            # conn.commit()

    async def aput_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
    ) -> None:
        async with self._get_async_connection() as conn:
            async with conn.cursor() as cur:
                await cur.executemany(
                    self.UPSERT_WRITES_QUERY,
                    [
                        (
                            str(config["configurable"]["thread_id"]),
                            str(config["configurable"]["thread_ts"]),
                            task_id,
                            idx,
                            channel,
                            self.serde.dumps(value),
                        )
                        for idx, (channel, value) in enumerate(writes)
                    ],
                )
            # await conn.commit()

    LIST_CHECKPOINTS_QUERY_STR = """
    SELECT checkpoint, metadata, thread_ts, parent_ts
    FROM checkpoints
    {where}
    ORDER BY thread_ts DESC
    """

    def list(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> Generator[CheckpointTuple, None, None]:
        """Get all the checkpoints for the given configuration."""
        where, args = self._search_where(config, filter, before)
        query = self.LIST_CHECKPOINTS_QUERY_STR.format(where=where)
        if limit:
            query += f" LIMIT {limit}"
        with self._get_sync_connection() as conn:
            with conn.cursor() as cur:
                thread_id = config["configurable"]["thread_id"]
                cur.execute(query, tuple(args))
                for value in cur:
                    checkpoint, metadata, thread_ts, parent_ts = value
                    yield CheckpointTuple(
                        config={
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": thread_ts,
                            }
                        },
                        checkpoint=self.serde.loads(checkpoint),
                        metadata=self.serde.loads(metadata),
                        parent_config={
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": thread_ts,
                            }
                        }
                        if parent_ts
                        else None,
                    )

    async def alist(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> AsyncIterator[CheckpointTuple]:
        """Get all the checkpoints for the given configuration."""
        where, args = self._search_where(config, filter, before)
        query = self.LIST_CHECKPOINTS_QUERY_STR.format(where=where)
        if limit:
            query += f" LIMIT {limit}"
        async with self._get_async_connection() as conn:
            async with conn.cursor() as cur:
                thread_id = config["configurable"]["thread_id"]
                await cur.execute(query, tuple(args))
                async for value in cur:
                    checkpoint, metadata, thread_ts, parent_ts = value
                    yield CheckpointTuple(
                        config={
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": thread_ts,
                            }
                        },
                        checkpoint=self.serde.loads(checkpoint),
                        metadata=self.serde.loads(metadata),
                        parent_config={
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": thread_ts,
                            }
                        }
                        if parent_ts
                        else None,
                    )

    GET_CHECKPOINT_BY_TS_QUERY = """
    SELECT checkpoint, metadata, thread_ts, parent_ts
    FROM checkpoints
    WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s
    """

    GET_CHECKPOINT_QUERY = """
    SELECT checkpoint, metadata, thread_ts, parent_ts
    FROM checkpoints
    WHERE thread_id = %(thread_id)s
    ORDER BY thread_ts DESC LIMIT 1
    """

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """Get the checkpoint tuple for the given configuration.
        Args:
            config: The configuration for the checkpoint.
                A dict with a `configurable` key which is a dict with
                a `thread_id` key and an optional `thread_ts` key.
                For example, { 'configurable': { 'thread_id': 'test_thread' } }
        Returns:
            The checkpoint tuple for the given configuration if it exists,
            otherwise None.
            If thread_ts is None, the latest checkpoint is returned if it exists.
        """
        thread_id = config["configurable"]["thread_id"]
        thread_ts = config["configurable"].get("thread_ts")
        with self._get_sync_connection() as conn:
            with conn.cursor() as cur:
                # find the latest checkpoint for the thread_id
                if thread_ts:
                    cur.execute(
                        self.GET_CHECKPOINT_BY_TS_QUERY,
                        {
                            "thread_id": thread_id,
                            "thread_ts": thread_ts,
                        },
                    )
                else:
                    cur.execute(
                        self.GET_CHECKPOINT_QUERY,
                        {
                            "thread_id": thread_id,
                        },
                    )

                # if a checkpoint is found, return it
                if value := cur.fetchone():
                    checkpoint, metadata, thread_ts, parent_ts = value
                    if not config["configurable"].get("thread_ts"):
                        config = {
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": thread_ts,
                            }
                        }

                    # find any pending writes
                    cur.execute(
                        "SELECT task_id, channel, value FROM writes WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s",
                        {
                            "thread_id": thread_id,
                            "thread_ts": thread_ts,
                        },
                    )
                    # deserialize the checkpoint and metadata
                    return CheckpointTuple(
                        config=config,
                        checkpoint=self.serde.loads(checkpoint),
                        metadata=self.serde.loads(metadata),
                        parent_config={
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": parent_ts,
                            }
                        }
                        if parent_ts
                        else None,
                        pending_writes=[
                            (task_id, channel, self.serde.loads(value))
                            for task_id, channel, value in cur
                        ],
                    )

    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """Get the checkpoint tuple for the given configuration.
        Args:
            config: The configuration for the checkpoint.
                A dict with a `configurable` key which is a dict with
                a `thread_id` key and an optional `thread_ts` key.
                For example, { 'configurable': { 'thread_id': 'test_thread' } }
        Returns:
            The checkpoint tuple for the given configuration if it exists,
            otherwise None.
            If thread_ts is None, the latest checkpoint is returned if it exists.
        """
        thread_id = config["configurable"]["thread_id"]
        thread_ts = config["configurable"].get("thread_ts")
        async with self._get_async_connection() as conn:
            async with conn.cursor() as cur:
                # find the latest checkpoint for the thread_id
                if thread_ts:
                    await cur.execute(
                        self.GET_CHECKPOINT_BY_TS_QUERY,
                        {
                            "thread_id": thread_id,
                            "thread_ts": thread_ts,
                        },
                    )
                else:
                    await cur.execute(
                        self.GET_CHECKPOINT_QUERY,
                        {
                            "thread_id": thread_id,
                        },
                    )
                # if a checkpoint is found, return it
                if value := await cur.fetchone():
                    checkpoint, metadata, thread_ts, parent_ts = value
                    if not config["configurable"].get("thread_ts"):
                        config = {
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": thread_ts,
                            }
                        }

                    # find any pending writes
                    await cur.execute(
                        "SELECT task_id, channel, value FROM writes WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s",
                        {
                            "thread_id": thread_id,
                            "thread_ts": thread_ts,
                        },
                    )
                    # deserialize the checkpoint and metadata
                    return CheckpointTuple(
                        config=config,
                        checkpoint=self.serde.loads(checkpoint),
                        metadata=self.serde.loads(metadata),
                        parent_config={
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": parent_ts,
                            }
                        }
                        if parent_ts
                        else None,
                        pending_writes=[
                            (task_id, channel, self.serde.loads(value))
                            async for task_id, channel, value in cur
                        ],
                    )

    def _search_where(
        self,
        config: Optional[RunnableConfig],
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
    ) -> Tuple[str, List[Any]]:
        """Return WHERE clause predicates for given config, filter, and before parameters.
        Args:
            config (Optional[RunnableConfig]): The config to use for filtering.
            filter (Optional[Dict[str, Any]]): Additional filtering criteria.
            before (Optional[RunnableConfig]): A config to limit results before a certain timestamp.
        Returns:
            Tuple[str, Sequence[Any]]: A tuple containing the WHERE clause and parameter values.
        """
        wheres = []
        param_values = []

        # Add predicate for config
        if config is not None:
            wheres.append("thread_id = %s ")
            param_values.append(config["configurable"]["thread_id"])

        if filter:
            raise NotImplementedError()

        # Add predicate for limiting results before a certain timestamp
        if before is not None:
            wheres.append("thread_ts < %s")
            param_values.append(before["configurable"]["thread_ts"])

        where_clause = "WHERE " + " AND ".join(wheres) if wheres else ""
        return where_clause, param_values

## Setup environment

## Setup model and tools for the graph

In [5]:
from typing import Literal
from langchain_core.runnables import ConfigurableField
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent


@tool
def get_weather(city: Literal["nyc", "sf"]):
    """Use this to get weather information."""
    if city == "nyc":
        return "It might be cloudy in nyc"
    elif city == "sf":
        return "It's always sunny in sf"
    else:
        raise AssertionError("Unknown city")


tools = [get_weather]
model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

## Use sync connection

In [6]:
# DB_URI = "postgresql://postgres:postgres@localhost:5432/postgres?sslmode=disable"
DB_URI = os.getenv("DB_URI")

### With a connection pool

In [7]:
from psycopg_pool import ConnectionPool

pool = ConnectionPool(
    # Example configuration
    conninfo=DB_URI,
    max_size=20,
)

checkpointer = PostgresSaver(sync_connection=pool)
checkpointer.create_tables(pool)

In [8]:
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
res = graph.invoke({"messages": [("human", "what's the weather in sf")]}, config)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [8]:
res

{'messages': [HumanMessage(content="what's the weather in sf", id='6e5b6f06-8f9d-453a-a142-975b82f027d3'),
  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_QSYPVNDFQlwjaQZVoneY7GpH', 'function': {'arguments': '{"city":"sf"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 57, 'total_tokens': 71}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_8b761cb050', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-94affcb2-d3a1-4db8-b6fc-d6a9a9f102a3-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_QSYPVNDFQlwjaQZVoneY7GpH', 'type': 'tool_call'}], usage_metadata={'input_tokens': 57, 'output_tokens': 14, 'total_tokens': 71}),
  ToolMessage(content="It's always sunny in sf", name='get_weather', id='8204ac53-c99f-48ba-85ac-4c051027caef', tool_call_id='call_QSYPVNDFQlwjaQZVoneY7GpH'),
  AIMessage(content='The weather in San Francisco is always sunny

In [9]:
checkpointer.get(config)

{'v': 1,
 'ts': '2024-07-21T04:10:21.713560+00:00',
 'id': '1ef47172-5e19-671c-8003-f74751dfd697',
 'channel_values': {'messages': [HumanMessage(content="what's the weather in sf", id='6e5b6f06-8f9d-453a-a142-975b82f027d3'),
   AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_QSYPVNDFQlwjaQZVoneY7GpH', 'function': {'arguments': '{"city":"sf"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 57, 'total_tokens': 71}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_8b761cb050', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-94affcb2-d3a1-4db8-b6fc-d6a9a9f102a3-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_QSYPVNDFQlwjaQZVoneY7GpH', 'type': 'tool_call'}], usage_metadata={'input_tokens': 57, 'output_tokens': 14, 'total_tokens': 71}),
   ToolMessage(content="It's always sunny in sf", name='get_weather', id='8204ac53-c99f-48ba-85ac-4c05102

### With a connection

In [10]:
from psycopg import Connection

with Connection.connect(DB_URI) as conn:
    checkpointer = PostgresSaver(sync_connection=conn)

    graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
    config = {"configurable": {"thread_id": "2"}}
    res = graph.invoke({"messages": [("human", "what's the weather in nyc")]}, config)

    checkpoint_tuple = checkpointer.get_tuple(config)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [11]:
checkpoint_tuple

CheckpointTuple(config={'configurable': {'thread_id': '2', 'thread_ts': '1ef47136-3e5b-6a56-8003-49dc57315575'}}, checkpoint={'v': 1, 'ts': '2024-07-21T03:43:27.772508+00:00', 'id': '1ef47136-3e5b-6a56-8003-49dc57315575', 'channel_values': {'messages': [HumanMessage(content="what's the weather in nyc", id='6c4cb097-3f9f-49a8-8267-bde8995da194'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_uXeGd9iy1ZGnU7b61eSDKmUG', 'function': {'arguments': '{"city":"nyc"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 58, 'total_tokens': 73}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_611b667b19', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6f9132f3-0fd0-4d06-bd75-d4dcd8811de1-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'nyc'}, 'id': 'call_uXeGd9iy1ZGnU7b61eSDKmUG', 'type': 'tool_call'}], usage_metadata={'input_tokens': 58, 'output_tokens': 15, 'tota

## Use async connection

### With a connection pool

In [19]:
from psycopg_pool import AsyncConnectionPool

pool = AsyncConnectionPool(
    # Example configuration
    conninfo=DB_URI,
    max_size=20,
)


checkpointer = PostgresSaver(async_connection=pool)
await checkpointer.acreate_tables(pool)



In [20]:
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
config = {"configurable": {"thread_id": "4"}}
res = await graph.ainvoke(
    {"messages": [("human", "what's the weather in nyc")]}, config
)

ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='pool-7-worker-0' coro=<AsyncConnectionPool.worker() done, defined at /Users/may/.virtualenvs/generative_ai/lib/python3.11/site-packages/psycopg_pool/pool_async.py:613> wait_for=<Future cancelled>>
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='pool-7-worker-1' coro=<AsyncConnectionPool.worker() done, defined at /Users/may/.virtualenvs/generative_ai/lib/python3.11/site-packages/psycopg_pool/pool_async.py:613> wait_for=<Future cancelled>>
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='pool-7-worker-2' coro=<AsyncConnectionPool.worker() done, defined at /Users/may/.virtualenvs/generative_ai/lib/python3.11/site-packages/psycopg_pool/pool_async.py:613> wait_for=<Future cancelled>>
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='pool-8-worker-0' coro=<AsyncConnectionPool.worker() done, defined at /Users/may/.virtualenvs/gene

In [21]:
checkpoint_tuple = await checkpointer.aget_tuple(config)

In [22]:
checkpoint_tuple

CheckpointTuple(config={'configurable': {'thread_id': '4', 'thread_ts': '1ef4713a-b1f9-6c84-8003-32b5bbfe69af'}}, checkpoint={'v': 1, 'ts': '2024-07-21T03:45:27.270099+00:00', 'id': '1ef4713a-b1f9-6c84-8003-32b5bbfe69af', 'channel_values': {'messages': [HumanMessage(content="what's the weather in nyc", id='b793e62b-2c48-4554-93ed-e141521d0a55'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_pIGkjx0ClYeA4zs6JVxbcReG', 'function': {'arguments': '{"city":"nyc"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 58, 'total_tokens': 73}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_661538dc1f', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-dc668344-8de9-4fd8-911f-88334cc7160f-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'nyc'}, 'id': 'call_pIGkjx0ClYeA4zs6JVxbcReG', 'type': 'tool_call'}], usage_metadata={'input_tokens': 58, 'output_tokens': 15, 'tota

### Use connection

In [51]:
from psycopg import AsyncConnection

async with await AsyncConnection.connect(DB_URI) as conn:
    checkpointer = PostgresSaver(async_connection=conn)
    graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
    config = {"configurable": {"thread_id": "5"}}
    res = await graph.ainvoke(
        {"messages": [("human", "what's the weather in nyc")]}, config
    )
    checkpoint_tuples = [c async for c in checkpointer.alist(config)]

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [52]:
checkpoint_tuples

[CheckpointTuple(config={'configurable': {'thread_id': '5', 'thread_ts': '1ef470c1-1b99-6268-8003-e6d0fd793910'}}, checkpoint={'v': 1, 'ts': '2024-07-21T02:51:03.432846+00:00', 'id': '1ef470c1-1b99-6268-8003-e6d0fd793910', 'channel_values': {'messages': [HumanMessage(content="what's the weather in nyc", id='0c9ec6de-a0b7-4b3d-9643-2d2aa8ae5b3e'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_2iGrvVXTbXRU5VpYuSjub3Ku', 'function': {'arguments': '{"city":"nyc"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 58, 'total_tokens': 73}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_661538dc1f', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-047bcf5a-0d57-42de-8ad1-21f4a0cbac17-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'nyc'}, 'id': 'call_2iGrvVXTbXRU5VpYuSjub3Ku', 'type': 'tool_call'}], usage_metadata={'input_tokens': 58, 'output_tokens': 15, 'tot

In [55]:
print(len(checkpoint_tuples))
for checkpoint_tuple in checkpoint_tuples:
    print(checkpoint_tuple, end="\n\n")

5
CheckpointTuple(config={'configurable': {'thread_id': '5', 'thread_ts': '1ef470c1-1b99-6268-8003-e6d0fd793910'}}, checkpoint={'v': 1, 'ts': '2024-07-21T02:51:03.432846+00:00', 'id': '1ef470c1-1b99-6268-8003-e6d0fd793910', 'channel_values': {'messages': [HumanMessage(content="what's the weather in nyc", id='0c9ec6de-a0b7-4b3d-9643-2d2aa8ae5b3e'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_2iGrvVXTbXRU5VpYuSjub3Ku', 'function': {'arguments': '{"city":"nyc"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 58, 'total_tokens': 73}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_661538dc1f', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-047bcf5a-0d57-42de-8ad1-21f4a0cbac17-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'nyc'}, 'id': 'call_2iGrvVXTbXRU5VpYuSjub3Ku', 'type': 'tool_call'}], usage_metadata={'input_tokens': 58, 'output_tokens': 15, 'to

# SQLite

In [22]:
import sqlite3
from langgraph.graph import StateGraph

In [23]:
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.set_entry_point("add_one")
builder.set_finish_point("add_one")

In [24]:
# conn = sqlite3.connect("checkpoints.sqlite")
# memory = SqliteSaver(conn)
# memory = SqliteSaver.from_conn_string(":memory:")
memory = SqliteSaver.from_conn_string("checkpoints.sqlite3")
graph = builder.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "2"}}
graph.get_state(config)

StateSnapshot(values=None, next=(), config={'configurable': {'thread_id': '2'}}, metadata=None, created_at=None, parent_config=None)

In [25]:
graph.invoke(3, config)

4

In [26]:
graph.get_state(config)

StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '2', 'thread_ts': '1ef47164-f80e-6178-8001-277aef3da081'}}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, created_at='2024-07-21T04:04:22.047366+00:00', parent_config={'configurable': {'thread_id': '2', 'thread_ts': '1ef47164-f808-62aa-8000-4dbc8cf307ce'}})