From 520dbf7fe79353ae0e7412e76c68c574e6f08668 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 22 Oct 2025 12:10:53 -0400 Subject: [PATCH 01/12] feat: add multiagent hooks, add serialize & deserialize function to multiagent base & agent result --- __init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 __init__.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb From 0beae2743e2c537c147cf8e8fdb5fffcc8f89125 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 22 Oct 2025 16:14:19 -0400 Subject: [PATCH 02/12] feat: add multiagent session manager, register hooks, fix import issue, rename deserialize function # Conflicts: # src/strands/experimental/agent_config.py --- src/strands/multiagent/base.py | 4 +- src/strands/session/file_session_manager.py | 56 +++++++++++-- .../session/repository_session_manager.py | 17 +++- src/strands/session/s3_session_manager.py | 32 ++++++- src/strands/session/session_manager.py | 84 +++++++++++++++++-- src/strands/types/session.py | 1 + tests/strands/multiagent/test_base.py | 4 +- .../session/test_file_session_manager.py | 49 +++++++++++ .../session/test_s3_session_manager.py | 24 ++++++ 9 files changed, 247 insertions(+), 24 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 1628a8a9d..f0e7f78f0 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -137,7 +137,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": metrics = _parse_metrics(data.get("accumulated_metrics", {})) multiagent_result = cls( - status=Status(data.get("status", Status.PENDING.value)), + status=Status(data.get("status")), results=results, accumulated_usage=usage, accumulated_metrics=metrics, @@ -204,7 +204,7 @@ def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" raise NotImplementedError - def deserialize_state(self, payload: dict[str, Any]) -> None: + def restore_from_session(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 491f7ad60..67c07df72 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,14 +5,18 @@ import os import shutil import tempfile -from typing import Any, Optional, cast +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Optional, cast from .. import _identifier from ..types.exceptions import SessionException -from ..types.session import Session, SessionAgent, SessionMessage +from ..types.session import Session, SessionAgent, SessionMessage, SessionType from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" @@ -37,19 +41,26 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): ``` """ - def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): + def __init__( + self, + session_id: str, + storage_dir: Optional[str] = None, + session_type: SessionType = SessionType.AGENT, + **kwargs: Any, + ): """Initialize FileSession with filesystem storage. Args: session_id: ID for the session. ID is not allowed to contain path separators (e.g., a/b). storage_dir: Directory for local filesystem storage (defaults to temp dir). + session_type: single agent or multiagent. **kwargs: Additional keyword arguments for future extensibility. """ self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") os.makedirs(self.storage_dir, exist_ok=True) - super().__init__(session_id=session_id, session_repository=self) + super().__init__(session_id=session_id, session_repository=self, session_type=session_type) def _get_session_path(self, session_id: str) -> str: """Get session directory path. @@ -107,8 +118,10 @@ def _read_file(self, path: str) -> dict[str, Any]: def _write_file(self, path: str, data: dict[str, Any]) -> None: """Write JSON file.""" os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as f: + tmp = f"{path}.tmp" + with open(tmp, "w", encoding="utf-8", newline="\n") as f: json.dump(data, f, indent=2, ensure_ascii=False) + os.replace(tmp, path) def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" @@ -118,7 +131,8 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: # Create directory structure os.makedirs(session_dir, exist_ok=True) - os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) + if self.session_type == SessionType.AGENT: + os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) # Write session file session_file = os.path.join(session_dir, "session.json") @@ -239,3 +253,33 @@ def list_messages( messages.append(SessionMessage.from_dict(message_data)) return messages + + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + """Write multi-agent state to filesystem. + + Args: + source: Multi-agent source object to persist + """ + state = source.serialize_state() + state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") + self._write_file(state_path, state) + + # Update session metadata + session_dir = self._get_session_path(self.session.session_id) + session_file = os.path.join(session_dir, "session.json") + with open(session_file, "r", encoding="utf-8") as f: + metadata = json.load(f) + metadata["updated_at"] = datetime.now(timezone.utc).isoformat() + self._write_file(session_file, metadata) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from filesystem. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") + if not os.path.exists(state_path): + return {} + state_data = self._read_file(state_path) + return state_data diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index e5075de93..bcdcbe30f 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -24,7 +24,13 @@ class RepositorySessionManager(SessionManager): """Session manager for persisting agents in a SessionRepository.""" - def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any): + def __init__( + self, + session_id: str, + session_repository: SessionRepository, + session_type: SessionType = SessionType.AGENT, + **kwargs: Any, + ): """Initialize the RepositorySessionManager. If no session with the specified session_id exists yet, it will be created @@ -34,22 +40,27 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa session_id: ID to use for the session. A new session with this id will be created if it does not exist in the repository yet session_repository: Underlying session repository to use to store the sessions state. + session_type: single agent or multiagent. **kwargs: Additional keyword arguments for future extensibility. """ + super().__init__(session_type=session_type) + self.session_repository = session_repository self.session_id = session_id session = session_repository.read_session(session_id) # Create a session if it does not exist yet if session is None: logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) - session = Session(session_id=session_id, session_type=SessionType.AGENT) + session = Session(session_id=session_id, session_type=session_type) session_repository.create_session(session) self.session = session + self.session_type = session.session_type # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + if self.session_type == SessionType.AGENT: + self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index c6ce28d80..cdd838c03 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -10,10 +10,13 @@ from .. import _identifier from ..types.exceptions import SessionException -from ..types.session import Session, SessionAgent, SessionMessage +from ..types.session import Session, SessionAgent, SessionMessage, SessionType from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" @@ -46,6 +49,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + session_type: SessionType = SessionType.AGENT, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -58,6 +62,7 @@ def __init__( boto_session: Optional boto3 session boto_client_config: Optional boto3 client configuration region_name: AWS region for S3 storage + session_type: single agent or multiagent. **kwargs: Additional keyword arguments for future extensibility. """ self.bucket = bucket @@ -78,7 +83,7 @@ def __init__( client_config = BotocoreConfig(user_agent_extra="strands-agents") self.client = session.client(service_name="s3", config=client_config) - super().__init__(session_id=session_id, session_repository=self) + super().__init__(session_id=session_id, session_type=session_type, session_repository=self) def _get_session_path(self, session_id: str) -> str: """Get session S3 prefix. @@ -294,3 +299,24 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e + + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + """Write multi-agent state to S3. + + Args: + source: Multi-agent source object to persist + """ + session_prefix = self._get_session_path(self.session_id) + state_key = f"{session_prefix}multi_agent_state.json" + state = source.serialize_state() + self._write_s3_object(state_key, state) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from S3. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + session_prefix = self._get_session_path(self.session_id) + state_key = f"{session_prefix}multi_agent_state.json" + return self._read_s3_object(state_key) or {} diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 66a07ea43..cb4b4a5f9 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -1,14 +1,24 @@ """Session manager interface for agent session management.""" +import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ..experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message +from ..types.session import SessionType if TYPE_CHECKING: from ..agent.agent import Agent + from ..multiagent.base import MultiAgentBase + +logger = logging.getLogger(__name__) class SessionManager(HookProvider, ABC): @@ -20,19 +30,39 @@ class SessionManager(HookProvider, ABC): for an agent, and should be persisted in the session. """ + def __init__(self, session_type: SessionType = SessionType.AGENT) -> None: + """Initialize SessionManager with session type. + + Args: + session_type: Type of session (AGENT or MULTI_AGENT) + """ + self.session_type: SessionType = session_type + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for persisting the agent to the session.""" - # After the normal Agent initialization behavior, call the session initialize function to restore the agent - registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + if not hasattr(self, "session_type"): + self.session_type = SessionType.AGENT + logger.debug("Session type not set, defaulting to AGENT") - # For each message appended to the Agents messages, store that message in the session - registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) + if self.session_type == SessionType.MULTI_AGENT: + registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized) + registry.add_callback(AfterNodeCallEvent, lambda event: self.write_multi_agent_json(event.source)) + registry.add_callback( + AfterMultiAgentInvocationEvent, lambda event: self.write_multi_agent_json(event.source) + ) - # Sync the agent into the session for each message in case the agent state was updated - registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + else: + # After the normal Agent initialization behavior, call the session initialize function to restore the agent + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) - # After an agent was invoked, sync it with the session to capture any conversation manager state updates - registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) + # For each message appended to the Agents messages, store that message in the session + registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) + + # Sync the agent into the session for each message in case the agent state was updated + registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + + # After an agent was invoked, sync it with the session to capture any conversation manager state updates + registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: @@ -71,3 +101,41 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent: Agent to initialize **kwargs: Additional keyword arguments for future extensibility. """ + + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + """Write multi-agent state to persistent storage. + + Args: + source: Multi-agent source object to persist + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(write_multi_agent_json). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from persistent storage. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(read_multi_agent_json). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) + + def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: + """Handle multi-agent initialization: restore from storage or create initial snapshot. + + If existing state is found, deserializes it into the source. Otherwise, + persists the current state as the initial snapshot. + """ + source: MultiAgentBase = event.source + payload = self.read_multi_agent_json() + # payload can be {} or Graph/Swarm state json + if payload: + source.restore_from_session(payload) + else: + self.write_multi_agent_json(source) diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 926480f2c..e0e8f396c 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -22,6 +22,7 @@ class SessionType(str, Enum): """ AGENT = "AGENT" + MULTI_AGENT = "MULTI_AGENT" def encode_bytes_values(obj: Any) -> Any: diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 4e8a5dd06..3438c5622 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -148,7 +148,7 @@ async def invoke_async(self, task: str) -> MultiAgentResult: def serialize_state(self) -> dict: return {} - def deserialize_state(self, payload: dict) -> None: + def restore_from_session(self, payload: dict) -> None: pass # Should not raise an exception - __call__ is provided by base class @@ -177,7 +177,7 @@ async def invoke_async(self, task, invocation_state, **kwargs): def serialize_state(self) -> dict: return {} - def deserialize_state(self, payload: dict) -> None: + def restore_from_session(self, payload: dict) -> None: pass agent = TestMultiAgent() diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f124ddf58..61b20aba4 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -408,3 +408,52 @@ def test__get_message_path_invalid_message_id(message_id, file_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): file_manager._get_message_path("session1", "agent1", message_id) + + +def test_write_read_multi_agent_json(file_manager, sample_session): + """Test writing and reading multi-agent state.""" + file_manager.create_session(sample_session) + + # Create mock MultiAgentBase object + class MockMultiAgent: + def serialize_state(self): + return {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + + mock_agent = MockMultiAgent() + expected_state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + + # Write multi-agent state + file_manager.write_multi_agent_json(mock_agent) + + # Read multi-agent state + result = file_manager.read_multi_agent_json() + assert result == expected_state + + +def test_read_multi_agent_json_nonexistent(file_manager): + """Test reading multi-agent state when file doesn't exist.""" + result = file_manager.read_multi_agent_json() + assert result == {} + + +def test_list_messages_missing_directory(file_manager, sample_session, sample_agent): + """Test listing messages when messages directory is missing.""" + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Remove messages directory + messages_dir = os.path.join( + file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id), "messages" + ) + os.rmdir(messages_dir) + + with pytest.raises(SessionException, match="Messages directory missing"): + file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + +def test_create_existing_session(file_manager, sample_session): + """Test creating a session that already exists.""" + file_manager.create_session(sample_session) + + with pytest.raises(SessionException, match="already exists"): + file_manager.create_session(sample_session) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index c4d6a0154..d5f2b7d97 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -374,3 +374,27 @@ def test__get_message_path_invalid_message_id(message_id, s3_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): s3_manager._get_message_path("session1", "agent1", message_id) + + +def test_write_read_multi_agent_json(s3_manager, sample_session): + """Test multi-agent state persistence.""" + s3_manager.create_session(sample_session) + + # Create mock MultiAgentBase object + class MockMultiAgent: + def serialize_state(self): + return {"type": "graph", "status": "completed"} + + mock_agent = MockMultiAgent() + expected_state = {"type": "graph", "status": "completed"} + + s3_manager.write_multi_agent_json(mock_agent) + + result = s3_manager.read_multi_agent_json() + assert result == expected_state + + +def test_read_multi_agent_json_nonexistent(s3_manager): + """Test reading multi-agent state when file doesn't exist.""" + result = s3_manager.read_multi_agent_json() + assert result == {} From 7a94c151a1ad952c309accc5a5f998c4fed5ef60 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 22 Oct 2025 16:19:45 -0400 Subject: [PATCH 03/12] Delete __init__.py --- __init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 __init__.py diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29bb..000000000 From 836d2a2d1244c042af404e9d5a053541bcbc781c Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Thu, 23 Oct 2025 10:36:35 -0400 Subject: [PATCH 04/12] fix: address comments --- src/strands/multiagent/base.py | 2 +- src/strands/session/file_session_manager.py | 1 + src/strands/session/repository_session_manager.py | 1 + src/strands/session/session_manager.py | 2 +- tests/strands/multiagent/test_base.py | 4 ++-- 5 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f0e7f78f0..51c368d78 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -204,7 +204,7 @@ def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" raise NotImplementedError - def restore_from_session(self, payload: dict[str, Any]) -> None: + def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 67c07df72..a5459fb58 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -45,6 +45,7 @@ def __init__( self, session_id: str, storage_dir: Optional[str] = None, + *, session_type: SessionType = SessionType.AGENT, **kwargs: Any, ): diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index bcdcbe30f..f4f1723b6 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -28,6 +28,7 @@ def __init__( self, session_id: str, session_repository: SessionRepository, + *, session_type: SessionType = SessionType.AGENT, **kwargs: Any, ): diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index cb4b4a5f9..b0955622d 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -136,6 +136,6 @@ def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: payload = self.read_multi_agent_json() # payload can be {} or Graph/Swarm state json if payload: - source.restore_from_session(payload) + source.deserialize_state(payload) else: self.write_multi_agent_json(source) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 3438c5622..4e8a5dd06 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -148,7 +148,7 @@ async def invoke_async(self, task: str) -> MultiAgentResult: def serialize_state(self) -> dict: return {} - def restore_from_session(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: pass # Should not raise an exception - __call__ is provided by base class @@ -177,7 +177,7 @@ async def invoke_async(self, task, invocation_state, **kwargs): def serialize_state(self) -> dict: return {} - def restore_from_session(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: pass agent = TestMultiAgent() From aa3c905c1ccced5f2e609ab773be67c3c29481b5 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Thu, 23 Oct 2025 17:45:47 -0400 Subject: [PATCH 05/12] fix: renaming function to keep consistent with existing code --- src/strands/session/file_session_manager.py | 6 ++++-- src/strands/session/s3_session_manager.py | 5 +++-- src/strands/session/session_manager.py | 19 +++++++++---------- .../session/test_file_session_manager.py | 6 +++--- .../session/test_s3_session_manager.py | 6 +++--- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index a5459fb58..d6f296c19 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -119,6 +119,7 @@ def _read_file(self, path: str) -> dict[str, Any]: def _write_file(self, path: str, data: dict[str, Any]) -> None: """Write JSON file.""" os.makedirs(os.path.dirname(path), exist_ok=True) + # This automic write ensure the completeness of session files in both single agent/ multi agents tmp = f"{path}.tmp" with open(tmp, "w", encoding="utf-8", newline="\n") as f: json.dump(data, f, indent=2, ensure_ascii=False) @@ -255,11 +256,12 @@ def list_messages( return messages - def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: """Write multi-agent state to filesystem. Args: source: Multi-agent source object to persist + **kwargs: Additional keyword arguments for future extensibility. """ state = source.serialize_state() state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") @@ -273,7 +275,7 @@ def write_multi_agent_json(self, source: "MultiAgentBase") -> None: metadata["updated_at"] = datetime.now(timezone.utc).isoformat() self._write_file(session_file, metadata) - def read_multi_agent_json(self) -> dict[str, Any]: + def initialize_multi_agent(self) -> dict[str, Any]: """Read multi-agent state from filesystem. Returns: diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index cdd838c03..5599c848d 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -300,18 +300,19 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e - def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: """Write multi-agent state to S3. Args: source: Multi-agent source object to persist + **kwargs: Additional keyword arguments for future extensibility. """ session_prefix = self._get_session_path(self.session_id) state_key = f"{session_prefix}multi_agent_state.json" state = source.serialize_state() self._write_s3_object(state_key, state) - def read_multi_agent_json(self) -> dict[str, Any]: + def initialize_multi_agent(self) -> dict[str, Any]: """Read multi-agent state from S3. Returns: diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index b0955622d..4f43e3784 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -46,10 +46,8 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: if self.session_type == SessionType.MULTI_AGENT: registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized) - registry.add_callback(AfterNodeCallEvent, lambda event: self.write_multi_agent_json(event.source)) - registry.add_callback( - AfterMultiAgentInvocationEvent, lambda event: self.write_multi_agent_json(event.source) - ) + registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) else: # After the normal Agent initialization behavior, call the session initialize function to restore the agent @@ -102,19 +100,20 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: **kwargs: Additional keyword arguments for future extensibility. """ - def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: """Write multi-agent state to persistent storage. Args: source: Multi-agent source object to persist + **kwargs: Additional keyword arguments for future extensibility. """ raise NotImplementedError( f"{self.__class__.__name__} does not support multi-agent persistence " - "(write_multi_agent_json). Provide an implementation or use a " + "(sync_multi_agent). Provide an implementation or use a " "SessionManager with session_type=SessionType.MULTI_AGENT." ) - def read_multi_agent_json(self) -> dict[str, Any]: + def initialize_multi_agent(self) -> dict[str, Any]: """Read multi-agent state from persistent storage. Returns: @@ -122,7 +121,7 @@ def read_multi_agent_json(self) -> dict[str, Any]: """ raise NotImplementedError( f"{self.__class__.__name__} does not support multi-agent persistence " - "(read_multi_agent_json). Provide an implementation or use a " + "(initialize_multi_agent). Provide an implementation or use a " "SessionManager with session_type=SessionType.MULTI_AGENT." ) @@ -133,9 +132,9 @@ def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: persists the current state as the initial snapshot. """ source: MultiAgentBase = event.source - payload = self.read_multi_agent_json() + payload = self.initialize_multi_agent() # payload can be {} or Graph/Swarm state json if payload: source.deserialize_state(payload) else: - self.write_multi_agent_json(source) + self.sync_multi_agent(source) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 61b20aba4..bdd3e966d 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -423,16 +423,16 @@ def serialize_state(self): expected_state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} # Write multi-agent state - file_manager.write_multi_agent_json(mock_agent) + file_manager.sync_multi_agent(mock_agent) # Read multi-agent state - result = file_manager.read_multi_agent_json() + result = file_manager.initialize_multi_agent() assert result == expected_state def test_read_multi_agent_json_nonexistent(file_manager): """Test reading multi-agent state when file doesn't exist.""" - result = file_manager.read_multi_agent_json() + result = file_manager.initialize_multi_agent() assert result == {} diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index d5f2b7d97..85ddf40d9 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -388,13 +388,13 @@ def serialize_state(self): mock_agent = MockMultiAgent() expected_state = {"type": "graph", "status": "completed"} - s3_manager.write_multi_agent_json(mock_agent) + s3_manager.sync_multi_agent(mock_agent) - result = s3_manager.read_multi_agent_json() + result = s3_manager.initialize_multi_agent() assert result == expected_state def test_read_multi_agent_json_nonexistent(s3_manager): """Test reading multi-agent state when file doesn't exist.""" - result = s3_manager.read_multi_agent_json() + result = s3_manager.initialize_multi_agent() assert result == {} From 01e6dbfebcc6f7efe4f3781a33ad487ef8f60058 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Sun, 26 Oct 2025 02:38:48 -0400 Subject: [PATCH 06/12] feat: add multiagent session/repository management pattern --- src/strands/multiagent/base.py | 2 + src/strands/session/file_session_manager.py | 80 +++++++----- .../session/repository_session_manager.py | 24 ++++ src/strands/session/s3_session_manager.py | 51 ++++---- src/strands/session/session_manager.py | 25 ++-- src/strands/session/session_repository.py | 17 ++- tests/fixtures/mock_session_repository.py | 32 +++++ .../session/test_file_session_manager.py | 118 +++++++++++++----- .../test_repository_session_manager.py | 69 ++++++++++ .../session/test_s3_session_manager.py | 102 ++++++++++++--- 10 files changed, 400 insertions(+), 120 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 51c368d78..6c22d7818 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -166,6 +166,8 @@ class MultiAgentBase(ABC): multi-agent orchestration capabilities. """ + id: str + @abstractmethod async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index d6f296c19..206c4bf33 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -22,6 +22,7 @@ SESSION_PREFIX = "session_" AGENT_PREFIX = "agent_" MESSAGE_PREFIX = "message_" +MULTI_AGENT_PREFIX = "multi_agent_" class FileSessionManager(RepositorySessionManager, SessionRepository): @@ -135,7 +136,8 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: os.makedirs(session_dir, exist_ok=True) if self.session_type == SessionType.AGENT: os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) - + else: + os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True) # Write session file session_file = os.path.join(session_dir, "session.json") session_dict = session.to_dict() @@ -152,6 +154,15 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: session_data = self._read_file(session_file) return Session.from_dict(session_data) + def update_session(self, session_id: str, **kwargs: Any) -> None: + """Update session updated_at field.""" + session_file = os.path.join(self._get_session_path(session_id), "session.json") + session_data = self.read_session(session_id) + if session_data is None: + raise SessionException(f"Session {session_id} does not exist") + session_data.updated_at = datetime.now(timezone.utc).isoformat() + self._write_file(session_file, session_data.to_dict()) + def delete_session(self, session_id: str, **kwargs: Any) -> None: """Delete session and all associated data.""" session_dir = self._get_session_path(session_id) @@ -256,33 +267,40 @@ def list_messages( return messages - def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: - """Write multi-agent state to filesystem. - - Args: - source: Multi-agent source object to persist - **kwargs: Additional keyword arguments for future extensibility. - """ - state = source.serialize_state() - state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") - self._write_file(state_path, state) - - # Update session metadata - session_dir = self._get_session_path(self.session.session_id) - session_file = os.path.join(session_dir, "session.json") - with open(session_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - metadata["updated_at"] = datetime.now(timezone.utc).isoformat() - self._write_file(session_file, metadata) - - def initialize_multi_agent(self) -> dict[str, Any]: - """Read multi-agent state from filesystem. - - Returns: - Multi-agent state dictionary or empty dict if not found - """ - state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") - if not os.path.exists(state_path): - return {} - state_data = self._read_file(state_path) - return state_data + def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str: + """Get multi-agent state file path.""" + session_path = self._get_session_path(session_id) + multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT) + return os.path.join(session_path, "multi_agents", f"{MULTI_AGENT_PREFIX}{multi_agent_id}") + + def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Create a new multiagent state in the session.""" + multi_agent_id = multi_agent.id + multi_agent_dir = self._get_multi_agent_path(session_id, multi_agent_id) + os.makedirs(multi_agent_dir, exist_ok=True) + + multi_agent_file = os.path.join(multi_agent_dir, "multi_agent.json") + session_data = multi_agent.serialize_state() + self._write_file(multi_agent_file, session_data) + + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + """Read multi-agent state from filesystem.""" + multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") + if not os.path.exists(multi_agent_file): + return None + return self._read_file(multi_agent_file) + + def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: + """Update multi-agent state from filesystem.""" + multi_agent_id = multi_agent_state.get("id") + if multi_agent_id is None: + raise SessionException("MultiAgent state must have an 'id' field") + previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent_id) + if previous_multi_agent_state is None: + raise SessionException(f"MultiAgent state {multi_agent_id} in session {session_id} does not exist") + + multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") + self._write_file(multi_agent_file, multi_agent_state) + + # Update session.update_at + self.update_session(session_id) diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index f4f1723b6..100184ab4 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..multiagent.base import MultiAgentBase logger = logging.getLogger(__name__) @@ -164,3 +165,26 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: # Restore the agents messages array including the optional prepend messages agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] + + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: + """Serialize and update the multi-agent state into the session repository. + + Args: + source: Multi-agent source object to sync to the session. + **kwargs: Additional keyword arguments for future extensibility. + """ + self.session_repository.update_multi_agent(self.session_id, source.serialize_state()) + + def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: + """Initialize multi-agent state from the session repository. + + Args: + source: Multi-agent source object to restore state into + **kwargs: Additional keyword arguments for future extensibility. + """ + state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + if state is None: + self.session_repository.create_multi_agent(self.session_id, source) + else: + logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id) + source.deserialize_state(state) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 5599c848d..170f753a6 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -22,6 +22,7 @@ SESSION_PREFIX = "session_" AGENT_PREFIX = "agent_" MESSAGE_PREFIX = "message_" +MULTI_AGENT_PREFIX = "multi_agent_" class S3SessionManager(RepositorySessionManager, SessionRepository): @@ -300,24 +301,32 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e - def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: - """Write multi-agent state to S3. - - Args: - source: Multi-agent source object to persist - **kwargs: Additional keyword arguments for future extensibility. - """ - session_prefix = self._get_session_path(self.session_id) - state_key = f"{session_prefix}multi_agent_state.json" - state = source.serialize_state() - self._write_s3_object(state_key, state) - - def initialize_multi_agent(self) -> dict[str, Any]: - """Read multi-agent state from S3. - - Returns: - Multi-agent state dictionary or empty dict if not found - """ - session_prefix = self._get_session_path(self.session_id) - state_key = f"{session_prefix}multi_agent_state.json" - return self._read_s3_object(state_key) or {} + def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str: + """Get multi-agent S3 prefix.""" + session_path = self._get_session_path(session_id) + multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT) + return f"{session_path}multi_agents/{MULTI_AGENT_PREFIX}{multi_agent_id}/" + + def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Create a new multiagent state in S3.""" + multi_agent_id = multi_agent.id + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" + session_data = multi_agent.serialize_state() + self._write_s3_object(multi_agent_key, session_data) + + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + """Read multi-agent state from S3.""" + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" + return self._read_s3_object(multi_agent_key) + + def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: + """Update multi-agent state in S3.""" + multi_agent_id = multi_agent_state.get("id") + if multi_agent_id is None: + raise SessionException("MultiAgent state must have an 'id' field") + previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent_id) + if previous_multi_agent_state is None: + raise SessionException(f"MultiAgent state {multi_agent_id} in session {session_id} does not exist") + + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" + self._write_s3_object(multi_agent_key, multi_agent_state) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 4f43e3784..182dcc67a 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -45,7 +45,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: logger.debug("Session type not set, defaulting to AGENT") if self.session_type == SessionType.MULTI_AGENT: - registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized) + registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) @@ -101,7 +101,7 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: """ def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: - """Write multi-agent state to persistent storage. + """Serialize and sync multi-agent with the session storage. Args: source: Multi-agent source object to persist @@ -113,28 +113,19 @@ def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: "SessionManager with session_type=SessionType.MULTI_AGENT." ) - def initialize_multi_agent(self) -> dict[str, Any]: + def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: """Read multi-agent state from persistent storage. + Args: + **kwargs: Additional keyword arguments for future extensibility. + source: Multi-agent state to initialize. + Returns: Multi-agent state dictionary or empty dict if not found + """ raise NotImplementedError( f"{self.__class__.__name__} does not support multi-agent persistence " "(initialize_multi_agent). Provide an implementation or use a " "SessionManager with session_type=SessionType.MULTI_AGENT." ) - - def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: - """Handle multi-agent initialization: restore from storage or create initial snapshot. - - If existing state is found, deserializes it into the source. Otherwise, - persists the current state as the initial snapshot. - """ - source: MultiAgentBase = event.source - payload = self.initialize_multi_agent() - # payload can be {} or Graph/Swarm state json - if payload: - source.deserialize_state(payload) - else: - self.sync_multi_agent(source) diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 6b0fded7a..7a7e02e39 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -1,10 +1,13 @@ """Session repository interface for agent session management.""" from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from ..types.session import Session, SessionAgent, SessionMessage +if TYPE_CHECKING: + from ..multiagent import MultiAgentBase + class SessionRepository(ABC): """Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages.""" @@ -49,3 +52,15 @@ def list_messages( self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" + + @abstractmethod + def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Create a new MultiAgent state for the Session.""" + + @abstractmethod + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + """Read the MultiAgent state for the Session.""" + + @abstractmethod + def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: + """Update the MultiAgent state for the Session.""" diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py index f3923f68b..96410a159 100644 --- a/tests/fixtures/mock_session_repository.py +++ b/tests/fixtures/mock_session_repository.py @@ -11,6 +11,7 @@ def __init__(self): self.sessions = {} self.agents = {} self.messages = {} + self.multi_agents = {} def create_session(self, session) -> None: """Create a session.""" @@ -20,11 +21,19 @@ def create_session(self, session) -> None: self.sessions[session_id] = session self.agents[session_id] = {} self.messages[session_id] = {} + self.multi_agents[session_id] = {} def read_session(self, session_id) -> SessionAgent: """Read a session.""" return self.sessions.get(session_id) + def update_session(self, session_id, **kwargs) -> None: + """Update a session.""" + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + # Mock implementation - just mark as updated + pass + def create_agent(self, session_id, session_agent) -> None: """Create an agent.""" agent_id = session_agent.agent_id @@ -95,3 +104,26 @@ def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[Sess if limit is not None: return sorted_messages[offset : offset + limit] return sorted_messages[offset:] + + def create_multi_agent(self, session_id, multi_agent, **kwargs) -> None: + """Create multi-agent state.""" + multi_agent_id = multi_agent.id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + state = multi_agent.serialize_state() + self.multi_agents.setdefault(session_id, {})[multi_agent_id] = state + + def read_multi_agent(self, session_id, multi_agent_id, **kwargs): + """Read multi-agent state.""" + if session_id not in self.sessions: + return None + return self.multi_agents.get(session_id, {}).get(multi_agent_id) + + def update_multi_agent(self, session_id, multi_agent_state, **kwargs) -> None: + """Update multi-agent state.""" + multi_agent_id = multi_agent_state.get("id") + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if multi_agent_id not in self.multi_agents.get(session_id, {}): + raise SessionException(f"MultiAgent {multi_agent_id} does not exist in session {session_id}") + self.multi_agents[session_id][multi_agent_id] = multi_agent_state diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index bdd3e966d..f49cf41ad 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -53,6 +53,30 @@ def sample_message(): ) +@pytest.fixture +def mock_multi_agent(): + """Create mock multi-agent for testing.""" + from unittest.mock import Mock + + mock = Mock() + mock.id = "test-multi-agent" + mock.state = {"key": "value"} + mock.serialize_state.return_value = {"id": "test-multi-agent", "state": {"key": "value"}} + return mock + + +@pytest.fixture +def multi_agent_session(): + """Create sample multi-agent session for testing.""" + return Session(session_id="test-session", session_type=SessionType.MULTI_AGENT) + + +@pytest.fixture +def multi_agent_manager(temp_dir): + """Create FileSessionManager with multi-agent session type.""" + return FileSessionManager(session_id="test", storage_dir=temp_dir, session_type=SessionType.MULTI_AGENT) + + def test_create_session(file_manager, sample_session): """Test creating a session.""" file_manager.create_session(sample_session) @@ -410,50 +434,76 @@ def test__get_message_path_invalid_message_id(message_id, file_manager): file_manager._get_message_path("session1", "agent1", message_id) -def test_write_read_multi_agent_json(file_manager, sample_session): - """Test writing and reading multi-agent state.""" - file_manager.create_session(sample_session) +def test_create_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent): + """Test creating multi-agent state.""" + multi_agent_manager.create_session(multi_agent_session) + multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) + + # Verify file created + multi_agent_file = os.path.join( + multi_agent_manager._get_multi_agent_path(multi_agent_session.session_id, mock_multi_agent.id), + "multi_agent.json", + ) + assert os.path.exists(multi_agent_file) + + # Verify content + with open(multi_agent_file, "r") as f: + data = json.load(f) + assert data["id"] == mock_multi_agent.id + assert data["state"] == mock_multi_agent.state - # Create mock MultiAgentBase object - class MockMultiAgent: - def serialize_state(self): - return {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} - mock_agent = MockMultiAgent() - expected_state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} +def test_read_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent): + """Test reading multi-agent state.""" + # Create session and multi-agent + multi_agent_manager.create_session(multi_agent_session) + multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Write multi-agent state - file_manager.sync_multi_agent(mock_agent) + # Read multi-agent + result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id) - # Read multi-agent state - result = file_manager.initialize_multi_agent() - assert result == expected_state + assert result["id"] == mock_multi_agent.id + assert result["state"] == mock_multi_agent.state -def test_read_multi_agent_json_nonexistent(file_manager): - """Test reading multi-agent state when file doesn't exist.""" - result = file_manager.initialize_multi_agent() - assert result == {} +def test_read_nonexistent_multi_agent(multi_agent_manager, multi_agent_session): + """Test reading multi-agent state that doesn't exist.""" + result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, "nonexistent") + assert result is None -def test_list_messages_missing_directory(file_manager, sample_session, sample_agent): - """Test listing messages when messages directory is missing.""" - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) +def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent): + """Test updating multi-agent state.""" + # Create session and multi-agent + multi_agent_manager.create_session(multi_agent_session) + multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Remove messages directory - messages_dir = os.path.join( - file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id), "messages" - ) - os.rmdir(messages_dir) + # Update multi-agent + updated_state = {"id": mock_multi_agent.id, "state": {"updated": "value"}} + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_state) + + # Verify update + result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id) + assert result["state"] == {"updated": "value"} - with pytest.raises(SessionException, match="Messages directory missing"): - file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) +def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session): + """Test updating multi-agent state that doesn't exist.""" + # Create session + multi_agent_manager.create_session(multi_agent_session) -def test_create_existing_session(file_manager, sample_session): - """Test creating a session that already exists.""" - file_manager.create_session(sample_session) + # Update nonexistent multi-agent + with pytest.raises(SessionException): + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, {"id": "nonexistent"}) + + +def test_create_session_multi_agent_directory_structure(multi_agent_manager, multi_agent_session): + """Test multi-agent session creates correct directory structure.""" + multi_agent_manager.create_session(multi_agent_session) + + # Verify directory structure + session_dir = multi_agent_manager._get_session_path(multi_agent_session.session_id) + multi_agents_dir = os.path.join(session_dir, "multi_agents") - with pytest.raises(SessionException, match="already exists"): - file_manager.create_session(sample_session) + assert os.path.exists(session_dir) + assert os.path.exists(multi_agents_dir) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 923b13daa..de095f584 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -31,6 +31,26 @@ def agent(): return Agent(messages=[{"role": "user", "content": [{"text": "Hello!"}]}]) +@pytest.fixture +def mock_multi_agent(): + """Create mock multi-agent for testing.""" + from unittest.mock import Mock + + mock = Mock() + mock.id = "test-multi-agent" + mock.serialize_state.return_value = {"id": "test-multi-agent", "state": {"key": "value"}} + mock.deserialize_state = Mock() + return mock + + +@pytest.fixture +def multi_agent_session_manager(mock_repository): + """Create a multi-agent session manager.""" + return RepositorySessionManager( + session_id="test-multi-session", session_repository=mock_repository, session_type=SessionType.MULTI_AGENT + ) + + def test_init_creates_session_if_not_exists(mock_repository): """Test that init creates a session if it doesn't exist.""" # Session doesn't exist yet @@ -177,3 +197,52 @@ def test_append_message(session_manager): assert len(messages) == 1 assert messages[0].message["role"] == "user" assert messages[0].message["content"][0]["text"] == "Hello" + + +def test_init_multi_agent_session_type(mock_repository): + """Test creating session manager with multi-agent type.""" + manager = RepositorySessionManager( + session_id="multi-session", session_repository=mock_repository, session_type=SessionType.MULTI_AGENT + ) + + assert manager.session_type == SessionType.MULTI_AGENT + session = mock_repository.read_session("multi-session") + assert session.session_type == SessionType.MULTI_AGENT + + +def test_sync_multi_agent(multi_agent_session_manager, mock_multi_agent): + """Test syncing multi-agent state.""" + # Create multi-agent first + multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent) + + # Sync multi-agent + multi_agent_session_manager.sync_multi_agent(mock_multi_agent) + + # Verify repository update_multi_agent was called + state = multi_agent_session_manager.session_repository.read_multi_agent("test-multi-session", mock_multi_agent.id) + assert state["id"] == "test-multi-agent" + assert state["state"] == {"key": "value"} + + +def test_initialize_multi_agent_new(multi_agent_session_manager, mock_multi_agent): + """Test initializing new multi-agent state.""" + multi_agent_session_manager.initialize_multi_agent(mock_multi_agent) + + # Verify multi-agent was created + state = multi_agent_session_manager.session_repository.read_multi_agent("test-multi-session", mock_multi_agent.id) + assert state["id"] == "test-multi-agent" + assert state["state"] == {"key": "value"} + + +def test_initialize_multi_agent_existing(multi_agent_session_manager, mock_multi_agent): + """Test initializing existing multi-agent state.""" + # Create existing state first + multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent) + existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}} + multi_agent_session_manager.session_repository.update_multi_agent("test-multi-session", existing_state) + + # Initialize multi-agent + multi_agent_session_manager.initialize_multi_agent(mock_multi_agent) + + # Verify deserialize_state was called with existing state + mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 85ddf40d9..0ac6f9640 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -376,25 +376,95 @@ def test__get_message_path_invalid_message_id(message_id, s3_manager): s3_manager._get_message_path("session1", "agent1", message_id) -def test_write_read_multi_agent_json(s3_manager, sample_session): - """Test multi-agent state persistence.""" - s3_manager.create_session(sample_session) +@pytest.fixture +def mock_multi_agent(): + """Create mock multi-agent for testing.""" + from unittest.mock import Mock + + mock = Mock() + mock.id = "test-multi-agent" + mock.state = {"key": "value"} + mock.serialize_state.return_value = {"id": "test-multi-agent", "state": {"key": "value"}} + return mock + + +@pytest.fixture +def multi_agent_session(): + """Create sample multi-agent session for testing.""" + return Session( + session_id="test-multi-session", + session_type=SessionType.MULTI_AGENT, + ) + + +@pytest.fixture +def multi_agent_manager(mocked_aws, s3_bucket): + """Create S3SessionManager with multi-agent session type.""" + yield S3SessionManager( + session_id="test-multi", + bucket=s3_bucket, + prefix="sessions/", + region_name="us-west-2", + session_type=SessionType.MULTI_AGENT, + ) + + +def test_create_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent): + """Test creating multi-agent state in S3.""" + multi_agent_manager.create_session(multi_agent_session) + multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Create mock MultiAgentBase object - class MockMultiAgent: - def serialize_state(self): - return {"type": "graph", "status": "completed"} + # Verify S3 object created + key = f"{ + multi_agent_manager._get_multi_agent_path(multi_agent_session.session_id, mock_multi_agent.id) + }multi_agent.json" + response = multi_agent_manager.client.get_object(Bucket=multi_agent_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["id"] == mock_multi_agent.id + assert data["state"] == mock_multi_agent.state + + +def test_read_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent): + """Test reading multi-agent state from S3.""" + # Create session and multi-agent + multi_agent_manager.create_session(multi_agent_session) + multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) + + # Read multi-agent + result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id) + + assert result["id"] == mock_multi_agent.id + assert result["state"] == mock_multi_agent.state - mock_agent = MockMultiAgent() - expected_state = {"type": "graph", "status": "completed"} - s3_manager.sync_multi_agent(mock_agent) +def test_read_nonexistent_multi_agent(multi_agent_manager, multi_agent_session): + """Test reading multi-agent state that doesn't exist.""" + multi_agent_manager.create_session(multi_agent_session) + result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, "nonexistent") + assert result is None + - result = s3_manager.initialize_multi_agent() - assert result == expected_state +def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi_agent): + """Test updating multi-agent state in S3.""" + # Create session and multi-agent + multi_agent_manager.create_session(multi_agent_session) + multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) + # Update multi-agent + updated_state = {"id": mock_multi_agent.id, "state": {"updated": "value"}} + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_state) + + # Verify update + result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id) + assert result["state"] == {"updated": "value"} -def test_read_multi_agent_json_nonexistent(s3_manager): - """Test reading multi-agent state when file doesn't exist.""" - result = s3_manager.initialize_multi_agent() - assert result == {} + +def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session): + """Test updating multi-agent state that doesn't exist.""" + # Create session + multi_agent_manager.create_session(multi_agent_session) + + # Update nonexistent multi-agent + with pytest.raises(SessionException): + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, {"id": "nonexistent"}) From 85c1c2db282cbd2d07d3cd397e0884f7c5db54bb Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Sun, 26 Oct 2025 03:00:06 -0400 Subject: [PATCH 07/12] fix: fix unit tests --- tests/strands/session/test_s3_session_manager.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 0ac6f9640..b89110a51 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -415,9 +415,10 @@ def test_create_multi_agent(multi_agent_manager, multi_agent_session, mock_multi multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) # Verify S3 object created - key = f"{ - multi_agent_manager._get_multi_agent_path(multi_agent_session.session_id, mock_multi_agent.id) - }multi_agent.json" + key = ( + f"{multi_agent_manager._get_multi_agent_path(multi_agent_session.session_id, mock_multi_agent.id)}" + f"multi_agent.json" + ) response = multi_agent_manager.client.get_object(Bucket=multi_agent_manager.bucket, Key=key) data = json.loads(response["Body"].read().decode("utf-8")) From dfd776c3afaec8f52d4d99beeeb276e938824c08 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 27 Oct 2025 17:35:02 -0400 Subject: [PATCH 08/12] fix: address comments --- src/strands/multiagent/base.py | 5 ++++- src/strands/session/s3_session_manager.py | 1 + src/strands/session/session_manager.py | 2 +- src/strands/session/session_repository.py | 6 +++--- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 6c22d7818..9ab107bb9 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -137,7 +137,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": metrics = _parse_metrics(data.get("accumulated_metrics", {})) multiagent_result = cls( - status=Status(data.get("status")), + status=Status(data["status"]), results=results, accumulated_usage=usage, accumulated_metrics=metrics, @@ -164,6 +164,9 @@ class MultiAgentBase(ABC): This class integrates with existing Strands Agent instances and provides multi-agent orchestration capabilities. + + Attributes: + id: Unique MultiAgent id for session management,etc. """ id: str diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 170f753a6..49f68bd15 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -50,6 +50,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + *, session_type: SessionType = SessionType.AGENT, **kwargs: Any, ): diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 182dcc67a..d5278525d 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -121,7 +121,7 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non source: Multi-agent state to initialize. Returns: - Multi-agent state dictionary or empty dict if not found + Multi-agent state dictionary or empty dict if not found. """ raise NotImplementedError( diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 7a7e02e39..7473a9fd9 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -53,14 +53,14 @@ def list_messages( ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" - @abstractmethod def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: """Create a new MultiAgent state for the Session.""" + raise NotImplementedError("MultiAgent is not implemented for this repository") - @abstractmethod def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: """Read the MultiAgent state for the Session.""" + raise NotImplementedError("MultiAgent is not implemented for this repository") - @abstractmethod def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: """Update the MultiAgent state for the Session.""" + raise NotImplementedError("MultiAgent is not implemented for this repository") From 8668b16a66a860b366ebc29c9afb0ca58c098537 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 28 Oct 2025 13:24:54 -0400 Subject: [PATCH 09/12] fix: update parameter to use MultiAgentBase --- src/strands/session/file_session_manager.py | 25 ++++--------------- .../session/repository_session_manager.py | 2 +- src/strands/session/s3_session_manager.py | 12 ++++----- src/strands/session/session_repository.py | 2 +- tests/fixtures/mock_session_repository.py | 9 ++++--- .../session/test_file_session_manager.py | 16 +++++++++--- .../test_repository_session_manager.py | 9 ++++++- .../session/test_s3_session_manager.py | 16 +++++++++--- 8 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 206c4bf33..554161f2e 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,7 +5,6 @@ import os import shutil import tempfile -from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Optional, cast from .. import _identifier @@ -154,15 +153,6 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: session_data = self._read_file(session_file) return Session.from_dict(session_data) - def update_session(self, session_id: str, **kwargs: Any) -> None: - """Update session updated_at field.""" - session_file = os.path.join(self._get_session_path(session_id), "session.json") - session_data = self.read_session(session_id) - if session_data is None: - raise SessionException(f"Session {session_id} does not exist") - session_data.updated_at = datetime.now(timezone.utc).isoformat() - self._write_file(session_file, session_data.to_dict()) - def delete_session(self, session_id: str, **kwargs: Any) -> None: """Delete session and all associated data.""" session_dir = self._get_session_path(session_id) @@ -290,17 +280,12 @@ def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) return None return self._read_file(multi_agent_file) - def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: + def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: """Update multi-agent state from filesystem.""" - multi_agent_id = multi_agent_state.get("id") - if multi_agent_id is None: - raise SessionException("MultiAgent state must have an 'id' field") - previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent_id) + multi_agent_state = multi_agent.serialize_state() + previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id) if previous_multi_agent_state is None: - raise SessionException(f"MultiAgent state {multi_agent_id} in session {session_id} does not exist") + raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist") - multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") + multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent.id), "multi_agent.json") self._write_file(multi_agent_file, multi_agent_state) - - # Update session.update_at - self.update_session(session_id) diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 100184ab4..b4da340a9 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -173,7 +173,7 @@ def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: source: Multi-agent source object to sync to the session. **kwargs: Additional keyword arguments for future extensibility. """ - self.session_repository.update_multi_agent(self.session_id, source.serialize_state()) + self.session_repository.update_multi_agent(self.session_id, source) def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: """Initialize multi-agent state from the session repository. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 49f68bd15..46a4345b4 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -320,14 +320,12 @@ def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" return self._read_s3_object(multi_agent_key) - def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: + def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: """Update multi-agent state in S3.""" - multi_agent_id = multi_agent_state.get("id") - if multi_agent_id is None: - raise SessionException("MultiAgent state must have an 'id' field") - previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent_id) + multi_agent_state = multi_agent.serialize_state() + previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id) if previous_multi_agent_state is None: - raise SessionException(f"MultiAgent state {multi_agent_id} in session {session_id} does not exist") + raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist") - multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent.id)}multi_agent.json" self._write_s3_object(multi_agent_key, multi_agent_state) diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 7473a9fd9..3f5476bdf 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -61,6 +61,6 @@ def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) """Read the MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") - def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: + def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: """Update the MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py index 96410a159..58eb395aa 100644 --- a/tests/fixtures/mock_session_repository.py +++ b/tests/fixtures/mock_session_repository.py @@ -119,11 +119,12 @@ def read_multi_agent(self, session_id, multi_agent_id, **kwargs): return None return self.multi_agents.get(session_id, {}).get(multi_agent_id) - def update_multi_agent(self, session_id, multi_agent_state, **kwargs) -> None: + def update_multi_agent(self, session_id, multi_agent, **kwargs) -> None: """Update multi-agent state.""" - multi_agent_id = multi_agent_state.get("id") + multi_agent_id = multi_agent.id if session_id not in self.sessions: raise SessionException(f"Session {session_id} does not exist") if multi_agent_id not in self.multi_agents.get(session_id, {}): - raise SessionException(f"MultiAgent {multi_agent_id} does not exist in session {session_id}") - self.multi_agents[session_id][multi_agent_id] = multi_agent_state + raise SessionException(f"MultiAgent {multi_agent} does not exist in session {session_id}") + state = multi_agent.serialize_state() + self.multi_agents[session_id][multi_agent_id] = state diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f49cf41ad..f09a6aee7 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -478,9 +478,13 @@ def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi multi_agent_manager.create_session(multi_agent_session) multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Update multi-agent - updated_state = {"id": mock_multi_agent.id, "state": {"updated": "value"}} - multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_state) + # Update multi-agent - create a new mock with updated state + from unittest.mock import Mock + + updated_mock = Mock() + updated_mock.id = mock_multi_agent.id + updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}} + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_mock) # Verify update result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id) @@ -493,8 +497,12 @@ def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session multi_agent_manager.create_session(multi_agent_session) # Update nonexistent multi-agent + from unittest.mock import Mock + + nonexistent_mock = Mock() + nonexistent_mock.id = "nonexistent" with pytest.raises(SessionException): - multi_agent_manager.update_multi_agent(multi_agent_session.session_id, {"id": "nonexistent"}) + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, nonexistent_mock) def test_create_session_multi_agent_directory_structure(multi_agent_manager, multi_agent_session): diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index de095f584..af0a7c7db 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -238,8 +238,15 @@ def test_initialize_multi_agent_existing(multi_agent_session_manager, mock_multi """Test initializing existing multi-agent state.""" # Create existing state first multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent) + + # Create a mock with updated state for the update call + from unittest.mock import Mock + + updated_mock = Mock() + updated_mock.id = "test-multi-agent" existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}} - multi_agent_session_manager.session_repository.update_multi_agent("test-multi-session", existing_state) + updated_mock.serialize_state.return_value = existing_state + multi_agent_session_manager.session_repository.update_multi_agent("test-multi-session", updated_mock) # Initialize multi-agent multi_agent_session_manager.initialize_multi_agent(mock_multi_agent) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index b89110a51..ec9b4059a 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -452,9 +452,13 @@ def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi multi_agent_manager.create_session(multi_agent_session) multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Update multi-agent - updated_state = {"id": mock_multi_agent.id, "state": {"updated": "value"}} - multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_state) + # Update multi-agent - create a new mock with updated state + from unittest.mock import Mock + + updated_mock = Mock() + updated_mock.id = mock_multi_agent.id + updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}} + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, updated_mock) # Verify update result = multi_agent_manager.read_multi_agent(multi_agent_session.session_id, mock_multi_agent.id) @@ -467,5 +471,9 @@ def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session multi_agent_manager.create_session(multi_agent_session) # Update nonexistent multi-agent + from unittest.mock import Mock + + nonexistent_mock = Mock() + nonexistent_mock.id = "nonexistent" with pytest.raises(SessionException): - multi_agent_manager.update_multi_agent(multi_agent_session.session_id, {"id": "nonexistent"}) + multi_agent_manager.update_multi_agent(multi_agent_session.session_id, nonexistent_mock) From b5a1f641e209caf831e7959a484982c7ecd01846 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 28 Oct 2025 13:51:25 -0400 Subject: [PATCH 10/12] fix: fix unit tests --- tests/fixtures/mock_session_repository.py | 7 ------- tests/strands/session/test_file_session_manager.py | 10 +--------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py index 58eb395aa..af369ba1c 100644 --- a/tests/fixtures/mock_session_repository.py +++ b/tests/fixtures/mock_session_repository.py @@ -27,13 +27,6 @@ def read_session(self, session_id) -> SessionAgent: """Read a session.""" return self.sessions.get(session_id) - def update_session(self, session_id, **kwargs) -> None: - """Update a session.""" - if session_id not in self.sessions: - raise SessionException(f"Session {session_id} does not exist") - # Mock implementation - just mark as updated - pass - def create_agent(self, session_id, session_agent) -> None: """Create an agent.""" agent_id = session_agent.agent_id diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f09a6aee7..213be0f01 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -3,7 +3,7 @@ import json import os import tempfile -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -56,8 +56,6 @@ def sample_message(): @pytest.fixture def mock_multi_agent(): """Create mock multi-agent for testing.""" - from unittest.mock import Mock - mock = Mock() mock.id = "test-multi-agent" mock.state = {"key": "value"} @@ -478,9 +476,6 @@ def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi multi_agent_manager.create_session(multi_agent_session) multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Update multi-agent - create a new mock with updated state - from unittest.mock import Mock - updated_mock = Mock() updated_mock.id = mock_multi_agent.id updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}} @@ -496,9 +491,6 @@ def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session # Create session multi_agent_manager.create_session(multi_agent_session) - # Update nonexistent multi-agent - from unittest.mock import Mock - nonexistent_mock = Mock() nonexistent_mock.id = "nonexistent" with pytest.raises(SessionException): From 0c3db715e1fffbd6f79c12b647b73b49ce276254 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 28 Oct 2025 13:59:44 -0400 Subject: [PATCH 11/12] fix: fix unit tests import --- tests/strands/session/test_repository_session_manager.py | 4 ++-- tests/strands/session/test_s3_session_manager.py | 8 +------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index af0a7c7db..0a38a8c95 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -1,5 +1,7 @@ """Tests for AgentSessionManager.""" +from unittest.mock import Mock + import pytest from strands.agent.agent import Agent @@ -34,7 +36,6 @@ def agent(): @pytest.fixture def mock_multi_agent(): """Create mock multi-agent for testing.""" - from unittest.mock import Mock mock = Mock() mock.id = "test-multi-agent" @@ -240,7 +241,6 @@ def test_initialize_multi_agent_existing(multi_agent_session_manager, mock_multi multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent) # Create a mock with updated state for the update call - from unittest.mock import Mock updated_mock = Mock() updated_mock.id = "test-multi-agent" diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index ec9b4059a..09fa4d6ba 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -1,6 +1,7 @@ """Tests for S3SessionManager.""" import json +from unittest.mock import Mock import boto3 import pytest @@ -379,7 +380,6 @@ def test__get_message_path_invalid_message_id(message_id, s3_manager): @pytest.fixture def mock_multi_agent(): """Create mock multi-agent for testing.""" - from unittest.mock import Mock mock = Mock() mock.id = "test-multi-agent" @@ -452,9 +452,6 @@ def test_update_multi_agent(multi_agent_manager, multi_agent_session, mock_multi multi_agent_manager.create_session(multi_agent_session) multi_agent_manager.create_multi_agent(multi_agent_session.session_id, mock_multi_agent) - # Update multi-agent - create a new mock with updated state - from unittest.mock import Mock - updated_mock = Mock() updated_mock.id = mock_multi_agent.id updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}} @@ -470,9 +467,6 @@ def test_update_nonexistent_multi_agent(multi_agent_manager, multi_agent_session # Create session multi_agent_manager.create_session(multi_agent_session) - # Update nonexistent multi-agent - from unittest.mock import Mock - nonexistent_mock = Mock() nonexistent_mock.id = "nonexistent" with pytest.raises(SessionException): From 2157597a0c0b39fe24fdb7ca29f3bb750f863811 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 28 Oct 2025 14:01:48 -0400 Subject: [PATCH 12/12] fix: fix unit tests import --- tests/strands/session/test_repository_session_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 0a38a8c95..e2931f3c3 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -241,7 +241,6 @@ def test_initialize_multi_agent_existing(multi_agent_session_manager, mock_multi multi_agent_session_manager.session_repository.create_multi_agent("test-multi-session", mock_multi_agent) # Create a mock with updated state for the update call - updated_mock = Mock() updated_mock.id = "test-multi-agent" existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}}