-
Couldn't load subscription status.
- Fork 452
feat: add multiagent session/repository management. #1071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
520dbf7
0beae27
7a94c15
836d2a2
aa3c905
01e6dbf
85c1c2d
dfd776c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,19 +5,24 @@ | |||||||
| import os | ||||||||
| import shutil | ||||||||
| import tempfile | ||||||||
| from typing import Any, Optional, cast | ||||||||
| from datetime import datetime, timezone | ||||||||
| from typing import TYPE_CHECKING, Any, Optional, cast | ||||||||
|
|
||||||||
| from .. import _identifier | ||||||||
| from ..types.exceptions import SessionException | ||||||||
| from ..types.session import Session, SessionAgent, SessionMessage | ||||||||
| from ..types.session import Session, SessionAgent, SessionMessage, SessionType | ||||||||
| from .repository_session_manager import RepositorySessionManager | ||||||||
| from .session_repository import SessionRepository | ||||||||
|
|
||||||||
| if TYPE_CHECKING: | ||||||||
| from ..multiagent.base import MultiAgentBase | ||||||||
|
|
||||||||
| logger = logging.getLogger(__name__) | ||||||||
|
|
||||||||
| SESSION_PREFIX = "session_" | ||||||||
| AGENT_PREFIX = "agent_" | ||||||||
| MESSAGE_PREFIX = "message_" | ||||||||
| MULTI_AGENT_PREFIX = "multi_agent_" | ||||||||
|
|
||||||||
|
|
||||||||
| class FileSessionManager(RepositorySessionManager, SessionRepository): | ||||||||
|
|
@@ -37,19 +42,27 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): | |||||||
| ``` | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): | ||||||||
| def __init__( | ||||||||
| self, | ||||||||
| session_id: str, | ||||||||
| storage_dir: Optional[str] = None, | ||||||||
| *, | ||||||||
| session_type: SessionType = SessionType.AGENT, | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to require that this is a named parameter - @dbschmigelski can you confirm - I recall you we discussed a similar situation before? (if so, is there anywhere we can document/write this down)?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, why does this need to be a named parameter? |
||||||||
| **kwargs: Any, | ||||||||
| ): | ||||||||
| """Initialize FileSession with filesystem storage. | ||||||||
|
|
||||||||
| Args: | ||||||||
| session_id: ID for the session. | ||||||||
| ID is not allowed to contain path separators (e.g., a/b). | ||||||||
| storage_dir: Directory for local filesystem storage (defaults to temp dir). | ||||||||
| session_type: single agent or multiagent. | ||||||||
| **kwargs: Additional keyword arguments for future extensibility. | ||||||||
| """ | ||||||||
| self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") | ||||||||
| os.makedirs(self.storage_dir, exist_ok=True) | ||||||||
|
|
||||||||
| super().__init__(session_id=session_id, session_repository=self) | ||||||||
| super().__init__(session_id=session_id, session_repository=self, session_type=session_type) | ||||||||
|
|
||||||||
| def _get_session_path(self, session_id: str) -> str: | ||||||||
| """Get session directory path. | ||||||||
|
|
@@ -107,8 +120,11 @@ def _read_file(self, path: str) -> dict[str, Any]: | |||||||
| def _write_file(self, path: str, data: dict[str, Any]) -> None: | ||||||||
| """Write JSON file.""" | ||||||||
| os.makedirs(os.path.dirname(path), exist_ok=True) | ||||||||
| with open(path, "w", encoding="utf-8") as f: | ||||||||
| # This automic write ensure the completeness of session files in both single agent/ multi agents | ||||||||
| tmp = f"{path}.tmp" | ||||||||
| with open(tmp, "w", encoding="utf-8", newline="\n") as f: | ||||||||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| json.dump(data, f, indent=2, ensure_ascii=False) | ||||||||
| os.replace(tmp, path) | ||||||||
|
|
||||||||
| def create_session(self, session: Session, **kwargs: Any) -> Session: | ||||||||
| """Create a new session.""" | ||||||||
|
|
@@ -118,8 +134,10 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: | |||||||
|
|
||||||||
| # Create directory structure | ||||||||
| os.makedirs(session_dir, exist_ok=True) | ||||||||
| os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) | ||||||||
|
|
||||||||
| if self.session_type == SessionType.AGENT: | ||||||||
| os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) | ||||||||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| else: | ||||||||
| os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True) | ||||||||
| # Write session file | ||||||||
| session_file = os.path.join(session_dir, "session.json") | ||||||||
| session_dict = session.to_dict() | ||||||||
|
|
@@ -136,6 +154,15 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: | |||||||
| session_data = self._read_file(session_file) | ||||||||
| return Session.from_dict(session_data) | ||||||||
|
|
||||||||
| def update_session(self, session_id: str, **kwargs: Any) -> None: | ||||||||
| """Update session updated_at field.""" | ||||||||
| session_file = os.path.join(self._get_session_path(session_id), "session.json") | ||||||||
| session_data = self.read_session(session_id) | ||||||||
| if session_data is None: | ||||||||
| raise SessionException(f"Session {session_id} does not exist") | ||||||||
| session_data.updated_at = datetime.now(timezone.utc).isoformat() | ||||||||
| self._write_file(session_file, session_data.to_dict()) | ||||||||
|
|
||||||||
| def delete_session(self, session_id: str, **kwargs: Any) -> None: | ||||||||
| """Delete session and all associated data.""" | ||||||||
| session_dir = self._get_session_path(session_id) | ||||||||
|
|
@@ -239,3 +266,41 @@ def list_messages( | |||||||
| messages.append(SessionMessage.from_dict(message_data)) | ||||||||
|
|
||||||||
| return messages | ||||||||
|
|
||||||||
| def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str: | ||||||||
| """Get multi-agent state file path.""" | ||||||||
| session_path = self._get_session_path(session_id) | ||||||||
| multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT) | ||||||||
| return os.path.join(session_path, "multi_agents", f"{MULTI_AGENT_PREFIX}{multi_agent_id}") | ||||||||
|
|
||||||||
| def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: | ||||||||
| """Create a new multiagent state in the session.""" | ||||||||
| multi_agent_id = multi_agent.id | ||||||||
| multi_agent_dir = self._get_multi_agent_path(session_id, multi_agent_id) | ||||||||
| os.makedirs(multi_agent_dir, exist_ok=True) | ||||||||
|
|
||||||||
| multi_agent_file = os.path.join(multi_agent_dir, "multi_agent.json") | ||||||||
| session_data = multi_agent.serialize_state() | ||||||||
| self._write_file(multi_agent_file, session_data) | ||||||||
|
|
||||||||
| def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initially I introduced a data class |
||||||||
| """Read multi-agent state from filesystem.""" | ||||||||
| multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") | ||||||||
| if not os.path.exists(multi_agent_file): | ||||||||
| return None | ||||||||
| return self._read_file(multi_agent_file) | ||||||||
|
|
||||||||
| def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: | ||||||||
| """Update multi-agent state from filesystem.""" | ||||||||
| multi_agent_id = multi_agent_state.get("id") | ||||||||
| if multi_agent_id is None: | ||||||||
| raise SessionException("MultiAgent state must have an 'id' field") | ||||||||
| previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent_id) | ||||||||
| if previous_multi_agent_state is None: | ||||||||
| raise SessionException(f"MultiAgent state {multi_agent_id} in session {session_id} does not exist") | ||||||||
|
|
||||||||
| multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") | ||||||||
| self._write_file(multi_agent_file, multi_agent_state) | ||||||||
|
|
||||||||
| # Update session.update_at | ||||||||
| self.update_session(session_id) | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see |
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,23 +2,27 @@ | |
|
|
||
| import json | ||
| import logging | ||
| from typing import Any, Dict, List, Optional, cast | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast | ||
|
|
||
| import boto3 | ||
| from botocore.config import Config as BotocoreConfig | ||
| from botocore.exceptions import ClientError | ||
|
|
||
| from .. import _identifier | ||
| from ..types.exceptions import SessionException | ||
| from ..types.session import Session, SessionAgent, SessionMessage | ||
| from ..types.session import Session, SessionAgent, SessionMessage, SessionType | ||
| from .repository_session_manager import RepositorySessionManager | ||
| from .session_repository import SessionRepository | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ..multiagent.base import MultiAgentBase | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| SESSION_PREFIX = "session_" | ||
| AGENT_PREFIX = "agent_" | ||
| MESSAGE_PREFIX = "message_" | ||
| MULTI_AGENT_PREFIX = "multi_agent_" | ||
|
|
||
|
|
||
| class S3SessionManager(RepositorySessionManager, SessionRepository): | ||
|
|
@@ -46,6 +50,8 @@ def __init__( | |
| boto_session: Optional[boto3.Session] = None, | ||
| boto_client_config: Optional[BotocoreConfig] = None, | ||
| region_name: Optional[str] = None, | ||
| *, | ||
| session_type: SessionType = SessionType.AGENT, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it should, I had a bad rebase There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it need to be a named parameter though? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This forces users to provide clarity by explicitly showing what sesion manager type they are using. |
||
| **kwargs: Any, | ||
| ): | ||
| """Initialize S3SessionManager with S3 storage. | ||
|
|
@@ -58,6 +64,7 @@ def __init__( | |
| boto_session: Optional boto3 session | ||
| boto_client_config: Optional boto3 client configuration | ||
| region_name: AWS region for S3 storage | ||
| session_type: single agent or multiagent. | ||
| **kwargs: Additional keyword arguments for future extensibility. | ||
| """ | ||
| self.bucket = bucket | ||
|
|
@@ -78,7 +85,7 @@ def __init__( | |
| client_config = BotocoreConfig(user_agent_extra="strands-agents") | ||
|
|
||
| self.client = session.client(service_name="s3", config=client_config) | ||
| super().__init__(session_id=session_id, session_repository=self) | ||
| super().__init__(session_id=session_id, session_type=session_type, session_repository=self) | ||
|
|
||
| def _get_session_path(self, session_id: str) -> str: | ||
| """Get session S3 prefix. | ||
|
|
@@ -294,3 +301,33 @@ def list_messages( | |
|
|
||
| except ClientError as e: | ||
| raise SessionException(f"S3 error reading messages: {e}") from e | ||
|
|
||
| def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str: | ||
| """Get multi-agent S3 prefix.""" | ||
| session_path = self._get_session_path(session_id) | ||
| multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT) | ||
| return f"{session_path}multi_agents/{MULTI_AGENT_PREFIX}{multi_agent_id}/" | ||
|
|
||
| def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: | ||
| """Create a new multiagent state in S3.""" | ||
| multi_agent_id = multi_agent.id | ||
| multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" | ||
| session_data = multi_agent.serialize_state() | ||
| self._write_s3_object(multi_agent_key, session_data) | ||
|
|
||
| def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: | ||
| """Read multi-agent state from S3.""" | ||
| multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" | ||
| return self._read_s3_object(multi_agent_key) | ||
|
|
||
| def update_multi_agent(self, session_id: str, multi_agent_state: dict[str, Any], **kwargs: Any) -> None: | ||
| """Update multi-agent state in S3.""" | ||
| multi_agent_id = multi_agent_state.get("id") | ||
| if multi_agent_id is None: | ||
| raise SessionException("MultiAgent state must have an 'id' field") | ||
| previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent_id) | ||
| if previous_multi_agent_state is None: | ||
| raise SessionException(f"MultiAgent state {multi_agent_id} in session {session_id} does not exist") | ||
|
|
||
| multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" | ||
| self._write_s3_object(multi_agent_key, multi_agent_state) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function wise no, s3 files has update at in console. |
||
Uh oh!
There was an error while loading. Please reload this page.