Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
coverage:
ignore:
- "src/strands/experimental/tools/tool_provider.py" # This is an interface, cannot meaningfully cover
31 changes: 31 additions & 0 deletions src/strands/_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Private async execution utilities."""

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Awaitable, Callable, TypeVar

T = TypeVar("T")


def run_async(async_func: Callable[[], Awaitable[T]]) -> T:
"""Run an async function in a separate thread to avoid event loop conflicts.

This utility handles the common pattern of running async code from sync contexts
by using ThreadPoolExecutor to isolate the async execution.

Args:
async_func: A callable that returns an awaitable

Returns:
The result of the async function
"""

async def execute_async() -> T:
return await async_func()

def execute() -> T:
return asyncio.run(execute_async())

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()
89 changes: 65 additions & 24 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import asyncio
import json
import logging
import random
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncGenerator,
Expand All @@ -31,7 +29,9 @@
from pydantic import BaseModel

from .. import _identifier
from .._async import run_async
from ..event_loop.event_loop import event_loop_cycle
from ..experimental.tools import ToolProvider
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..hooks import (
AfterInvocationEvent,
Expand Down Expand Up @@ -160,12 +160,7 @@ async def acall() -> ToolResult:

return tool_results[0]

def tcall() -> ToolResult:
return asyncio.run(acall())

with ThreadPoolExecutor() as executor:
future = executor.submit(tcall)
tool_result = future.result()
tool_result = run_async(acall)

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
Expand Down Expand Up @@ -208,7 +203,7 @@ def __init__(
self,
model: Union[Model, str, None] = None,
messages: Optional[Messages] = None,
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
tools: Optional[list[Union[str, dict[str, str], ToolProvider, Any]]] = None,
system_prompt: Optional[str] = None,
callback_handler: Optional[
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
Expand Down Expand Up @@ -240,7 +235,8 @@ def __init__(
- File paths (e.g., "/path/to/tool.py")
- Imported Python modules (e.g., from strands_tools import current_time)
- Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"})
- Functions decorated with `@strands.tool` decorator.
- Functions decorated with `@strands.tool` decorator
- ToolProvider instances for managed tool collections

If provided, only these tools will be available. If None, all tools will be available.
system_prompt: System prompt to guide model behavior.
Expand Down Expand Up @@ -333,6 +329,9 @@ def __init__(
else:
self.state = AgentState()

# Track cleanup state
self._cleanup_called = False

self.tool_caller = Agent.ToolCaller(self)

self.hooks = HookRegistry()
Expand Down Expand Up @@ -399,13 +398,7 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""

def execute() -> AgentResult:
return asyncio.run(self.invoke_async(prompt, **kwargs))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()
return run_async(lambda: self.invoke_async(prompt, **kwargs))

async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand Down Expand Up @@ -459,13 +452,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) ->
Raises:
ValueError: If no conversation history or prompt is provided.
"""

def execute() -> T:
return asyncio.run(self.structured_output_async(output_model, prompt))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()
return run_async(lambda: self.structured_output_async(output_model, prompt))

async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.
Expand Down Expand Up @@ -527,6 +514,60 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

def cleanup(self) -> None:
"""Clean up resources used by the agent.

This method cleans up all tool providers that require explicit cleanup,
such as MCP clients. It should be called when the agent is no longer needed
to ensure proper resource cleanup.

Note: This method uses a "belt and braces" approach with automatic cleanup
through __del__ as a fallback, but explicit cleanup is recommended.
"""
if self._cleanup_called:
return

run_async(self.cleanup_async)

async def cleanup_async(self) -> None:
"""Asynchronously clean up resources used by the agent.

This method cleans up all tool providers that require explicit cleanup,
such as MCP clients. It should be called when the agent is no longer needed
to ensure proper resource cleanup.

Note: This method uses a "belt and braces" approach with automatic cleanup
through __del__ as a fallback, but explicit cleanup is recommended.
"""
if self._cleanup_called:
return

logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id)

await self.tool_registry.cleanup_async()

self._cleanup_called = True
logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id)

def __del__(self) -> None:
"""Automatic cleanup when agent is garbage collected.

This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred.
"""
try:
if self._cleanup_called or not self.tool_registry.tool_providers:
return

logger.warning(
"agent_id=<%s> | Agent cleanup called via __del__. "
"Consider calling agent.cleanup() explicitly for better resource management.",
self.agent_id,
)
self.cleanup()
except Exception as e:
# Log exceptions during garbage collection cleanup for debugging
logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e)

async def stream_async(
self,
prompt: AgentInput = None,
Expand Down
4 changes: 4 additions & 0 deletions src/strands/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@

This module implements experimental features that are subject to change in future revisions without notice.
"""

from . import tools

__all__ = ["tools"]
5 changes: 5 additions & 0 deletions src/strands/experimental/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Experimental tools package."""

from .tool_provider import ToolProvider

__all__ = ["ToolProvider"]
49 changes: 49 additions & 0 deletions src/strands/experimental/tools/tool_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tool provider interface."""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence

if TYPE_CHECKING:
from ...types.tools import AgentTool


class ToolProvider(ABC):
"""Interface for providing tools with lifecycle management.

Provides a way to load a collection of tools and clean them up
when done, with lifecycle managed by the agent.
"""

@abstractmethod
async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]:
"""Load and return the tools in this provider.

Args:
**kwargs: Additional arguments for future compatibility.

Returns:
List of tools that are ready to use.
"""
pass

@abstractmethod
async def add_consumer(self, id: Any, **kwargs: Any) -> None:
"""Add a consumer to this tool provider.

Args:
id: Unique identifier for the consumer.
**kwargs: Additional arguments for future compatibility.
"""
pass

@abstractmethod
async def remove_consumer(self, id: Any, **kwargs: Any) -> None:
"""Remove a consumer from this tool provider.

Args:
id: Unique identifier for the consumer.
**kwargs: Additional arguments for future compatibility.

Provider may clean up resources when no consumers remain.
"""
pass
10 changes: 2 additions & 8 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""

import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union

from .._async import run_async
from ..agent import AgentResult
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage
Expand Down Expand Up @@ -111,9 +110,4 @@ def __call__(
if invocation_state is None:
invocation_state = {}

def execute() -> MultiAgentResult:
return asyncio.run(self.invoke_async(task, invocation_state, **kwargs))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()
return run_async(lambda: self.invoke_async(task, invocation_state, **kwargs))
9 changes: 2 additions & 7 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import copy
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Tuple

from opentelemetry import trace as trace_api

from .._async import run_async
from ..agent import Agent
from ..agent.state import AgentState
from ..telemetry import get_tracer
Expand Down Expand Up @@ -399,12 +399,7 @@ def __call__(
if invocation_state is None:
invocation_state = {}

def execute() -> GraphResult:
return asyncio.run(self.invoke_async(task, invocation_state))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()
return run_async(lambda: self.invoke_async(task, invocation_state))

async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
Expand Down
12 changes: 4 additions & 8 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from .._async import run_async
from ..agent import Agent
from ..agent.agent_result import AgentResult
from ..agent.state import AgentState
from ..telemetry import get_tracer
from ..tools.decorator import tool
Expand Down Expand Up @@ -254,12 +255,7 @@ def __call__(
if invocation_state is None:
invocation_state = {}

def execute() -> SwarmResult:
return asyncio.run(self.invoke_async(task, invocation_state))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()
return run_async(lambda: self.invoke_async(task, invocation_state))

async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions src/strands/tools/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from .mcp_agent_tool import MCPAgentTool
from .mcp_client import MCPClient
from .mcp_client import MCPClient, ToolFilters
from .mcp_types import MCPTransport

__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"]
__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"]
Loading
Loading