diff --git a/agent-memory-client/agent_memory_client/__init__.py b/agent-memory-client/agent_memory_client/__init__.py index fcc51a5..0e6039e 100644 --- a/agent-memory-client/agent_memory_client/__init__.py +++ b/agent-memory-client/agent_memory_client/__init__.py @@ -5,7 +5,7 @@ memory management capabilities for AI agents and applications. """ -__version__ = "0.11.1" +__version__ = "0.12.0" from .client import MemoryAPIClient, MemoryClientConfig, create_memory_client from .exceptions import ( diff --git a/agent-memory-client/agent_memory_client/client.py b/agent-memory-client/agent_memory_client/client.py index 9907629..9d8461a 100644 --- a/agent-memory-client/agent_memory_client/client.py +++ b/agent-memory-client/agent_memory_client/client.py @@ -5,6 +5,7 @@ """ import asyncio +import logging # noqa: F401 import re from collections.abc import AsyncIterator, Sequence from typing import TYPE_CHECKING, Any, Literal, TypedDict @@ -39,7 +40,6 @@ RecencyConfig, SessionListResponse, WorkingMemory, - WorkingMemoryGetOrCreateResponse, WorkingMemoryResponse, ) @@ -120,10 +120,16 @@ def __init__(self, config: MemoryClientConfig): Args: config: MemoryClientConfig instance with server connection details """ + from . import __version__ + self.config = config self._client = httpx.AsyncClient( base_url=config.base_url, timeout=config.timeout, + headers={ + "User-Agent": f"agent-memory-client/{__version__}", + "X-Client-Version": __version__, + }, ) async def close(self) -> None: @@ -289,11 +295,11 @@ async def get_or_create_working_memory( namespace: str | None = None, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, - ) -> WorkingMemoryGetOrCreateResponse: + ) -> tuple[bool, WorkingMemory]: """ Get working memory for a session, creating it if it doesn't exist. - This method returns both the working memory and whether it was created or found. + This method returns a tuple with the creation status and the working memory. This is important for applications that need to know if they're working with a new session or an existing one. @@ -305,24 +311,24 @@ async def get_or_create_working_memory( context_window_max: Optional direct specification of context window tokens Returns: - WorkingMemoryGetOrCreateResponse containing the memory and creation status + Tuple of (created: bool, memory: WorkingMemory) + - created: True if the session was created, False if it already existed + - memory: The WorkingMemory object Example: ```python # Get or create session memory - result = await client.get_or_create_working_memory( + created, memory = await client.get_or_create_working_memory( session_id="chat_session_123", user_id="user_456" ) - if result.created: - print("Created new session") + if created: + logging.info("Created new session") else: - print("Found existing session") + logging.info("Found existing session") - # Access the memory - memory = result.memory - print(f"Session has {len(memory.messages)} messages") + logging.info(f"Session has {len(memory.messages)} messages") ``` """ try: @@ -334,29 +340,54 @@ async def get_or_create_working_memory( model_name=model_name, context_window_max=context_window_max, ) - return WorkingMemoryGetOrCreateResponse( - memory=existing_memory, created=False - ) - except Exception: - # Session doesn't exist, create it - empty_memory = WorkingMemory( - session_id=session_id, - namespace=namespace or self.config.default_namespace, - messages=[], - memories=[], - data={}, - user_id=user_id, - ) - created_memory = await self.put_working_memory( - session_id=session_id, - memory=empty_memory, - user_id=user_id, - model_name=model_name, - context_window_max=context_window_max, - ) + # Check if this is an unsaved session (deprecated behavior for old clients) + if getattr(existing_memory, "unsaved", None) is True: + # This is an unsaved session - we need to create it properly + empty_memory = WorkingMemory( + session_id=session_id, + namespace=namespace or self.config.default_namespace, + messages=[], + memories=[], + data={}, + user_id=user_id, + ) + + created_memory = await self.put_working_memory( + session_id=session_id, + memory=empty_memory, + user_id=user_id, + model_name=model_name, + context_window_max=context_window_max, + ) - return WorkingMemoryGetOrCreateResponse(memory=created_memory, created=True) + return (True, created_memory) + + return (False, existing_memory) + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + # Session doesn't exist, create it + empty_memory = WorkingMemory( + session_id=session_id, + namespace=namespace or self.config.default_namespace, + messages=[], + memories=[], + data={}, + user_id=user_id, + ) + + created_memory = await self.put_working_memory( + session_id=session_id, + memory=empty_memory, + user_id=user_id, + model_name=model_name, + context_window_max=context_window_max, + ) + + return (True, created_memory) + else: + # Re-raise other HTTP errors + raise async def put_working_memory( self, @@ -484,11 +515,10 @@ async def set_working_memory_data( existing_memory = None if preserve_existing: try: - result_obj = await self.get_or_create_working_memory( + created, existing_memory = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace, ) - existing_memory = result_obj.memory except Exception: existing_memory = None @@ -544,11 +574,10 @@ async def add_memories_to_working_memory( ``` """ # Get existing memory - result_obj = await self.get_or_create_working_memory( + created, existing_memory = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace, ) - existing_memory = result_obj.memory # Determine final memories list if replace or not existing_memory: @@ -610,7 +639,7 @@ async def create_long_term_memory( ] response = await client.create_long_term_memory(memories) - print(f"Stored memories: {response.status}") + logging.info(f"Stored memories: {response.status}") ``` """ # Apply default namespace and ensure IDs are present @@ -764,9 +793,9 @@ async def search_long_term_memory( distance_threshold=0.3 ) - print(f"Found {results.total} memories") + logging.info(f"Found {results.total} memories") for memory in results.memories: - print(f"- {memory.text[:100]}... (distance: {memory.dist})") + logging.info(f"- {memory.text[:100]}... (distance: {memory.dist})") ``` """ # Convert dictionary filters to their proper filter objects if needed @@ -916,9 +945,9 @@ async def search_memory_tool( min_relevance=0.7 ) - print(result["summary"]) # "Found 2 relevant memories for: user preferences about UI themes" + logging.info(result["summary"]) # "Found 2 relevant memories for: user preferences about UI themes" for memory in result["memories"]: - print(f"- {memory['text']} (score: {memory['relevance_score']})") + logging.info(f"- {memory['text']} (score: {memory['relevance_score']})") ``` LLM Framework Integration: @@ -1119,18 +1148,17 @@ async def get_working_memory_tool( session_id="current_session" ) - print(memory_state["summary"]) # Human-readable summary - print(f"Messages: {memory_state['message_count']}") - print(f"Memories: {len(memory_state['memories'])}") + logging.info(memory_state["summary"]) # Human-readable summary + logging.info(f"Messages: {memory_state['message_count']}") + logging.info(f"Memories: {len(memory_state['memories'])}") ``` """ try: - result_obj = await self.get_or_create_working_memory( + created, result = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace or self.config.default_namespace, user_id=user_id, ) - result = result_obj.memory # Format for LLM consumption message_count = len(result.messages) if result.messages else 0 @@ -1200,24 +1228,23 @@ async def get_or_create_working_memory_tool( ) if memory_state["created"]: - print("Created new session") + logging.info("Created new session") else: - print("Found existing session") + logging.info("Found existing session") - print(memory_state["summary"]) # Human-readable summary - print(f"Messages: {memory_state['message_count']}") - print(f"Memories: {len(memory_state['memories'])}") + logging.info(memory_state["summary"]) # Human-readable summary + logging.info(f"Messages: {memory_state['message_count']}") + logging.info(f"Memories: {len(memory_state['memories'])}") ``` """ try: - result_obj = await self.get_or_create_working_memory( + created, result = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace or self.config.default_namespace, user_id=user_id, ) # Format for LLM consumption - result = result_obj.memory message_count = len(result.messages) if result.messages else 0 memory_count = len(result.memories) if result.memories else 0 data_keys = list(result.data.keys()) if result.data else [] @@ -1238,11 +1265,11 @@ async def get_or_create_working_memory_tool( } ) - status_text = "new session" if result_obj.created else "existing session" + status_text = "new session" if created else "existing session" return { "session_id": session_id, - "created": result_obj.created, + "created": created, "message_count": message_count, "memory_count": memory_count, "memories": formatted_memories, @@ -1299,7 +1326,7 @@ async def add_memory_tool( entities=["vegetarian", "restaurants"] ) - print(result["summary"]) # "Successfully stored semantic memory" + logging.info(result["summary"]) # "Successfully stored semantic memory" ``` """ try: @@ -1373,7 +1400,7 @@ async def update_memory_data_tool( } ) - print(result["summary"]) # "Successfully updated 3 data entries" + logging.info(result["summary"]) # "Successfully updated 3 data entries" ``` """ try: @@ -1948,9 +1975,9 @@ async def resolve_tool_call( ) if result["success"]: - print(result["formatted_response"]) + logging.info(result["formatted_response"]) else: - print(f"Error: {result['error']}") + logging.error(f"Error: {result['error']}") ``` """ try: @@ -2004,7 +2031,7 @@ async def resolve_tool_calls( for result in results: if result["success"]: - print(f"{result['function_name']}: {result['formatted_response']}") + logging.info(f"{result['function_name']}: {result['formatted_response']}") ``` """ results = [] @@ -2062,9 +2089,9 @@ async def resolve_function_call( ) if result["success"]: - print(result["formatted_response"]) + logging.info(result["formatted_response"]) else: - print(f"Error: {result['error']}") + logging.error(f"Error: {result['error']}") ``` """ import json @@ -2352,7 +2379,7 @@ async def resolve_function_calls( results = await client.resolve_function_calls(calls, "session123") for result in results: if result["success"]: - print(f"{result['function_name']}: {result['formatted_response']}") + logging.info(f"{result['function_name']}: {result['formatted_response']}") ``` """ results = [] @@ -2395,10 +2422,9 @@ async def promote_working_memories_to_long_term( Acknowledgement of promotion operation """ # Get current working memory - result_obj = await self.get_or_create_working_memory( + created, working_memory = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace ) - working_memory = result_obj.memory # Filter memories if specific IDs are requested memories_to_promote = working_memory.memories @@ -2611,10 +2637,9 @@ async def update_working_memory_data( WorkingMemoryResponse with updated memory """ # Get existing memory - result_obj = await self.get_or_create_working_memory( + created, existing_memory = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace, user_id=user_id ) - existing_memory = result_obj.memory # Determine final data based on merge strategy if existing_memory and existing_memory.data: @@ -2667,10 +2692,9 @@ async def append_messages_to_working_memory( WorkingMemoryResponse with updated memory (potentially summarized if token limit exceeded) """ # Get existing memory - result_obj = await self.get_or_create_working_memory( + created, existing_memory = await self.get_or_create_working_memory( session_id=session_id, namespace=namespace, user_id=user_id ) - existing_memory = result_obj.memory # Convert messages to MemoryMessage objects converted_messages = [] diff --git a/agent-memory-client/agent_memory_client/models.py b/agent-memory-client/agent_memory_client/models.py index e00732b..2c83760 100644 --- a/agent-memory-client/agent_memory_client/models.py +++ b/agent-memory-client/agent_memory_client/models.py @@ -236,6 +236,14 @@ class WorkingMemoryResponse(WorkingMemory): default=None, description="Percentage until auto-summarization triggers (0-100, reaches 100% at summarization threshold)", ) + new_session: bool | None = Field( + default=None, + description="True if session was created, False if existing session was found, None if not applicable", + ) + unsaved: bool | None = Field( + default=None, + description="True if this session data has not been persisted to Redis yet (deprecated behavior for old clients)", + ) class MemoryRecordResult(MemoryRecord): @@ -283,15 +291,6 @@ class MemoryRecordResults(BaseModel): next_offset: int | None = None -class WorkingMemoryGetOrCreateResponse(BaseModel): - """Response from get_or_create_working_memory operations""" - - memory: WorkingMemoryResponse - created: bool = Field( - description="True if the session was created, False if it already existed" - ) - - class MemoryPromptResponse(BaseModel): """Response from memory prompt endpoint""" diff --git a/agent_memory_server/__init__.py b/agent_memory_server/__init__.py index 935abaf..7e344ac 100644 --- a/agent_memory_server/__init__.py +++ b/agent_memory_server/__init__.py @@ -1,3 +1,3 @@ """Redis Agent Memory Server - A memory system for conversational AI.""" -__version__ = "0.10.0" +__version__ = "0.11.0" diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index 4e59f87..1150557 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -1,7 +1,8 @@ +import re from typing import Any import tiktoken -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Header, HTTPException, Query from mcp.server.fastmcp.prompts import base from mcp.types import TextContent @@ -38,6 +39,31 @@ router = APIRouter() +def parse_client_version(client_version: str | None) -> tuple[int, int, int] | None: + """Parse client version string into tuple (major, minor, patch)""" + if not client_version: + return None + + # Extract version from format like "0.12.0" + match = re.match(r"(\d+)\.(\d+)\.(\d+)", client_version) + if not match: + return None + + return (int(match.group(1)), int(match.group(2)), int(match.group(3))) + + +def is_old_client(client_version: str | None) -> bool: + """Check if client version is older than 0.12.0 (needs deprecated behavior)""" + parsed = parse_client_version(client_version) + if not parsed: + # No version header means very old client + return True + + major, minor, patch = parsed + # Version 0.12.0 is when we introduced proper REST behavior + return (major, minor, patch) < (0, 12, 0) + + @router.post("/v1/long-term-memory/forget") async def forget_endpoint( policy: dict, @@ -320,6 +346,7 @@ async def get_working_memory( namespace: str | None = None, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, + x_client_version: str | None = Header(None, alias="X-Client-Version"), current_user: UserInfo = Depends(get_current_user), ): """ @@ -347,15 +374,30 @@ async def get_working_memory( user_id=user_id, ) + # Handle missing sessions based on client version + new_session = False + unsaved = None + if not working_mem: - # Create empty working memory if none exists - working_mem = WorkingMemory( - messages=[], - memories=[], - session_id=session_id, - namespace=namespace, - user_id=user_id, - ) + if is_old_client(x_client_version): + # Deprecated behavior: return empty session with unsaved=True (don't persist) + logger.warning( + f"Client version {x_client_version or 'unknown'} using deprecated behavior. " + "GET /v1/working-memory/{session_id} will return 404 for missing sessions in version 1.0. " + "Use get_or_create_working_memory client method instead." + ) + unsaved = True + # Create empty working memory but DO NOT persist it + working_mem = WorkingMemory( + session_id=session_id, + namespace=namespace, + user_id=user_id, + ) + else: + # Proper REST behavior: return 404 for missing sessions + raise HTTPException( + status_code=404, detail=f"Session {session_id} not found" + ) # Apply token-based truncation if we have messages and model info if working_mem.messages and (model_name or context_window_max): @@ -383,12 +425,14 @@ async def get_working_memory( ) ) - # Return WorkingMemoryResponse with both percentage values + # Return WorkingMemoryResponse with percentage values, new_session flag, and unsaved flag working_mem_data = working_mem.model_dump() working_mem_data["context_percentage_total_used"] = total_percentage working_mem_data["context_percentage_until_summarization"] = ( until_summarization_percentage ) + working_mem_data["new_session"] = new_session + working_mem_data["unsaved"] = unsaved return WorkingMemoryResponse(**working_mem_data) @@ -421,6 +465,8 @@ async def put_working_memory( """ redis = await get_redis_conn() + # PUT semantics: we simply replace whatever exists (or create if it doesn't exist) + # Ensure session_id matches memory.session_id = session_id @@ -477,7 +523,7 @@ async def put_working_memory( ) ) - # Return WorkingMemoryResponse with both percentage values + # Return WorkingMemoryResponse with percentage values (no new_session for PUT) updated_memory_data = updated_memory.model_dump() updated_memory_data["context_percentage_total_used"] = total_percentage updated_memory_data["context_percentage_until_summarization"] = ( diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 663c43f..d1e065c 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -1,49 +1,178 @@ import os +from enum import Enum from typing import Any, Literal import yaml from dotenv import load_dotenv +from pydantic import BaseModel from pydantic_settings import BaseSettings load_dotenv() +class ModelProvider(str, Enum): + """Type of model provider""" + + OPENAI = "openai" + ANTHROPIC = "anthropic" + + +class ModelConfig(BaseModel): + """Configuration for a model""" + + provider: ModelProvider + name: str + max_tokens: int + embedding_dimensions: int = 1536 # Default for OpenAI ada-002 + + # Model configuration mapping MODEL_CONFIGS = { - "gpt-4o": {"provider": "openai", "embedding_dimensions": None}, - "gpt-4o-mini": {"provider": "openai", "embedding_dimensions": None}, - "gpt-4": {"provider": "openai", "embedding_dimensions": None}, - "gpt-3.5-turbo": {"provider": "openai", "embedding_dimensions": None}, - "text-embedding-3-small": {"provider": "openai", "embedding_dimensions": 1536}, - "text-embedding-3-large": {"provider": "openai", "embedding_dimensions": 3072}, - "text-embedding-ada-002": {"provider": "openai", "embedding_dimensions": 1536}, - "claude-3-opus-20240229": {"provider": "anthropic", "embedding_dimensions": None}, - "claude-3-sonnet-20240229": {"provider": "anthropic", "embedding_dimensions": None}, - "claude-3-haiku-20240307": {"provider": "anthropic", "embedding_dimensions": None}, - "claude-3-5-sonnet-20240620": { - "provider": "anthropic", - "embedding_dimensions": None, - }, - "claude-3-5-sonnet-20241022": { - "provider": "anthropic", - "embedding_dimensions": None, - }, - "claude-3-5-haiku-20241022": { - "provider": "anthropic", - "embedding_dimensions": None, - }, - "claude-3-7-sonnet-20250219": { - "provider": "anthropic", - "embedding_dimensions": None, - }, - "claude-3-7-sonnet-latest": {"provider": "anthropic", "embedding_dimensions": None}, - "claude-3-5-sonnet-latest": {"provider": "anthropic", "embedding_dimensions": None}, - "claude-3-5-haiku-latest": {"provider": "anthropic", "embedding_dimensions": None}, - "claude-3-opus-latest": {"provider": "anthropic", "embedding_dimensions": None}, - "o1": {"provider": "openai", "embedding_dimensions": None}, - "o1-mini": {"provider": "openai", "embedding_dimensions": None}, - "o3-mini": {"provider": "openai", "embedding_dimensions": None}, + # OpenAI Models + "gpt-3.5-turbo": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-3.5-turbo", + max_tokens=4096, + embedding_dimensions=1536, + ), + "gpt-3.5-turbo-16k": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-3.5-turbo-16k", + max_tokens=16384, + embedding_dimensions=1536, + ), + "gpt-4": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-4", + max_tokens=8192, + embedding_dimensions=1536, + ), + "gpt-4-32k": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-4-32k", + max_tokens=32768, + embedding_dimensions=1536, + ), + "gpt-4o": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-4o", + max_tokens=128000, + embedding_dimensions=1536, + ), + "gpt-4o-mini": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-4o-mini", + max_tokens=128000, + embedding_dimensions=1536, + ), + # Newer reasoning models + "o1": ModelConfig( + provider=ModelProvider.OPENAI, + name="o1", + max_tokens=200000, + embedding_dimensions=1536, + ), + "o1-mini": ModelConfig( + provider=ModelProvider.OPENAI, + name="o1-mini", + max_tokens=128000, + embedding_dimensions=1536, + ), + "o3-mini": ModelConfig( + provider=ModelProvider.OPENAI, + name="o3-mini", + max_tokens=200000, + embedding_dimensions=1536, + ), + # Embedding models + "text-embedding-ada-002": ModelConfig( + provider=ModelProvider.OPENAI, + name="text-embedding-ada-002", + max_tokens=8191, + embedding_dimensions=1536, + ), + "text-embedding-3-small": ModelConfig( + provider=ModelProvider.OPENAI, + name="text-embedding-3-small", + max_tokens=8191, + embedding_dimensions=1536, + ), + "text-embedding-3-large": ModelConfig( + provider=ModelProvider.OPENAI, + name="text-embedding-3-large", + max_tokens=8191, + embedding_dimensions=3072, + ), + # Anthropic Models + "claude-3-opus-20240229": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-opus-20240229", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-sonnet-20240229": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-sonnet-20240229", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-haiku-20240307": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-haiku-20240307", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-5-sonnet-20240620": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-5-sonnet-20240620", + max_tokens=200000, + embedding_dimensions=1536, + ), + # Latest Anthropic Models + "claude-3-7-sonnet-20250219": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-7-sonnet-20250219", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-5-sonnet-20241022": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-5-sonnet-20241022", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-5-haiku-20241022": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-5-haiku-20241022", + max_tokens=200000, + embedding_dimensions=1536, + ), + # Convenience aliases + "claude-3-7-sonnet-latest": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-7-sonnet-20250219", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-5-sonnet-latest": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-5-sonnet-20241022", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-5-haiku-latest": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-5-haiku-20241022", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-3-opus-latest": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-3-opus-20240229", + max_tokens=200000, + embedding_dimensions=1536, + ), } @@ -167,14 +296,14 @@ class Config: extra = "ignore" # Ignore extra environment variables @property - def generation_model_config(self) -> dict[str, Any]: + def generation_model_config(self) -> ModelConfig | None: """Get configuration for the generation model.""" - return MODEL_CONFIGS.get(self.generation_model, {}) + return MODEL_CONFIGS.get(self.generation_model) @property - def embedding_model_config(self) -> dict[str, Any]: + def embedding_model_config(self) -> ModelConfig | None: """Get configuration for the embedding model.""" - return MODEL_CONFIGS.get(self.embedding_model, {}) + return MODEL_CONFIGS.get(self.embedding_model) def load_yaml_config(self, config_path: str) -> dict[str, Any]: """Load configuration from YAML file.""" diff --git a/agent_memory_server/llms.py b/agent_memory_server/llms.py index 8653026..de4901c 100644 --- a/agent_memory_server/llms.py +++ b/agent_memory_server/llms.py @@ -1,185 +1,23 @@ import json import logging import os -from enum import Enum from typing import Any import anthropic import numpy as np from openai import AsyncOpenAI -from pydantic import BaseModel -from agent_memory_server.config import settings +from agent_memory_server.config import ( + MODEL_CONFIGS, + ModelConfig, + ModelProvider, + settings, +) logger = logging.getLogger(__name__) -class ModelProvider(str, Enum): - """Type of model provider""" - - OPENAI = "openai" - ANTHROPIC = "anthropic" - - -class ModelConfig(BaseModel): - """Configuration for a model""" - - provider: ModelProvider - name: str - max_tokens: int - embedding_dimensions: int = 1536 # Default for OpenAI ada-002 - - -# Model configurations -MODEL_CONFIGS = { - # OpenAI Models - "gpt-3.5-turbo": ModelConfig( - provider=ModelProvider.OPENAI, - name="gpt-3.5-turbo", - max_tokens=4096, - embedding_dimensions=1536, - ), - "gpt-3.5-turbo-16k": ModelConfig( - provider=ModelProvider.OPENAI, - name="gpt-3.5-turbo-16k", - max_tokens=16384, - embedding_dimensions=1536, - ), - "gpt-4": ModelConfig( - provider=ModelProvider.OPENAI, - name="gpt-4", - max_tokens=8192, - embedding_dimensions=1536, - ), - "gpt-4-32k": ModelConfig( - provider=ModelProvider.OPENAI, - name="gpt-4-32k", - max_tokens=32768, - embedding_dimensions=1536, - ), - "gpt-4o": ModelConfig( - provider=ModelProvider.OPENAI, - name="gpt-4o", - max_tokens=128000, - embedding_dimensions=1536, - ), - "gpt-4o-mini": ModelConfig( - provider=ModelProvider.OPENAI, - name="gpt-4o-mini", - max_tokens=128000, - embedding_dimensions=1536, - ), - # Newer reasoning models - "o1": ModelConfig( - provider=ModelProvider.OPENAI, - name="o1", - max_tokens=200000, - embedding_dimensions=1536, - ), - "o1-mini": ModelConfig( - provider=ModelProvider.OPENAI, - name="o1-mini", - max_tokens=128000, - embedding_dimensions=1536, - ), - "o3-mini": ModelConfig( - provider=ModelProvider.OPENAI, - name="o3-mini", - max_tokens=200000, - embedding_dimensions=1536, - ), - # Embedding models - "text-embedding-ada-002": ModelConfig( - provider=ModelProvider.OPENAI, - name="text-embedding-ada-002", - max_tokens=8191, - embedding_dimensions=1536, - ), - "text-embedding-3-small": ModelConfig( - provider=ModelProvider.OPENAI, - name="text-embedding-3-small", - max_tokens=8191, - embedding_dimensions=1536, - ), - "text-embedding-3-large": ModelConfig( - provider=ModelProvider.OPENAI, - name="text-embedding-3-large", - max_tokens=8191, - embedding_dimensions=3072, - ), - # Anthropic Models - "claude-3-opus-20240229": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-opus-20240229", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-sonnet-20240229": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-sonnet-20240229", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-haiku-20240307": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-haiku-20240307", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-5-sonnet-20240620": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-5-sonnet-20240620", - max_tokens=200000, - embedding_dimensions=1536, - ), - # Latest Anthropic Models - "claude-3-7-sonnet-20250219": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-7-sonnet-20250219", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-5-sonnet-20241022": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-5-sonnet-20241022", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-5-haiku-20241022": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-5-haiku-20241022", - max_tokens=200000, - embedding_dimensions=1536, - ), - # Convenience aliases - "claude-3-7-sonnet-latest": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-7-sonnet-20250219", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-5-sonnet-latest": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-5-sonnet-20241022", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-5-haiku-latest": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-5-haiku-20241022", - max_tokens=200000, - embedding_dimensions=1536, - ), - "claude-3-opus-latest": ModelConfig( - provider=ModelProvider.ANTHROPIC, - name="claude-3-opus-20240229", - max_tokens=200000, - embedding_dimensions=1536, - ), -} - - def get_model_config(model_name: str) -> ModelConfig: """Get configuration for a model""" if model_name in MODEL_CONFIGS: diff --git a/agent_memory_server/main.py b/agent_memory_server/main.py index bb4a715..2488f87 100644 --- a/agent_memory_server/main.py +++ b/agent_memory_server/main.py @@ -8,10 +8,9 @@ from agent_memory_server import __version__ from agent_memory_server.api import router as memory_router from agent_memory_server.auth import verify_auth_config -from agent_memory_server.config import settings +from agent_memory_server.config import MODEL_CONFIGS, ModelProvider, settings from agent_memory_server.docket_tasks import register_tasks from agent_memory_server.healthcheck import router as health_router -from agent_memory_server.llms import MODEL_CONFIGS, ModelProvider from agent_memory_server.logging import get_logger from agent_memory_server.utils.redis import ( _redis_pool as connection_pool, diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index 54c09a8..01c240b 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -389,6 +389,14 @@ class WorkingMemoryResponse(WorkingMemory): default=None, description="Percentage until auto-summarization triggers (0-100, reaches 100% at summarization threshold)", ) + new_session: bool | None = Field( + default=None, + description="True if session was created, False if existing session was found, None if not applicable", + ) + unsaved: bool | None = Field( + default=None, + description="True if this session data has not been persisted to Redis yet (deprecated behavior for old clients)", + ) class WorkingMemoryRequest(BaseModel): diff --git a/agent_memory_server/vectorstore_factory.py b/agent_memory_server/vectorstore_factory.py index d3f1ff2..efde800 100644 --- a/agent_memory_server/vectorstore_factory.py +++ b/agent_memory_server/vectorstore_factory.py @@ -46,7 +46,8 @@ def create_embeddings() -> Embeddings: An Embeddings instance """ embedding_config = settings.embedding_model_config - provider = embedding_config.get("provider", "openai") + # Only support ModelConfig objects + provider = embedding_config.provider if embedding_config else "openai" if provider == "openai": try: diff --git a/docs/memory-integration-patterns.md b/docs/memory-integration-patterns.md index 6bdf9ee..efe9ef2 100644 --- a/docs/memory-integration-patterns.md +++ b/docs/memory-integration-patterns.md @@ -59,29 +59,40 @@ if response.choices[0].message.tool_calls: ```python class LLMMemoryAgent: - def __init__(self, memory_url: str, session_id: str, user_id: str): + def __init__(self, memory_url: str, session_id: str, user_id: str, model_name: str = "gpt-4o"): self.memory_client = MemoryAPIClient(base_url=memory_url) self.openai_client = openai.AsyncOpenAI() self.session_id = session_id self.user_id = user_id - self.conversation_history = [] + self.model_name = model_name async def chat(self, user_message: str) -> str: - # Add user message to conversation - self.conversation_history.append({ - "role": "user", - "content": user_message - }) + # Get or create working memory session for conversation history + created, working_memory = await self.memory_client.get_or_create_working_memory( + session_id=self.session_id, + model_name=self.model_name, + user_id=self.user_id + ) - # Get memory tools + # Get conversation context that includes relevant long-term memories + context = await self.memory_client.memory_prompt( + query=user_message, + session_id=self.session_id, + long_term_search={ + "text": user_message, + "filters": {"user_id": {"eq": self.user_id}}, + "limit": 5 + } + ) + + # Get memory tools for the LLM tools = MemoryAPIClient.get_all_memory_tool_schemas() - # Generate response with memory tools + # Generate response with memory tools and context response = await self.openai_client.chat.completions.create( - model="gpt-4o", - messages=[ - {"role": "system", "content": "You are a helpful assistant with persistent memory. Remember important user information and retrieve relevant context when needed."}, - *self.conversation_history + model=self.model_name, + messages=context.messages + [ + {"role": "user", "content": user_message} ], tools=tools ) @@ -97,10 +108,21 @@ class LLMMemoryAgent: ) assistant_message = response.choices[0].message.content - self.conversation_history.append({ - "role": "assistant", - "content": assistant_message - }) + + # Store the conversation turn in working memory + from agent_memory_client.models import WorkingMemory, MemoryMessage + + await self.memory_client.set_working_memory( + session_id=self.session_id, + working_memory=WorkingMemory( + session_id=self.session_id, + messages=[ + MemoryMessage(role="user", content=user_message), + MemoryMessage(role="assistant", content=assistant_message) + ], + user_id=self.user_id + ) + ) return assistant_message @@ -108,7 +130,8 @@ class LLMMemoryAgent: agent = LLMMemoryAgent( memory_url="http://localhost:8000", session_id="alice_chat", - user_id="alice" + user_id="alice", + model_name="gpt-4o" ) # First conversation @@ -235,8 +258,7 @@ class CodeDrivenAgent: session_id: str ) -> str: # 1. Get working memory session (creates if doesn't exist) - result = await self.memory_client.get_or_create_working_memory(session_id) - working_memory = result.memory + created, working_memory = await self.memory_client.get_or_create_working_memory(session_id) # 2. Search for relevant context using session ID context_search = await self.memory_client.memory_prompt( @@ -344,8 +366,7 @@ results = await asyncio.gather(*search_tasks) async def get_enriched_context(user_query: str, user_id: str, session_id: str): """Get context that includes both working memory and relevant long-term memories""" # First, get the working memory session (creates if doesn't exist) - result = await client.get_or_create_working_memory(session_id) - working_memory = result.memory + created, working_memory = await client.get_or_create_working_memory(session_id) # Then use memory_prompt with session ID return await client.memory_prompt( @@ -501,8 +522,7 @@ class AutoLearningAgent: """Process conversation with automatic learning""" # 1. Get working memory session (creates if doesn't exist) - result = await self.memory_client.get_or_create_working_memory(session_id) - working_memory = result.memory + created, working_memory = await self.memory_client.get_or_create_working_memory(session_id) # 2. Get existing context for better responses context = await self.memory_client.memory_prompt( @@ -651,8 +671,7 @@ class HybridMemoryAgent: async def chat(self, user_message: str, user_id: str, session_id: str) -> str: # 1. Get working memory session (creates if doesn't exist) - result = await self.memory_client.get_or_create_working_memory(session_id) - working_memory = result.memory + created, working_memory = await self.memory_client.get_or_create_working_memory(session_id) # 2. Code-driven: Get relevant context context = await self.memory_client.memory_prompt( diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 6fbc108..62bddd3 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -335,12 +335,12 @@ conversation = { await client.set_working_memory("session-123", conversation) # Retrieve or create working memory -result = await client.get_or_create_working_memory("session-123") -if result.created: +created, memory = await client.get_or_create_working_memory("session-123") +if created: print("Created new session") else: print("Found existing session") -print(f"Session has {len(result.memory.messages)} messages") +print(f"Session has {len(memory.messages)} messages") ``` ## Memory-Enhanced Conversations diff --git a/examples/memory_editing_agent.py b/examples/memory_editing_agent.py index c43c49f..644f3fb 100644 --- a/examples/memory_editing_agent.py +++ b/examples/memory_editing_agent.py @@ -456,13 +456,12 @@ async def _generate_response( """Generate a response using the LLM with conversation context.""" # Get working memory for context client = await self.get_client() - result_obj = await client.get_or_create_working_memory( + created, working_memory = await client.get_or_create_working_memory( session_id=session_id, namespace=self._get_namespace(user_id), model_name="gpt-4o-mini", user_id=user_id, ) - working_memory = result_obj.memory context_messages = working_memory.messages diff --git a/examples/travel_agent.py b/examples/travel_agent.py index 97ca9ff..fa233fc 100644 --- a/examples/travel_agent.py +++ b/examples/travel_agent.py @@ -257,12 +257,11 @@ async def cleanup(self): async def _get_working_memory(self, session_id: str, user_id: str) -> WorkingMemory: """Get working memory for a session, creating it if it doesn't exist.""" client = await self.get_client() - result_obj = await client.get_or_create_working_memory( + created, result = await client.get_or_create_working_memory( session_id=session_id, namespace=self._get_namespace(user_id), model_name="gpt-4o-mini", # Controls token-based truncation ) - result = result_obj.memory return WorkingMemory(**result.model_dump()) async def _search_web(self, query: str) -> str: diff --git a/tests/integration/test_vectorstore_factory_integration.py b/tests/integration/test_vectorstore_factory_integration.py index b8bb71f..3ae56a9 100644 --- a/tests/integration/test_vectorstore_factory_integration.py +++ b/tests/integration/test_vectorstore_factory_integration.py @@ -10,6 +10,7 @@ import pytest from langchain_core.embeddings import Embeddings +from agent_memory_server.config import ModelConfig, ModelProvider from agent_memory_server.vectorstore_factory import ( _import_and_call_factory, create_embeddings, @@ -89,8 +90,13 @@ class TestEmbeddingsCreation: def test_create_openai_embeddings(self, mock_settings): """Test OpenAI embeddings creation.""" - # Configure mock settings - mock_settings.embedding_model_config = {"provider": "openai"} + # Configure mock settings with ModelConfig object + mock_settings.embedding_model_config = ModelConfig( + provider=ModelProvider.OPENAI, + name="text-embedding-3-small", + max_tokens=8191, + embedding_dimensions=1536, + ) mock_settings.embedding_model = "text-embedding-3-small" mock_settings.openai_api_key = "test-key" @@ -107,7 +113,12 @@ def test_create_openai_embeddings(self, mock_settings): def test_create_embeddings_unsupported_provider(self, mock_settings): """Test embeddings creation with unsupported provider.""" - mock_settings.embedding_model_config = {"provider": "unsupported"} + # Create a mock model config with unsupported provider + mock_config = Mock() + mock_config.provider = ( + "unsupported" # Set directly as string, bypassing enum validation + ) + mock_settings.embedding_model_config = mock_config with pytest.raises(ValueError, match="Unsupported embedding provider"): create_embeddings() diff --git a/tests/test_api.py b/tests/test_api.py index 7b9a9d8..436c8a0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -434,11 +434,28 @@ async def test_delete_memory(self, client, session): response = await client.get( f"/v1/working-memory/{session_id}?namespace=test-namespace&user_id=test-user" ) + # Should return 200 with unsaved session (deprecated behavior for old clients) assert response.status_code == 200 + data = response.json() + assert data["unsaved"] is True # Not persisted (deprecated behavior) + assert len(data["messages"]) == 0 # Empty session + assert len(data["memories"]) == 0 + + @pytest.mark.asyncio + async def test_get_nonexistent_session_with_new_client_returns_404(self, client): + """Test that new clients (with version header) get 404 for missing sessions""" + # Simulate new client by sending version header + headers = {"X-Client-Version": "0.12.0"} + + response = await client.get( + "/v1/working-memory/nonexistent-session?namespace=test-namespace&user_id=test-user", + headers=headers, + ) - # Should return empty working memory after deletion + # Should return 404 for proper REST behavior + assert response.status_code == 404 data = response.json() - assert len(data["messages"]) == 0 + assert "not found" in data["detail"].lower() @pytest.mark.requires_api_keys diff --git a/tests/test_client_api.py b/tests/test_client_api.py index 63df23c..d69bec1 100644 --- a/tests/test_client_api.py +++ b/tests/test_client_api.py @@ -67,9 +67,15 @@ async def memory_test_client( memory_app: FastAPI, ) -> AsyncGenerator[MemoryAPIClient, None]: """Create a memory client that uses the test FastAPI app.""" + from agent_memory_client import __version__ + async with AsyncClient( transport=ASGITransport(app=memory_app), base_url="http://test", + headers={ + "User-Agent": f"agent-memory-client/{__version__}", + "X-Client-Version": __version__, + }, ) as http_client: # Create the memory client with our test http client config = MemoryClientConfig( @@ -156,15 +162,18 @@ async def test_session_lifecycle(memory_test_client: MemoryAPIClient): response = await memory_test_client.delete_working_memory(session_id) assert response.status == "ok" - # Verify it's gone by mocking a 404 response + # Verify session is gone - new proper REST behavior returns 404 for missing sessions with patch( "agent_memory_server.working_memory.get_working_memory" ) as mock_get_memory: mock_get_memory.return_value = None - # This should not raise an error anymore since the unified API returns empty working memory instead of 404 - session = await memory_test_client.get_working_memory(session_id) - assert len(session.messages) == 0 # Should return empty working memory + # Should raise MemoryNotFoundError (404) since session was deleted + import pytest + from agent_memory_client.exceptions import MemoryNotFoundError + + with pytest.raises(MemoryNotFoundError): + await memory_test_client.get_working_memory(session_id) @pytest.mark.asyncio diff --git a/tests/test_client_enhancements.py b/tests/test_client_enhancements.py index 8fd410c..4ad934b 100644 --- a/tests/test_client_enhancements.py +++ b/tests/test_client_enhancements.py @@ -10,7 +10,6 @@ MemoryRecordResult, MemoryRecordResults, MemoryTypeEnum, - WorkingMemoryGetOrCreateResponse, WorkingMemoryResponse, ) from fastapi import FastAPI @@ -78,10 +77,8 @@ async def test_promote_working_memories_to_long_term(self, enhanced_test_client) user_id=None, ) - # Mock the get_or_create response - get_or_create_response = WorkingMemoryGetOrCreateResponse( - memory=working_memory_response, created=False - ) + # Mock the get_or_create response - now returns (created, memory) tuple + get_or_create_response = (False, working_memory_response) with ( patch.object( @@ -131,9 +128,7 @@ async def test_promote_specific_memory_ids(self, enhanced_test_client): ) # Mock the get_or_create response - get_or_create_response = WorkingMemoryGetOrCreateResponse( - memory=working_memory_response, created=False - ) + get_or_create_response = (False, working_memory_response) with ( patch.object( @@ -173,9 +168,7 @@ async def test_promote_no_memories(self, enhanced_test_client): ) # Mock the get_or_create response - get_or_create_response = WorkingMemoryGetOrCreateResponse( - memory=working_memory_response, created=False - ) + get_or_create_response = (False, working_memory_response) with patch.object( enhanced_test_client, "get_or_create_working_memory" @@ -433,9 +426,7 @@ async def test_update_working_memory_data_merge(self, enhanced_test_client): user_id=None, ) - get_or_create_response = WorkingMemoryGetOrCreateResponse( - memory=existing_memory, created=False - ) + get_or_create_response = (False, existing_memory) with ( patch.object( @@ -478,9 +469,7 @@ async def test_update_working_memory_data_replace(self, enhanced_test_client): user_id=None, ) - get_or_create_response = WorkingMemoryGetOrCreateResponse( - memory=existing_memory, created=False - ) + get_or_create_response = (False, existing_memory) with ( patch.object( @@ -526,9 +515,7 @@ async def test_update_working_memory_data_deep_merge(self, enhanced_test_client) ) as mock_get, patch.object(enhanced_test_client, "put_working_memory") as mock_put, ): - mock_get.return_value = WorkingMemoryGetOrCreateResponse( - memory=existing_memory, created=False - ) + mock_get.return_value = (False, existing_memory) mock_put.return_value = existing_memory updates = { @@ -580,9 +567,7 @@ async def test_append_messages_to_working_memory(self, enhanced_test_client): ) as mock_get, patch.object(enhanced_test_client, "put_working_memory") as mock_put, ): - mock_get.return_value = WorkingMemoryGetOrCreateResponse( - memory=existing_memory, created=False - ) + mock_get.return_value = (False, existing_memory) mock_put.return_value = existing_memory await enhanced_test_client.append_messages_to_working_memory( diff --git a/tests/test_extraction.py b/tests/test_extraction.py index 0f4b4ab..10f0d1d 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -177,6 +177,12 @@ class TestTopicExtractionIntegration: async def test_bertopic_integration(self): """Integration test for BERTopic topic extraction (skipped if not available)""" + # Check if bertopic is available + try: + import bertopic # noqa: F401 + except ImportError: + pytest.skip("bertopic not available") + # Save and set topic_model_source original_source = settings.topic_model_source original_enable_topic_extraction = settings.enable_topic_extraction