Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions agent_memory_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
migrate_add_memory_hashes_1,
migrate_add_memory_type_3,
)
from agent_memory_server.utils.redis import ensure_search_index_exists, get_redis_conn
from agent_memory_server.utils.redis import get_redis_conn


logger = get_logger(__name__)
Expand All @@ -46,11 +46,26 @@ def rebuild_index():
"""Rebuild the search index."""
import asyncio

from agent_memory_server.vectorstore_adapter import RedisVectorStoreAdapter
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter

configure_logging()

async def setup_and_run():
redis = await get_redis_conn()
await ensure_search_index_exists(redis, overwrite=True)
# Get the vectorstore adapter
adapter = await get_vectorstore_adapter()

# Only Redis adapter supports index rebuilding
if isinstance(adapter, RedisVectorStoreAdapter):
index = adapter.vectorstore.index
logger.info(f"Dropping and recreating index '{index.name}'")
index.create(overwrite=True)
logger.info("Index rebuilt successfully")
else:
logger.error(
"Index rebuilding is only supported for Redis vectorstore. "
"Current vectorstore does not support this operation."
)

asyncio.run(setup_and_run())

Expand Down Expand Up @@ -200,8 +215,7 @@ def schedule_task(task_path: str, args: list[str]):
sys.exit(1)

async def setup_and_run_task():
redis = await get_redis_conn()
await ensure_search_index_exists(redis)
await get_redis_conn()

# Import the task function
module_path, function_name = task_path.rsplit(".", 1)
Expand Down Expand Up @@ -269,14 +283,8 @@ async def _ensure_stream_and_group():
raise

async def _run_worker():
# Ensure Redis stream/consumer group and search index exist before starting worker
await _ensure_stream_and_group()
try:
redis = await get_redis_conn()
# Don't overwrite if an index already exists; just ensure it's present
await ensure_search_index_exists(redis, overwrite=False)
except Exception as e:
logger.warning(f"Failed to ensure search index exists: {e}")
await get_redis_conn()
await Worker.run(
docket_name=settings.docket_name,
url=settings.redis_url,
Expand Down
18 changes: 1 addition & 17 deletions agent_memory_server/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@
rerank_with_recency,
update_memory_hash_if_text_changed,
)
from agent_memory_server.utils.redis import (
ensure_search_index_exists,
get_redis_conn,
)
from agent_memory_server.utils.redis import get_redis_conn
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter


Expand Down Expand Up @@ -614,19 +611,6 @@ async def compact_long_term_memories(
index_name = Keys.search_index_name()
logger.info(f"Using index '{index_name}' for semantic duplicate compaction.")

# Check if the index exists before proceeding
try:
await redis_client.execute_command(f"FT.INFO {index_name}")
except Exception as info_e:
if "unknown index name" in str(info_e).lower():
logger.info(f"Search index {index_name} doesn't exist, creating it")
# Ensure 'get_search_index' is called with the correct name to create it if needed
await ensure_search_index_exists(redis_client, index_name=index_name)
else:
logger.warning(
f"Error checking index '{index_name}': {info_e} - attempting to proceed."
)

# Get all memories using the vector store adapter
try:
# Convert filters to adapter format
Expand Down
25 changes: 2 additions & 23 deletions agent_memory_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from agent_memory_server.logging import get_logger
from agent_memory_server.utils.redis import (
_redis_pool as connection_pool,
ensure_search_index_exists,
get_redis_conn,
)

Expand Down Expand Up @@ -75,29 +74,9 @@ async def lifespan(app: FastAPI):
"Long-term memory requires OpenAI for embeddings, but OpenAI API key is not set"
)

# Set up RediSearch index if long-term memory is enabled
# Set up Redis connection if long-term memory is enabled
if settings.long_term_memory:
redis = await get_redis_conn()

# Get embedding dimensions from model config
embedding_model_config = MODEL_CONFIGS.get(settings.embedding_model)
vector_dimensions = (
str(embedding_model_config.embedding_dimensions)
if embedding_model_config
else "1536"
)
distance_metric = "COSINE"

try:
await ensure_search_index_exists(
redis,
index_name=settings.redisvl_index_name,
vector_dimensions=vector_dimensions,
distance_metric=distance_metric,
)
except Exception as e:
logger.error(f"Failed to ensure RediSearch index: {e}")
raise
await get_redis_conn()

# Initialize Docket for background tasks if enabled
if settings.use_docket:
Expand Down
20 changes: 6 additions & 14 deletions agent_memory_server/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,10 @@ async def call_tool(self, name, arguments):
return await super().call_tool(name, arguments)

async def run_sse_async(self):
"""Ensure Redis search index exists before starting SSE server."""
from agent_memory_server.utils.redis import (
ensure_search_index_exists,
get_redis_conn,
)
"""Start SSE server."""
from agent_memory_server.utils.redis import get_redis_conn

redis = await get_redis_conn()
await ensure_search_index_exists(redis)
await get_redis_conn()

# Run the SSE server using our custom implementation
import uvicorn
Expand All @@ -176,14 +172,10 @@ async def run_sse_async(self):
).serve()

async def run_stdio_async(self):
"""Ensure Redis search index exists before starting STDIO MCP server."""
from agent_memory_server.utils.redis import (
ensure_search_index_exists,
get_redis_conn,
)
"""Start STDIO MCP server."""
from agent_memory_server.utils.redis import get_redis_conn

redis = await get_redis_conn()
await ensure_search_index_exists(redis)
await get_redis_conn()
return await super().run_stdio_async()


Expand Down
37 changes: 0 additions & 37 deletions agent_memory_server/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
from typing import Any

from redis.asyncio import Redis
from redisvl.index import AsyncSearchIndex

from agent_memory_server.config import settings
from agent_memory_server.vectorstore_adapter import RedisVectorStoreAdapter
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter


logger = logging.getLogger(__name__)
_redis_pool: Redis | None = None
_index: AsyncSearchIndex | None = None


async def get_redis_conn(url: str = settings.redis_url, **kwargs) -> Redis:
Expand All @@ -35,39 +31,6 @@ async def get_redis_conn(url: str = settings.redis_url, **kwargs) -> Redis:
return _redis_pool


async def ensure_search_index_exists(
redis: Redis,
index_name: str = settings.redisvl_index_name,
vector_dimensions: str = settings.redisvl_vector_dimensions,
distance_metric: str = settings.redisvl_distance_metric,
overwrite: bool = True,
) -> None:
"""
Ensure that the async search index exists, create it if it doesn't.
This function is deprecated and only exists for compatibility.
The VectorStore adapter now handles index creation automatically.

Args:
redis: A Redis client instance
vector_dimensions: Dimensions of the embedding vectors
distance_metric: Distance metric to use (default: COSINE)
index_name: The name of the index
"""
# If this is Redis, creating the adapter will create the index.
adapter = await get_vectorstore_adapter()

if overwrite:
if isinstance(adapter, RedisVectorStoreAdapter):
index = adapter.vectorstore.index
if index is not None:
index.create(overwrite=True)
else:
logger.warning(
"Overwriting the search index is only supported for RedisVectorStoreAdapter. "
"Consult your vector store's documentation to learn how to recreate the index."
)


def safe_get(doc: Any, key: str, default: Any | None = None) -> Any:
"""Get a value from a Document, returning a default if the key is not present.

Expand Down
82 changes: 82 additions & 0 deletions docker-compose-task-workers.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
services:
# For testing a production-like setup, you can run this API and the
# task-worker container. This API container does NOT use --no-worker, so when
# it starts background work, the task-worker will process those tasks.
api:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
- REDIS_URL=redis://redis:6379
- PORT=8000
# Add your API keys here or use a .env file
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
# Optional configurations with defaults
- LONG_TERM_MEMORY=True
- GENERATION_MODEL=gpt-4o-mini
- EMBEDDING_MODEL=text-embedding-3-small
- ENABLE_TOPIC_EXTRACTION=True
- ENABLE_NER=True
depends_on:
- redis
volumes:
- ./agent_memory_server:/app/agent_memory_server
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8000/v1/health" ]
interval: 30s
timeout: 10s
retries: 3
command: ["agent-memory", "api", "--host", "0.0.0.0", "--port", "8000"]


mcp:
build:
context: .
dockerfile: Dockerfile
environment:
- REDIS_URL=redis://redis:6379
- PORT=9050
# Add your API keys here or use a .env file
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
ports:
- "9050:9000"
depends_on:
- redis
command: ["agent-memory", "mcp", "--mode", "sse"]

task-worker:
build:
context: .
dockerfile: Dockerfile
environment:
- REDIS_URL=redis://redis:6379
# Add your API keys here or use a .env file
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
# Optional configurations with defaults
depends_on:
- redis
command: ["agent-memory", "task-worker"]
volumes:
- ./agent_memory_server:/app/agent_memory_server
restart: unless-stopped

redis:
image: redis:8
ports:
- "16380:6379" # Redis port
volumes:
- redis_data:/data
command: redis-server --save "30 1" --loglevel warning --appendonly no --stop-writes-on-bgsave-error no
healthcheck:
test: [ "CMD", "redis-cli", "ping" ]
interval: 30s
timeout: 10s
retries: 3

volumes:
redis_data:
33 changes: 18 additions & 15 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,30 @@ def test_version_command(self):
class TestRebuildIndex:
"""Tests for the rebuild_index command."""

@patch("agent_memory_server.cli.ensure_search_index_exists")
@patch("agent_memory_server.cli.get_redis_conn")
def test_rebuild_index_command(self, mock_get_redis_conn, mock_ensure_index):
@patch("agent_memory_server.vectorstore_factory.get_vectorstore_adapter")
def test_rebuild_index_command(self, mock_get_adapter):
"""Test rebuild_index command execution."""
# Use AsyncMock which returns completed awaitables
mock_redis = Mock()
mock_get_redis_conn.return_value = mock_redis
mock_ensure_index.return_value = None
from agent_memory_server.vectorstore_adapter import RedisVectorStoreAdapter

# Create a mock adapter with a mock index
mock_index = Mock()
mock_index.name = "test_index"
mock_index.create = Mock()

mock_vectorstore = Mock()
mock_vectorstore.index = mock_index

mock_adapter = Mock(spec=RedisVectorStoreAdapter)
mock_adapter.vectorstore = mock_vectorstore

mock_get_adapter.return_value = mock_adapter

runner = CliRunner()
result = runner.invoke(rebuild_index)

assert result.exit_code == 0
mock_get_redis_conn.assert_called_once()
mock_ensure_index.assert_called_once_with(mock_redis, overwrite=True)
mock_get_adapter.assert_called_once()
mock_index.create.assert_called_once_with(overwrite=True)


class TestMigrateMemories:
Expand Down Expand Up @@ -440,7 +449,6 @@ def test_schedule_task_argument_parsing(self):
class TestTaskWorker:
"""Tests for the task_worker command."""

@patch("agent_memory_server.cli.ensure_search_index_exists")
@patch("agent_memory_server.cli.get_redis_conn")
@patch("docket.Worker.run")
@patch("agent_memory_server.cli.settings")
Expand All @@ -449,7 +457,6 @@ def test_task_worker_success(
mock_settings,
mock_worker_run,
mock_get_redis_conn,
mock_ensure_index,
redis_url,
):
"""Test successful task worker start."""
Expand All @@ -460,7 +467,6 @@ def test_task_worker_success(
mock_worker_run.return_value = None
mock_redis = AsyncMock()
mock_get_redis_conn.return_value = mock_redis
mock_ensure_index.return_value = None

runner = CliRunner()
result = runner.invoke(
Expand All @@ -481,7 +487,6 @@ def test_task_worker_docket_disabled(self, mock_settings):
assert result.exit_code == 1
assert "Docket is disabled in settings" in result.output

@patch("agent_memory_server.cli.ensure_search_index_exists")
@patch("agent_memory_server.cli.get_redis_conn")
@patch("docket.Worker.run")
@patch("agent_memory_server.cli.settings")
Expand All @@ -490,7 +495,6 @@ def test_task_worker_default_params(
mock_settings,
mock_worker_run,
mock_get_redis_conn,
mock_ensure_index,
redis_url,
):
"""Test task worker with default parameters."""
Expand All @@ -501,7 +505,6 @@ def test_task_worker_default_params(
mock_worker_run.return_value = None
mock_redis = AsyncMock()
mock_get_redis_conn.return_value = mock_redis
mock_ensure_index.return_value = None

runner = CliRunner()
result = runner.invoke(task_worker)
Expand Down
Loading