![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)

# Agent Memory Using Redis and LangGraph
This notebook demonstrates how to manage short-term and long-term agent memory using Redis and LangGraph. We'll explore:

1. Short-term memory management using LangGraph's checkpointer
2. Long-term memory storage and retrieval using RedisVL
3. Manually storing and retrieving long-term memory vs. exposing tool access (AKA function-calling)
4. Managing conversation history size with summarization
5. Memory consolidation and decay

## Let's Begin!
<a href="https://colab.research.google.com/github/redis-developer/redis-ai-resources/blob/main/python-recipes/agents/03_memory_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Packages

In [6]:
%pip install -q langchain-openai langgraph-checkpoint langgraph-checkpoint-redis "langchain-community>=0.2.11" tavily-python langchain-redis pydantic ulid

5075.53s - pydevd: Sending message related to process being replaced timed-out after 5 seconds



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


### OPEN_AI_API key

You must add an OpenAI API key with billing information enabled is required for this lesson.

In [2]:
# NBVAL_SKIP
import getpass
import os


def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("OPENAI_API_KEY")

## Run redis

### For colab

In [None]:
# NBVAL_SKIP
%%sh
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update  > /dev/null 2>&1
sudo apt-get install redis-stack-server  > /dev/null 2>&1
redis-stack-server --daemonize ye

#### For Alternative Environments
There are many ways to get the necessary redis-stack instance running
1. On cloud, deploy a [FREE instance of Redis in the cloud](https://redis.com/try-free/). Or, if you have your
own version of Redis Enterprise running, that works too!
2. Per OS, [see the docs](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/)
3. With docker: `docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest`

## Test connection

In [None]:
from redis import Redis

# Use the environment variable if set, otherwise default to localhost
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")

redis_client = Redis.from_url(REDIS_URL)
redis_client.ping()

True

## Short-Term vs. Long-Term Memory

The agent uses **short-term memory** and **long-term memory**. The implementations
of short-term and long-term memory differ, as does how agent uses them. Let's
dig into the details. We'll return to code soon!

### Short-Term Memory

For short-term memory, the agent keeps track of conversation history with Redis.
Because this is a LangGraph agent, we use the `RedisSaver` class to achieve
this. `RedisSaver` is what LangGraph refers to as a _checkpointer_. You can read
more about checkpointers in the [LangGraph
documentation](https://langchain-ai.github.io/langgraph/concepts/persistence/).

If Redis persistence is on, then Redis will persist short-term memory to
disk. This means if you quit the agent and return with the same thread ID and
user ID, you'll resume the same conversation.

Conversation histories can grow long and pollute an LLM's context window. To manage
this, after every "turn" of a conversation, the agent summarizes messages when the
conversation grows past a configurable threshold. Checkpointers do not do this by
default, so we've created a node in the graph for summarization.

**NOTE**: We'll see example code for the summarization node later in this notebook.

### Long-Term Memory

Aside from conversation history, the agent stores long-term memories in a search
index in Redis, using [RedisVL](https://docs.redisvl.com/en/latest/).

The agent tracks two types of long-term memories:

- **Episodic**: User-specific experiences and preferences
- **Semantic**: General knowledge about travel destinations and requirements

**NOTE** If you're familiar with the [CoALA
paper](https://arxiv.org/abs/2309.02427), the terms "episodic" and "semantic"
here map to the same concepts in the paper. CoALA discusses a third type of
memory, _procedural_. In our example, we consider logic encoded in Python in the
agent codebase to be its procedural memory.

### Representing Long-Term Memory in Python
We use a couple of Pydantic models to represent long-term memories, both before
and after they're stored in Redis:

In [None]:
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel, Field
import ulid


class MemoryType(str, Enum):
    """
    The type of a long-term memory.

    EPISODIC: User specific experiences and preferences

    SEMANTIC: General knowledge on top of the user's preferences and LLM's
    training data.
    """

    EPISODIC = "episodic"
    SEMANTIC = "semantic"


class Memory(BaseModel):
    """Represents a single long-term memory."""

    content: str
    memory_type: MemoryType
    metadata: str
    
    
class Memories(BaseModel):
    """
    A list of memories extracted from a conversation by an LLM.

    NOTE: OpenAI's structured output requires us to wrap the list in an object.
    """

    memories: List[Memory]


class StoredMemory(Memory):
    """A stored long-term memory"""

    id: str  # The redis key
    memory_id: ulid.ULID = Field(default_factory=lambda: ulid.ULID())
    created_at: datetime = Field(default_factory=datetime.now)
    user_id: Optional[str] = None
    thread_id: Optional[str] = None
    memory_type: Optional[MemoryType] = None

We'll return to these models soon, to see them in action!

## Short-Term Memory Storage and Retrieval

The `RedisSaver` class handles the basics of short-term memory storage for us,
so we don't need to do anything here.

## Long-Term Memory Storage and Retrieval

We use RedisVL to store and retrieve long-term memories with vector embeddings.
This allows for semantic search of past experiences and knowledge.

Let's set up a new search index to store and query memories:

In [None]:
from redisvl.index import SearchIndex
from redisvl.schema.schema import IndexSchema

# Define schema for long-term memory index
memory_schema = IndexSchema(
    **{
        "index": {
            "name": "agent_memories",
            "prefix": "memory:",
            "key_separator": ":",
            "storage_type": "json",
        },
        "fields": [
            {"name": "content", "type": "text"},
            {"name": "memory_type", "type": "tag"},
            {"name": "metadata", "type": "text"},
            {"name": "created_at", "type": "text"},
            {"name": "user_id", "type": "tag"},
            {"name": "memory_id", "type": "tag"},
            {
                "name": "embedding",
                "type": "vector",
                "attrs": {
                    "algorithm": "flat",
                    "dims": 1536,  # OpenAI embedding dimension
                    "distance_metric": "cosine",
                    "datatype": "float32",
                },
            },
        ],
    }
)

# Create search index
try:
    long_term_memory_index = SearchIndex(
        schema=memory_schema, redis_client=redis_client, overwrite=True
    )
    long_term_memory_index.create()
    print("Long-term memory index ready")
except Exception as e:
    print(f"Error creating index: {e}")

### Storage and Retrieval Functions

Now that we have a search index in Redis, we can write functions to store and
retrieve memories. We can use RedisVL to write these.

First, we'll write a utility function to check if a memory similar to a given
memory already exists in the index. Later, we can use this to avoid storing
duplicate memories.

#### Checking for Similar Memories

In [None]:
import logging

from langchain_openai import OpenAIEmbeddings
from redisvl.query import VectorRangeQuery
from redisvl.query.filter import Tag

logger = logging.getLogger(__name__)

# If we have any memories that aren't associated with a user, we'll use this ID.
SYSTEM_USER_ID = "system"

openai_embed = OpenAIEmbeddings(model="text-embedding-ada-002")


def similar_memory_exists(
    content: str,
    memory_type: MemoryType,
    user_id: str = SYSTEM_USER_ID,
    thread_id: Optional[str] = None,
    distance_threshold: float = 0.1,
) -> bool:
    """Check if a similar long-term memory already exists in Redis."""
    query_embedding = openai_embed.embed_query(content)
    filters = (Tag("user_id") == user_id) & (Tag("memory_type") == memory_type)
    if thread_id:
        filters = filters & (Tag("thread_id") == thread_id)

    # Search for similar memories
    vector_query = VectorRangeQuery(
        vector=query_embedding,
        num_results=1,
        vector_field_name="embedding",
        filter_expression=filters,
        distance_threshold=distance_threshold,
        return_fields=["id"],
    )
    results = long_term_memory_index.query(vector_query)
    logger.debug(f"Similar memory search results: {results}")

    if results:
        logger.debug(
            f"{len(results)} similar {'memory' if results.count == 1 else 'memories'} found. First: "
            f"{results[0]['id']}. Skipping storage."
        )
        return True

    return False


#### Storing and Retrieving Long-Term Memories

We'll use the `similar_memory_exists()` function when we store memories:

In [None]:

from datetime import datetime
from typing import List, Optional, Union

import ulid


def store_memory(
    content: str,
    memory_type: MemoryType,
    user_id: str = SYSTEM_USER_ID,
    thread_id: Optional[str] = None,
    metadata: Optional[str] = None,
):
    """Store a long-term memory in Redis, avoiding duplicates."""
    if metadata is None:
        metadata = "{}"

    logger.info(f"Preparing to store memory: {content}")

    if similar_memory_exists(content, memory_type, user_id, thread_id):
        logger.info("Similar memory found, skipping storage")
        return

    embedding = openai_embed.embed_query(content)

    memory_data = {
        "user_id": user_id or SYSTEM_USER_ID,
        "content": content,
        "memory_type": memory_type.value,
        "metadata": metadata,
        "created_at": datetime.now().isoformat(),
        "embedding": embedding,
        "memory_id": str(ulid.ULID()),
        "thread_id": thread_id,
    }

    try:
        long_term_memory_index.load([memory_data])
    except Exception as e:
        logger.error(f"Error storing memory: {e}")
        return

    logger.info(f"Stored {memory_type} memory: {content}")
    


And now that we're storing memories, we can retrieve them:

In [None]:
def retrieve_memories(
    query: str,
    memory_type: Union[Optional[MemoryType], List[MemoryType]] = None,
    user_id: str = SYSTEM_USER_ID,
    thread_id: Optional[str] = None,
    distance_threshold: float = 0.1,
    limit: int = 5,
) -> List[StoredMemory]:
    """Retrieve relevant memories from Redis"""
    # Create vector query
    logger.debug(f"Retrieving memories for query: {query}")
    vector_query = VectorRangeQuery(
        vector=openai_embed.embed_query(query),
        return_fields=[
            "content",
            "memory_type",
            "metadata",
            "created_at",
            "memory_id",
            "thread_id",
            "user_id",
        ],
        num_results=limit,
        vector_field_name="embedding",
        dialect=2,
        distance_threshold=distance_threshold,
    )

    base_filters = [f"@user_id:{{{user_id or SYSTEM_USER_ID}}}"]

    if memory_type:
        if isinstance(memory_type, list):
            base_filters.append(f"@memory_type:{{{'|'.join(memory_type)}}}")
        else:
            base_filters.append(f"@memory_type:{{{memory_type.value}}}")

    if thread_id:
        base_filters.append(f"@thread_id:{{{thread_id}}}")

    vector_query.set_filter(" ".join(base_filters))

    # Execute search
    results = long_term_memory_index.query(vector_query)

    # Parse results
    memories = []
    for doc in results:
        try:
            memory = StoredMemory(
                id=doc["id"],
                memory_id=doc["memory_id"],
                user_id=doc["user_id"],
                thread_id=doc.get("thread_id", None),
                memory_type=MemoryType(doc["memory_type"]),
                content=doc["content"],
                created_at=doc["created_at"],
                metadata=doc["metadata"],
            )
            memories.append(memory)
        except Exception as e:
            logger.error(f"Error parsing memory: {e}")
            continue
    return memories

## Managing Long-Term Memory Manually vs. Calling Tools

While making LLM queries, agents can store and retrieve relevant long-term
memories in one of two ways (and more, but these are the two we'll discuss):

1. Expose memory retrieval and storage as "tools" that the LLM can decide to call contextually.
2. Manually augment prompts with relevant memories, and manually extract and store relevant memories.

These approaches both have tradeoffs.

**Tool-calling** leaves the decision to store a memory or find relevant memories
up to the LLM. This can add latency to requests. It will generally result in
fewer calls to Redis but will also sometimes miss out on retrieving potentially
relevant context and/or extracting relevant memories from a conversation.

**Manual memory management** will result in more calls to Redis but will produce
fewer round-trip LLM requests, reducing latency. Manually extracting memories
will generally extract more memories than tool calls, which will store more data
in Redis and should result in more context added to LLM requests. More context
means more contextual awareness but also higher token spend.

You can test both approaches with this agent by changing the `memory_strategy`
variable.

## Managing Memory Manually
With the manual memory management strategy, we're going to extract memories after
every interaction between the user and the agent. We're then going to retrieve
those memories during future interactions before we send the query.

### Extracting Memories
We'll call this `extract_memories` function manually after each interaction:

In [None]:
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph.message import MessagesState


class RuntimeState(MessagesState):
    """Agent state (just messages for now)"""

    pass


memory_llm = ChatOpenAI(model="gpt-4o", temperature=0.3).with_structured_output(
    Memories
)


def extract_memories(
    last_processed_message_id: Optional[str],
    state: RuntimeState,
    config: RunnableConfig,
) -> Optional[str]:
    """Extract and store memories in long-term memory"""
    logger.debug(f"Last message ID is: {last_processed_message_id}")

    if len(state["messages"]) < 3:  # Need at least a user message and agent response
        logger.debug("Not enough messages to extract memories")
        return last_processed_message_id

    user_id = config.get("configurable", {}).get("user_id", None)
    if not user_id:
        logger.warning("No user ID found in config when extracting memories")
        return last_processed_message_id

    # Get the messages
    messages = state["messages"]

    # Find the newest message ID (or None if no IDs)
    newest_message_id = None
    for msg in reversed(messages):
        if hasattr(msg, "id") and msg.id:
            newest_message_id = msg.id
            break

    logger.debug(f"Newest message ID is: {newest_message_id}")

    # If we've already processed up to this message ID, skip
    if (
        last_processed_message_id
        and newest_message_id
        and last_processed_message_id == newest_message_id
    ):
        logger.debug(f"Already processed messages up to ID {newest_message_id}")
        return last_processed_message_id

    # Find the index of the message with last_processed_message_id
    start_index = 0
    if last_processed_message_id:
        for i, msg in enumerate(messages):
            if hasattr(msg, "id") and msg.id == last_processed_message_id:
                start_index = i + 1  # Start processing from the next message
                break

    # Check if there are messages to process
    if start_index >= len(messages):
        logger.debug("No new messages to process since last processed message")
        return newest_message_id

    # Get only the messages after the last processed message
    messages_to_process = messages[start_index:]

    # If there are not enough messages to process, include some context
    if len(messages_to_process) < 3 and start_index > 0:
        # Include up to 3 messages before the start_index for context
        context_start = max(0, start_index - 3)
        messages_to_process = messages[context_start:]

    # Format messages for the memory agent
    message_history = "\n".join(
        [
            f"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}"
            for msg in messages_to_process
        ]
    )

    prompt = f"""
    You are a long-memory manager. Your job is to analyze this message history
    and extract information that might be useful in future conversations.
    
    Extract two types of memories:
    1. EPISODIC: Personal experiences and preferences specific to this user
       Example: "User prefers window seats" or "User had a bad experience in Paris"
    
    2. SEMANTIC: General facts and knowledge about travel that could be useful
       Example: "The best time to visit Japan is during cherry blossom season in April"
    
    For each memory, provide:
    - Type: The memory type (EPISODIC/SEMANTIC)
    - Content: The actual information to store
    - Metadata: Relevant tags and context (as JSON)
    
    IMPORTANT RULES:
    1. Only extract information that would be genuinely useful for future interactions.
    2. Do not extract procedural knowledge - that is handled by the system's built-in tools and prompts.
    3. You are a large language model, not a human - do not extract facts that you already know.
    
    Message history:
    {message_history}
    
    Extracted memories:
    """

    memories_to_store: Memories = memory_llm.invoke([HumanMessage(content=prompt)])  # type: ignore

    # Store each extracted memory
    for memory_data in memories_to_store.memories:
        store_memory(
            content=memory_data.content,
            memory_type=memory_data.memory_type,
            user_id=user_id,
            metadata=memory_data.metadata,
        )

    # Return data with the newest processed message ID
    return newest_message_id


On future interactions, we'll query for relevant memories and add them to
the prompt:

In [None]:
def retrieve_relevant_memories(
    state: RuntimeState, config: RunnableConfig
) -> RuntimeState:
    """Retrieve relevant memories based on the current conversation."""
    if not state["messages"]:
        logger.debug("No messages in state")
        return state

    latest_message = state["messages"][-1]
    if not isinstance(latest_message, HumanMessage):
        logger.debug("Latest message is not a HumanMessage: ", latest_message)
        return state

    user_id = config.get("configurable", {}).get("user_id", SYSTEM_USER_ID)

    query = str(latest_message.content)
    relevant_memories = retrieve_memories(
        query=query,
        memory_type=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
        limit=5,
        user_id=user_id,
        distance_threshold=0.3,
    )

    logger.debug(f"All relevant memories: {relevant_memories}")

    if relevant_memories:
        memory_context = "\n\n### Relevant memories from previous conversations:\n"

        # Group by memory type
        memory_types = {
            MemoryType.EPISODIC: "User Preferences & History",
            MemoryType.SEMANTIC: "Travel Knowledge",
        }

        for mem_type, type_label in memory_types.items():
            memories_of_type = [
                m for m in relevant_memories if m.memory_type == mem_type
            ]
            if memories_of_type:
                memory_context += f"\n**{type_label}**:\n"
                for mem in memories_of_type:
                    memory_context += f"- {mem.content}\n"

        augmented_message = HumanMessage(content=f"{query}\n{memory_context}")
        state["messages"][-1] = augmented_message

        logger.debug(f"Augmented message: {augmented_message.content}")

    return state.copy()
