diff --git a/src/strands/session/__init__.py b/src/strands/session/__init__.py index 7b5310190..fbd55706d 100644 --- a/src/strands/session/__init__.py +++ b/src/strands/session/__init__.py @@ -3,6 +3,7 @@ This module provides session management functionality. """ +from .dynamodb_session_manager import DynamoDBSessionManager from .file_session_manager import FileSessionManager from .repository_session_manager import RepositorySessionManager from .s3_session_manager import S3SessionManager @@ -13,6 +14,7 @@ "FileSessionManager", "RepositorySessionManager", "S3SessionManager", + "DynamoDBSessionManager", "SessionManager", "SessionRepository", ] diff --git a/src/strands/session/dynamodb_session_manager.py b/src/strands/session/dynamodb_session_manager.py new file mode 100644 index 000000000..9c559b7cc --- /dev/null +++ b/src/strands/session/dynamodb_session_manager.py @@ -0,0 +1,352 @@ +"""DynamoDB-based session manager for cloud storage.""" + +import logging +from decimal import Decimal +from typing import Any, List, Optional + +import boto3 +from boto3.dynamodb.types import TypeDeserializer, TypeSerializer +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError + +from .. import _identifier +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .repository_session_manager import RepositorySessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + + +def _convert_decimals_to_native_types(obj: Any) -> Any: + """Convert Decimal objects to native Python types recursively. + + DynamoDB's TypeDeserializer returns Decimal objects for numeric values, + but other AWS services expect native Python int/float types. + """ + if isinstance(obj, Decimal): + # Convert to int if it's a whole number, otherwise float + return int(obj) if obj % 1 == 0 else float(obj) + elif isinstance(obj, dict): + return {key: _convert_decimals_to_native_types(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [_convert_decimals_to_native_types(item) for item in obj] + else: + return obj + + +class DynamoDBSessionManager(RepositorySessionManager, SessionRepository): + """DynamoDB-based session manager for cloud storage. + + Uses a single table design with the following structure: + - PK (HASH): session_ + - SK (RANGE): session | agent_ | agent_#message_ + + Example: + ``` + ┌─────────────────┬──────────────────────────┬─────────────────┬──────────────────┐ + │ PK │ SK │ entity_type │ data │ + ├─────────────────┼──────────────────────────┼─────────────────┼──────────────────┤ + │ session_abc123 │ session │ SESSION │ {session_json} │ + │ session_abc123 │ agent_agent1 │ AGENT │ {agent_json} │ + │ session_abc123 │ agent_agent1#message_0 │ MESSAGE │ {message_json} │ + │ session_abc123 │ agent_agent1#message_1 │ MESSAGE │ {message_json} │ + └─────────────────┴──────────────────────────┴─────────────────┴──────────────────┘ + ``` + """ + + def __init__( + self, + session_id: str, + table_name: str, + boto_session: Optional[boto3.Session] = None, + boto_client_config: Optional[BotocoreConfig] = None, + region_name: Optional[str] = None, + **kwargs: Any, + ): + """Initialize DynamoDBSessionManager. + + Args: + session_id: ID for the session + table_name: DynamoDB table name + boto_session: Optional boto3 session + boto_client_config: Optional boto3 client configuration + region_name: AWS region for DynamoDB + **kwargs: Additional keyword arguments for future extensibility. + """ + self.table_name = table_name + + session = boto_session or boto3.Session(region_name=region_name) + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self.client = session.client(service_name="dynamodb", config=client_config) + self.serializer = TypeSerializer() + self.deserializer = TypeDeserializer() + super().__init__(session_id=session_id, session_repository=self) + + def _validate_dynamodb_id(self, id_: str, id_type: _identifier.Identifier) -> str: + """Validate ID for DynamoDB key structure. + + Args: + id_: ID to validate + id_type: Type of ID for error messages + + Returns: + Validated ID + + Raises: + ValueError: If ID contains characters that would break DynamoDB key structure + """ + if "_" in id_ or "#" in id_: + raise ValueError(f"{id_type.value}_id={id_} | id cannot contain underscore (_) or hash (#) characters") + return id_ + + def _get_session_pk(self, session_id: str) -> str: + """Get session partition key.""" + session_id = self._validate_dynamodb_id(session_id, _identifier.Identifier.SESSION) + return f"session_{session_id}" + + def _get_session_sk(self) -> str: + """Get session sort key.""" + return "session" + + def _get_agent_sk(self, agent_id: str) -> str: + """Get agent sort key.""" + agent_id = self._validate_dynamodb_id(agent_id, _identifier.Identifier.AGENT) + return f"agent_{agent_id}" + + def _get_message_sk(self, agent_id: str, message_id: int) -> str: + """Get message sort key.""" + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + agent_id = self._validate_dynamodb_id(agent_id, _identifier.Identifier.AGENT) + return f"agent_{agent_id}#message_{message_id}" + + def create_session(self, session: Session, **kwargs: Any) -> Session: + """Create a new session in DynamoDB.""" + pk = self._get_session_pk(session.session_id) + sk = self._get_session_sk() + + try: + self.client.put_item( + TableName=self.table_name, + Item={ + "PK": {"S": pk}, + "SK": {"S": sk}, + "entity_type": {"S": "SESSION"}, + "data": self.serializer.serialize(session.to_dict()), + }, + ConditionExpression="attribute_not_exists(PK)", + ) + return session + except ClientError as e: + if e.response["Error"]["Code"] == "ConditionalCheckFailedException": + raise SessionException(f"Session {session.session_id} already exists") from e + raise SessionException(f"DynamoDB error creating session: {e}") from e + + def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + """Read session data from DynamoDB.""" + pk = self._get_session_pk(session_id) + sk = self._get_session_sk() + + try: + response = self.client.get_item(TableName=self.table_name, Key={"PK": {"S": pk}, "SK": {"S": sk}}) + if "Item" not in response: + return None + + data = self.deserializer.deserialize(response["Item"]["data"]) + data = _convert_decimals_to_native_types(data) + return Session.from_dict(data) + except ClientError as e: + raise SessionException(f"DynamoDB error reading session: {e}") from e + + def delete_session(self, session_id: str, **kwargs: Any) -> None: + """Delete session and all associated data from DynamoDB.""" + pk = self._get_session_pk(session_id) + + try: + # Query all items for this session + response = self.client.query( + TableName=self.table_name, + KeyConditionExpression="PK = :pk", + ExpressionAttributeValues={":pk": {"S": pk}}, + ) + + if not response["Items"]: + raise SessionException(f"Session {session_id} does not exist") + + # Delete all items in batches + for i in range(0, len(response["Items"]), 25): + batch = response["Items"][i : i + 25] + delete_requests = [{"DeleteRequest": {"Key": {"PK": item["PK"], "SK": item["SK"]}}} for item in batch] + self.client.batch_write_item(RequestItems={self.table_name: delete_requests}) + + except ClientError as e: + raise SessionException(f"DynamoDB error deleting session: {e}") from e + + def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: + """Create a new agent in DynamoDB.""" + pk = self._get_session_pk(session_id) + sk = self._get_agent_sk(session_agent.agent_id) + + try: + self.client.put_item( + TableName=self.table_name, + Item={ + "PK": {"S": pk}, + "SK": {"S": sk}, + "entity_type": {"S": "AGENT"}, + "data": self.serializer.serialize(session_agent.to_dict()), + }, + ) + except ClientError as e: + raise SessionException(f"DynamoDB error creating agent: {e}") from e + + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + """Read agent data from DynamoDB.""" + pk = self._get_session_pk(session_id) + sk = self._get_agent_sk(agent_id) + + try: + response = self.client.get_item(TableName=self.table_name, Key={"PK": {"S": pk}, "SK": {"S": sk}}) + if "Item" not in response: + return None + + data = self.deserializer.deserialize(response["Item"]["data"]) + data = _convert_decimals_to_native_types(data) + return SessionAgent.from_dict(data) + except ClientError as e: + raise SessionException(f"DynamoDB error reading agent: {e}") from e + + def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: + """Update agent data in DynamoDB.""" + previous_agent = self.read_agent(session_id=session_id, agent_id=session_agent.agent_id) + if previous_agent is None: + raise SessionException(f"Agent {session_agent.agent_id} in session {session_id} does not exist") + + # Preserve creation timestamp + session_agent.created_at = previous_agent.created_at + + pk = self._get_session_pk(session_id) + sk = self._get_agent_sk(session_agent.agent_id) + + try: + self.client.put_item( + TableName=self.table_name, + Item={ + "PK": {"S": pk}, + "SK": {"S": sk}, + "entity_type": {"S": "AGENT"}, + "data": self.serializer.serialize(session_agent.to_dict()), + }, + ) + except ClientError as e: + raise SessionException(f"DynamoDB error updating agent: {e}") from e + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: + """Create a new message in DynamoDB.""" + pk = self._get_session_pk(session_id) + sk = self._get_message_sk(agent_id, session_message.message_id) + + try: + self.client.put_item( + TableName=self.table_name, + Item={ + "PK": {"S": pk}, + "SK": {"S": sk}, + "entity_type": {"S": "MESSAGE"}, + "data": self.serializer.serialize(session_message.to_dict()), + }, + ) + except ClientError as e: + raise SessionException(f"DynamoDB error creating message: {e}") from e + + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + """Read message data from DynamoDB.""" + pk = self._get_session_pk(session_id) + sk = self._get_message_sk(agent_id, message_id) + + try: + response = self.client.get_item(TableName=self.table_name, Key={"PK": {"S": pk}, "SK": {"S": sk}}) + if "Item" not in response: + return None + + data = self.deserializer.deserialize(response["Item"]["data"]) + data = _convert_decimals_to_native_types(data) + return SessionMessage.from_dict(data) + except ClientError as e: + raise SessionException(f"DynamoDB error reading message: {e}") from e + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: + """Update message data in DynamoDB.""" + previous_message = self.read_message( + session_id=session_id, agent_id=agent_id, message_id=session_message.message_id + ) + if previous_message is None: + raise SessionException(f"Message {session_message.message_id} does not exist") + + # Preserve creation timestamp + session_message.created_at = previous_message.created_at + + pk = self._get_session_pk(session_id) + sk = self._get_message_sk(agent_id, session_message.message_id) + + try: + self.client.put_item( + TableName=self.table_name, + Item={ + "PK": {"S": pk}, + "SK": {"S": sk}, + "entity_type": {"S": "MESSAGE"}, + "data": self.serializer.serialize(session_message.to_dict()), + }, + ) + except ClientError as e: + raise SessionException(f"DynamoDB error updating message: {e}") from e + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + ) -> List[SessionMessage]: + """List messages for an agent with pagination from DynamoDB.""" + pk = self._get_session_pk(session_id) + agent_prefix = f"agent_{self._validate_dynamodb_id(agent_id, _identifier.Identifier.AGENT)}#message_" + + try: + # Query messages for this agent + response = self.client.query( + TableName=self.table_name, + KeyConditionExpression="PK = :pk AND begins_with(SK, :sk_prefix)", + ExpressionAttributeValues={":pk": {"S": pk}, ":sk_prefix": {"S": agent_prefix}}, + ) + + # Sort by message ID (extracted from SK) + items = sorted(response["Items"], key=lambda x: int(x["SK"]["S"].split("_")[-1])) + + # Apply pagination + if limit is not None: + items = items[offset : offset + limit] + else: + items = items[offset:] + + # Convert to SessionMessage objects + messages = [] + for item in items: + data = self.deserializer.deserialize(item["data"]) + data = _convert_decimals_to_native_types(data) + messages.append(SessionMessage.from_dict(data)) + + return messages + + except ClientError as e: + raise SessionException(f"DynamoDB error listing messages: {e}") from e diff --git a/tests/strands/session/test_dynamodb_session_manager.py b/tests/strands/session/test_dynamodb_session_manager.py new file mode 100644 index 000000000..26be99f3d --- /dev/null +++ b/tests/strands/session/test_dynamodb_session_manager.py @@ -0,0 +1,453 @@ +"""Tests for DynamoDBSessionManager.""" + +import time +from decimal import Decimal + +import boto3 +import pytest +from botocore.config import Config as BotocoreConfig +from moto import mock_aws + +from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager +from strands.session.dynamodb_session_manager import DynamoDBSessionManager, \ + _convert_decimals_to_native_types +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +@pytest.fixture +def mocked_aws(): + """ + Mock all AWS interactions + Requires you to create your own boto3 clients + """ + with mock_aws(): + yield + + +@pytest.fixture(scope="function") +def dynamodb_table(mocked_aws): + """DynamoDB table for testing.""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-2") + client = boto3.client("dynamodb", region_name="us-west-2") + table_name = "test-session-table" + + dynamodb.create_table( + TableName=table_name, + KeySchema=[{"AttributeName": "PK", "KeyType": "HASH"}, {"AttributeName": "SK", "KeyType": "RANGE"}], + AttributeDefinitions=[ + {"AttributeName": "PK", "AttributeType": "S"}, + {"AttributeName": "SK", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + time.sleep(1) # make sure table is ready + response = client.describe_table(TableName=table_name) + assert response["Table"]["TableStatus"] == "ACTIVE" + + return table_name + + +@pytest.fixture +def dynamodb_manager(mocked_aws, dynamodb_table): + """Create DynamoDBSessionManager with mocked DynamoDB.""" + yield DynamoDBSessionManager(session_id="test", table_name=dynamodb_table, region_name="us-west-2") + + +@pytest.fixture +def sample_session(): + """Create sample session for testing.""" + return Session( + session_id="test-session-123", + session_type=SessionType.AGENT, + ) + + +@pytest.fixture +def sample_agent(): + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent-456", + state={"key": "value"}, + conversation_manager_state=NullConversationManager().get_state(), + ) + + +@pytest.fixture +def sample_message(): + """Create sample message for testing.""" + return SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + }, + index=0, + ) + + +def test_init_dynamodb_session_manager(mocked_aws, dynamodb_table): + session_manager = DynamoDBSessionManager(session_id="test", table_name=dynamodb_table, region_name="us-west-2") + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_init_dynamodb_session_manager_with_config(mocked_aws, dynamodb_table): + session_manager = DynamoDBSessionManager( + session_id="test", table_name=dynamodb_table, boto_client_config=BotocoreConfig(), region_name="us-west-2" + ) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_init_dynamodb_session_manager_with_existing_user_agent(mocked_aws, dynamodb_table): + session_manager = DynamoDBSessionManager( + session_id="test", + table_name=dynamodb_table, + boto_client_config=BotocoreConfig(user_agent_extra="test"), + region_name="us-west-2", + ) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_create_session(dynamodb_manager, sample_session): + """Test creating a session in DynamoDB.""" + result = dynamodb_manager.create_session(sample_session) + assert result == sample_session + + # Verify DynamoDB item created + response = dynamodb_manager.client.get_item( + TableName=dynamodb_manager.table_name, + Key={"PK": {"S": f"session_{sample_session.session_id}"}, "SK": {"S": "session"}}, + ) + assert "Item" in response + assert response["Item"]["entity_type"]["S"] == "SESSION" + + data = dynamodb_manager.deserializer.deserialize(response["Item"]["data"]) + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + +def test_create_session_already_exists(dynamodb_manager, sample_session): + """Test creating a session that already exists.""" + dynamodb_manager.create_session(sample_session) + + with pytest.raises(SessionException): + dynamodb_manager.create_session(sample_session) + + +def test_read_session(dynamodb_manager, sample_session): + """Test reading a session from DynamoDB.""" + # Create session first + dynamodb_manager.create_session(sample_session) + + # Read it back + result = dynamodb_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(dynamodb_manager): + """Test reading a session that doesn't exist.""" + result = dynamodb_manager.read_session("nonexistent-session") + assert result is None + + +def test_delete_session(dynamodb_manager, sample_session): + """Test deleting a session from DynamoDB.""" + # Create session first + dynamodb_manager.create_session(sample_session) + + # Verify session exists + response = dynamodb_manager.client.get_item( + TableName=dynamodb_manager.table_name, + Key={"PK": {"S": f"session_{sample_session.session_id}"}, "SK": {"S": "session"}}, + ) + assert "Item" in response + + # Delete session + dynamodb_manager.delete_session(sample_session.session_id) + + # Verify deletion + response = dynamodb_manager.client.get_item( + TableName=dynamodb_manager.table_name, + Key={"PK": {"S": f"session_{sample_session.session_id}"}, "SK": {"S": "session"}}, + ) + assert "Item" not in response + + +def test_create_agent(dynamodb_manager, sample_session, sample_agent): + """Test creating an agent in DynamoDB.""" + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify DynamoDB item created + response = dynamodb_manager.client.get_item( + TableName=dynamodb_manager.table_name, + Key={"PK": {"S": f"session_{sample_session.session_id}"}, "SK": {"S": f"agent_{sample_agent.agent_id}"}}, + ) + assert "Item" in response + assert response["Item"]["entity_type"]["S"] == "AGENT" + + data = dynamodb_manager.deserializer.deserialize(response["Item"]["data"]) + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + +def test_read_agent(dynamodb_manager, sample_session, sample_agent): + """Test reading an agent from DynamoDB.""" + # Create session and agent + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = dynamodb_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + assert isinstance(result.conversation_manager_state.get("removed_message_count"), int) + + +def test_read_nonexistent_agent(dynamodb_manager, sample_session): + """Test reading an agent that doesn't exist.""" + # Create session + dynamodb_manager.create_session(sample_session) + # Read agent + result = dynamodb_manager.read_agent(sample_session.session_id, "nonexistent-agent") + + assert result is None + + +def test_update_agent(dynamodb_manager, sample_session, sample_agent): + """Test updating an agent in DynamoDB.""" + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + sample_agent.state = {"updated": "value"} + dynamodb_manager.update_agent(sample_session.session_id, sample_agent) + + result = dynamodb_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + +def test_update_nonexistent_agent(dynamodb_manager, sample_session, sample_agent): + """Test updating an agent that doesn't exist.""" + dynamodb_manager.create_session(sample_session) + + with pytest.raises(SessionException): + dynamodb_manager.update_agent(sample_session.session_id, sample_agent) + + +def test_create_message(dynamodb_manager, sample_session, sample_agent, sample_message): + """Test creating a message in DynamoDB.""" + # Create session and agent + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + # Create message + dynamodb_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify DynamoDB item created + response = dynamodb_manager.client.get_item( + TableName=dynamodb_manager.table_name, + Key={ + "PK": {"S": f"session_{sample_session.session_id}"}, + "SK": {"S": f"agent_{sample_agent.agent_id}#message_{sample_message.message_id}"}, + }, + ) + assert "Item" in response + assert response["Item"]["entity_type"]["S"] == "MESSAGE" + + data = dynamodb_manager.deserializer.deserialize(response["Item"]["data"]) + assert data["message_id"] == sample_message.message_id + + +def test_read_message(dynamodb_manager, sample_session, sample_agent, sample_message): + """Test reading a message from DynamoDB.""" + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + dynamodb_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + result = dynamodb_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert isinstance(result.message_id, int) + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + +def test_read_nonexistent_message(dynamodb_manager, sample_session, sample_agent): + """Test reading a message that doesn't exist.""" + # Create session and agent, no message + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + # Read message + result = dynamodb_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) + + assert result is None + + +def test_list_messages_all(dynamodb_manager, sample_session, sample_agent): + """Test listing all messages from DynamoDB.""" + # Create session and agent + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + { + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + i, + ) + messages.append(message) + dynamodb_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = dynamodb_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 5 + for msg in result: + assert isinstance(msg.message_id, int) + + +def test_list_messages_with_pagination(dynamodb_manager, sample_session, sample_agent): + """Test listing messages with pagination in DynamoDB.""" + # Create session and agent + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for index in range(10): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + }, + index=index, + ) + dynamodb_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = dynamodb_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + assert len(result) == 3 + + # List with offset + result = dynamodb_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + assert len(result) == 5 + + +def test_update_message(dynamodb_manager, sample_session, sample_agent, sample_message): + """Test updating a message in DynamoDB.""" + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + dynamodb_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + sample_message.message["content"] = [ContentBlock(text="Updated content")] + dynamodb_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + result = dynamodb_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" + + +def test_update_nonexistent_message(dynamodb_manager, sample_session, sample_agent, sample_message): + """Test updating a message that doesn't exist.""" + dynamodb_manager.create_session(sample_session) + dynamodb_manager.create_agent(sample_session.session_id, sample_agent) + + with pytest.raises(SessionException): + dynamodb_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +@pytest.mark.parametrize( + "session_id", + [ + "session_with_underscore", + "session#with#hash", + "session_and#both", + ], +) +def test__get_session_pk_invalid_session_id(session_id, dynamodb_manager): + with pytest.raises( + ValueError, match=f"session_id={session_id} | id cannot contain underscore \(_\) or hash \(#\) characters" + ): + dynamodb_manager._get_session_pk(session_id) + + +@pytest.mark.parametrize( + "agent_id", + [ + "agent_with_underscore", + "agent#with#hash", + "agent_and#both", + ], +) +def test__get_agent_sk_invalid_agent_id(agent_id, dynamodb_manager): + with pytest.raises( + ValueError, match=f"agent_id={agent_id} | id cannot contain underscore \(_\) or hash \(#\) characters" + ): + dynamodb_manager._get_agent_sk(agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "not_an_int", + None, + [], + ], +) +def test__get_message_sk_invalid_message_id(message_id, dynamodb_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + dynamodb_manager._get_message_sk("agent1", message_id) + + +def test_convert_decimals_to_native_types(): + """Test the Decimal conversion utility function.""" + # Test simple Decimal conversion + assert _convert_decimals_to_native_types(Decimal('10')) == 10 + assert _convert_decimals_to_native_types(Decimal('10.5')) == 10.5 + assert _convert_decimals_to_native_types(Decimal('0')) == 0 + + # Test nested dictionary conversion + data = { + 'limit': Decimal('10'), + 'max_length': Decimal('8000'), + 'temperature': Decimal('0.5'), + 'name': 'test', + 'enabled': True, + 'nested': { + 'count': Decimal('42'), + 'ratio': Decimal('3.14') + } + } + + result = _convert_decimals_to_native_types(data) + + assert result['limit'] == 10 + assert isinstance(result['limit'], int) + assert result['max_length'] == 8000 + assert isinstance(result['max_length'], int) + assert result['temperature'] == 0.5 + assert isinstance(result['temperature'], float) + assert result['name'] == 'test' + assert result['enabled'] is True + assert result['nested']['count'] == 42 + assert isinstance(result['nested']['count'], int) + assert result['nested']['ratio'] == 3.14 + assert isinstance(result['nested']['ratio'], float) + + # Test list conversion + list_data = [Decimal('1'), Decimal('2.5'), 'string', {'nested': Decimal('100')}] + result = _convert_decimals_to_native_types(list_data) + + assert result[0] == 1 + assert isinstance(result[0], int) + assert result[1] == 2.5 + assert isinstance(result[1], float) + assert result[2] == 'string' + assert result[3]['nested'] == 100 + assert isinstance(result[3]['nested'], int) \ No newline at end of file