diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py index 36120b8ff..c323041a3 100644 --- a/src/strands/agent/state.py +++ b/src/strands/agent/state.py @@ -1,97 +1,6 @@ """Agent state management.""" -import copy -import json -from typing import Any, Dict, Optional +from ..types.json_dict import JSONSerializableDict - -class AgentState: - """Represents an Agent's stateful information outside of context provided to a model. - - Provides a key-value store for agent state with JSON serialization validation and persistence support. - Key features: - - JSON serialization validation on assignment - - Get/set/delete operations - """ - - def __init__(self, initial_state: Optional[Dict[str, Any]] = None): - """Initialize AgentState.""" - self._state: Dict[str, Dict[str, Any]] - if initial_state: - self._validate_json_serializable(initial_state) - self._state = copy.deepcopy(initial_state) - else: - self._state = {} - - def set(self, key: str, value: Any) -> None: - """Set a value in the state. - - Args: - key: The key to store the value under - value: The value to store (must be JSON serializable) - - Raises: - ValueError: If key is invalid, or if value is not JSON serializable - """ - self._validate_key(key) - self._validate_json_serializable(value) - - self._state[key] = copy.deepcopy(value) - - def get(self, key: Optional[str] = None) -> Any: - """Get a value or entire state. - - Args: - key: The key to retrieve (if None, returns entire state object) - - Returns: - The stored value, entire state dict, or None if not found - """ - if key is None: - return copy.deepcopy(self._state) - else: - # Return specific key - return copy.deepcopy(self._state.get(key)) - - def delete(self, key: str) -> None: - """Delete a specific key from the state. - - Args: - key: The key to delete - """ - self._validate_key(key) - - self._state.pop(key, None) - - def _validate_key(self, key: str) -> None: - """Validate that a key is valid. - - Args: - key: The key to validate - - Raises: - ValueError: If key is invalid - """ - if key is None: - raise ValueError("Key cannot be None") - if not isinstance(key, str): - raise ValueError("Key must be a string") - if not key.strip(): - raise ValueError("Key cannot be empty") - - def _validate_json_serializable(self, value: Any) -> None: - """Validate that a value is JSON serializable. - - Args: - value: The value to validate - - Raises: - ValueError: If value is not JSON serializable - """ - try: - json.dumps(value) - except (TypeError, ValueError) as e: - raise ValueError( - f"Value is not JSON serializable: {type(value).__name__}. " - f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." - ) from e +# Type alias for agent state +AgentState = JSONSerializableDict diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 188c80c69..3c1d0ee46 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,7 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ -from . import tools +from . import steering, tools from .agent_config import config_to_agent -__all__ = ["config_to_agent", "tools"] +__all__ = ["config_to_agent", "tools", "steering"] diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py new file mode 100644 index 000000000..4d0775873 --- /dev/null +++ b/src/strands/experimental/steering/__init__.py @@ -0,0 +1,46 @@ +"""Steering system for Strands agents. + +Provides contextual guidance for agents through modular prompting with progressive disclosure. +Instead of front-loading all instructions, steering handlers provide just-in-time feedback +based on local context data populated by context callbacks. + +Core components: + +- SteeringHandler: Base class for guidance logic with local context +- SteeringContextCallback: Protocol for context update functions +- SteeringContextProvider: Protocol for multi-event context providers +- SteeringAction: Proceed/Guide/Interrupt decisions + +Usage: + handler = LLMSteeringHandler(system_prompt="...") + agent = Agent(tools=[...], hooks=[handler]) +""" + +# Core primitives +# Context providers +from .context_providers.ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) +from .core.action import Guide, Interrupt, Proceed, SteeringAction +from .core.context import SteeringContextCallback, SteeringContextProvider +from .core.handler import SteeringHandler + +# Handler implementations +from .handlers.llm import LLMPromptMapper, LLMSteeringHandler + +__all__ = [ + "SteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContextCallback", + "SteeringContextProvider", + "LedgerBeforeToolCall", + "LedgerAfterToolCall", + "LedgerProvider", + "LLMSteeringHandler", + "LLMPromptMapper", +] diff --git a/src/strands/experimental/steering/context_providers/__init__.py b/src/strands/experimental/steering/context_providers/__init__.py new file mode 100644 index 000000000..242ed9cf1 --- /dev/null +++ b/src/strands/experimental/steering/context_providers/__init__.py @@ -0,0 +1,13 @@ +"""Context providers for steering evaluation.""" + +from .ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) + +__all__ = [ + "LedgerAfterToolCall", + "LedgerBeforeToolCall", + "LedgerProvider", +] diff --git a/src/strands/experimental/steering/context_providers/ledger_provider.py b/src/strands/experimental/steering/context_providers/ledger_provider.py new file mode 100644 index 000000000..da8504bd0 --- /dev/null +++ b/src/strands/experimental/steering/context_providers/ledger_provider.py @@ -0,0 +1,85 @@ +"""Ledger context provider for comprehensive agent activity tracking. + +Tracks complete agent activity ledger including tool calls, conversation history, +and timing information. This comprehensive audit trail enables steering handlers +to make informed guidance decisions based on agent behavior patterns and history. + +Data captured: + + - Tool call history with inputs, outputs, timing, success/failure + - Conversation messages and agent responses + - Session metadata and timing information + - Error patterns and recovery attempts + +Usage: + Use as context provider functions or mix into steering handlers. +""" + +import logging +from datetime import datetime +from typing import Any + +from ....hooks.events import AfterToolCallEvent, BeforeToolCallEvent +from ..core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider + +logger = logging.getLogger(__name__) + + +class LedgerBeforeToolCall(SteeringContextCallback[BeforeToolCallEvent]): + """Context provider for ledger tracking before tool calls.""" + + def __init__(self) -> None: + """Initialize the ledger provider.""" + self.session_start = datetime.now().isoformat() + + def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: + """Update ledger before tool call.""" + ledger = steering_context.data.get("ledger") or {} + + if not ledger: + ledger = { + "session_start": self.session_start, + "tool_calls": [], + "conversation_history": [], + "session_metadata": {}, + } + + tool_call_entry = { + "timestamp": datetime.now().isoformat(), + "tool_name": event.tool_use.get("name"), + "tool_args": event.tool_use.get("arguments", {}), + "status": "pending", + } + ledger["tool_calls"].append(tool_call_entry) + steering_context.data.set("ledger", ledger) + + +class LedgerAfterToolCall(SteeringContextCallback[AfterToolCallEvent]): + """Context provider for ledger tracking after tool calls.""" + + def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: + """Update ledger after tool call.""" + ledger = steering_context.data.get("ledger") or {} + + if ledger.get("tool_calls"): + last_call = ledger["tool_calls"][-1] + last_call.update( + { + "completion_timestamp": datetime.now().isoformat(), + "status": event.result["status"], + "result": event.result["content"], + "error": str(event.exception) if event.exception else None, + } + ) + steering_context.data.set("ledger", ledger) + + +class LedgerProvider(SteeringContextProvider): + """Combined ledger context provider for both before and after tool calls.""" + + def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: + """Return ledger context providers with shared state.""" + return [ + LedgerBeforeToolCall(), + LedgerAfterToolCall(), + ] diff --git a/src/strands/experimental/steering/core/__init__.py b/src/strands/experimental/steering/core/__init__.py new file mode 100644 index 000000000..a3efe0dbc --- /dev/null +++ b/src/strands/experimental/steering/core/__init__.py @@ -0,0 +1,6 @@ +"""Core steering system interfaces and base classes.""" + +from .action import Guide, Interrupt, Proceed, SteeringAction +from .handler import SteeringHandler + +__all__ = ["SteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] diff --git a/src/strands/experimental/steering/core/action.py b/src/strands/experimental/steering/core/action.py new file mode 100644 index 000000000..8b4ec141d --- /dev/null +++ b/src/strands/experimental/steering/core/action.py @@ -0,0 +1,65 @@ +"""SteeringAction types for steering evaluation results. + +Defines structured outcomes from steering handlers that determine how tool calls +should be handled. SteeringActions enable modular prompting by providing just-in-time +feedback rather than front-loading all instructions in monolithic prompts. + +Flow: + SteeringHandler.steer() → SteeringAction → BeforeToolCallEvent handling + ↓ ↓ ↓ + Evaluate context Action type Tool execution modified + +SteeringAction types: + Proceed: Tool executes immediately (no intervention needed) + Guide: Tool cancelled, agent receives contextual feedback to explore alternatives + Interrupt: Tool execution paused for human input via interrupt system + +Extensibility: + New action types can be added to the union. Always handle the default + case in pattern matching to maintain backward compatibility. +""" + +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + + +class Proceed(BaseModel): + """Allow tool to execute immediately without intervention. + + The tool call proceeds as planned. The reason provides context + for logging and debugging purposes. + """ + + type: Literal["proceed"] = "proceed" + reason: str + + +class Guide(BaseModel): + """Cancel tool and provide contextual feedback for agent to explore alternatives. + + The tool call is cancelled and the agent receives the reason as contextual + feedback to help them consider alternative approaches while maintaining + adaptive reasoning capabilities. + """ + + type: Literal["guide"] = "guide" + reason: str + + +class Interrupt(BaseModel): + """Pause tool execution for human input via interrupt system. + + The tool call is paused and human input is requested through Strands' + interrupt system. The human can approve or deny the operation, and their + decision determines whether the tool executes or is cancelled. + """ + + type: Literal["interrupt"] = "interrupt" + reason: str + + +# SteeringAction union - extensible for future action types +# IMPORTANT: Always handle the default case when pattern matching +# to maintain backward compatibility as new action types are added +SteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] diff --git a/src/strands/experimental/steering/core/context.py b/src/strands/experimental/steering/core/context.py new file mode 100644 index 000000000..446c4c9f9 --- /dev/null +++ b/src/strands/experimental/steering/core/context.py @@ -0,0 +1,77 @@ +"""Steering context protocols for contextual guidance. + +Defines protocols for context callbacks and providers that populate +steering context data used by handlers to make guidance decisions. + +Architecture: + SteeringContextCallback → Handler.steering_context → SteeringHandler.steer() + ↓ ↓ ↓ + Update local context Store in handler Access via self.steering_context + +Context lifecycle: + 1. Handler registers context callbacks for hook events + 2. Callbacks update handler's local steering_context on events + 3. Handler accesses self.steering_context in steer() method + 4. Context persists across calls within handler instance + +Implementation: + Each handler maintains its own JSONSerializableDict context. + Callbacks are registered per handler instance for isolation. + Providers can supply multiple callbacks for different events. +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Generic, TypeVar, cast, get_args, get_origin + +from ....hooks.registry import HookEvent +from ....types.json_dict import JSONSerializableDict + +logger = logging.getLogger(__name__) + + +@dataclass +class SteeringContext: + """Container for steering context data.""" + + """Container for steering context data. + + This class should not be instantiated directly - it is intended for internal use only. + """ + + data: JSONSerializableDict = field(default_factory=JSONSerializableDict) + + +EventType = TypeVar("EventType", bound=HookEvent, contravariant=True) + + +class SteeringContextCallback(ABC, Generic[EventType]): + """Abstract base class for steering context update callbacks.""" + + @property + def event_type(self) -> type[HookEvent]: + """Return the event type this callback handles.""" + for base in getattr(self.__class__, "__orig_bases__", ()): + if get_origin(base) is SteeringContextCallback: + return cast(type[HookEvent], get_args(base)[0]) + raise ValueError("Could not determine event type from generic parameter") + + def __call__(self, event: EventType, steering_context: "SteeringContext", **kwargs: Any) -> None: + """Update steering context based on hook event. + + Args: + event: The hook event that triggered the callback + steering_context: The steering context to update + **kwargs: Additional keyword arguments for context updates + """ + ... + + +class SteeringContextProvider(ABC): + """Abstract base class for context providers that handle multiple event types.""" + + @abstractmethod + def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: + """Return list of context callbacks with event types extracted from generics.""" + ... diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py new file mode 100644 index 000000000..4a0bcaa6a --- /dev/null +++ b/src/strands/experimental/steering/core/handler.py @@ -0,0 +1,134 @@ +"""Steering handler base class for providing contextual guidance to agents. + +Provides modular prompting through contextual guidance that appears when relevant, +rather than front-loading all instructions. Handlers integrate with the Strands hook +system to intercept tool calls and provide just-in-time feedback based on local context. + +Architecture: + BeforeToolCallEvent → Context Callbacks → Update steering_context → steer() → SteeringAction + ↓ ↓ ↓ ↓ ↓ + Hook triggered Populate context Handler evaluates Handler decides Action taken + +Lifecycle: + 1. Context callbacks update handler's steering_context on hook events + 2. BeforeToolCallEvent triggers steering evaluation via steer() method + 3. Handler accesses self.steering_context for guidance decisions + 4. SteeringAction determines tool execution: Proceed/Guide/Interrupt + +Implementation: + Subclass SteeringHandler and implement steer() method. + Pass context_callbacks in constructor to register context update functions. + Each handler maintains isolated steering_context that persists across calls. + +SteeringAction handling: + Proceed: Tool executes immediately + Guide: Tool cancelled, agent receives contextual feedback to explore alternatives + Interrupt: Tool execution paused for human input via interrupt system +""" + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from ....hooks.events import BeforeToolCallEvent +from ....hooks.registry import HookProvider, HookRegistry +from ....types.tools import ToolUse +from .action import Guide, Interrupt, Proceed, SteeringAction +from .context import SteeringContext, SteeringContextProvider + +if TYPE_CHECKING: + from ....agent import Agent + +logger = logging.getLogger(__name__) + + +class SteeringHandler(HookProvider, ABC): + """Base class for steering handlers that provide contextual guidance to agents. + + Steering handlers maintain local context and register hook callbacks + to populate context data as needed for guidance decisions. + """ + + def __init__(self, context_providers: list[SteeringContextProvider] | None = None): + """Initialize the steering handler. + + Args: + context_providers: List of context providers for context updates + """ + super().__init__() + self.steering_context = SteeringContext() + self._context_callbacks = [] + + # Collect callbacks from all providers + for provider in context_providers or []: + self._context_callbacks.extend(provider.context_providers()) + + logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for steering guidance and context updates.""" + # Register context update callbacks + for callback in self._context_callbacks: + registry.add_callback( + callback.event_type, lambda event, callback=callback: callback(event, self.steering_context) + ) + + # Register steering guidance + registry.add_callback(BeforeToolCallEvent, self._provide_steering_guidance) + + async def _provide_steering_guidance(self, event: BeforeToolCallEvent) -> None: + """Provide steering guidance for tool call.""" + tool_name = event.tool_use["name"] + logger.debug("tool_name=<%s> | providing steering guidance", tool_name) + + try: + action = await self.steer(event.agent, event.tool_use) + except Exception as e: + logger.debug("tool_name=<%s>, error=<%s> | steering handler guidance failed", tool_name, e) + return + + self._handle_steering_action(action, event, tool_name) + + def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallEvent, tool_name: str) -> None: + """Handle the steering action by modifying tool execution flow. + + Proceed: Tool executes normally + Guide: Tool cancelled with contextual feedback for agent to consider alternatives + Interrupt: Tool execution paused for human input via interrupt system + """ + if isinstance(action, Proceed): + logger.debug("tool_name=<%s> | tool call proceeding", tool_name) + elif isinstance(action, Guide): + logger.debug("tool_name=<%s> | tool call guided: %s", tool_name, action.reason) + event.cancel_tool = ( + f"Tool call cancelled given new guidance. {action.reason}. Consider this approach and continue" + ) + elif isinstance(action, Interrupt): + logger.debug("tool_name=<%s> | tool call requires human input: %s", tool_name, action.reason) + can_proceed: bool = event.interrupt(name=f"steering_input_{tool_name}", reason={"message": action.reason}) + logger.debug("tool_name=<%s> | received human input for tool call", tool_name) + + if not can_proceed: + event.cancel_tool = f"Manual approval denied: {action.reason}" + logger.debug("tool_name=<%s> | tool call denied by manual approval", tool_name) + else: + logger.debug("tool_name=<%s> | tool call approved manually", tool_name) + else: + raise ValueError(f"Unknown steering action type: {action}") + + @abstractmethod + async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + """Provide contextual guidance to help agent navigate complex workflows. + + Args: + agent: The agent instance + tool_use: The tool use object with name and arguments + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + SteeringAction indicating how to guide the agent's next action + + Note: + Access steering context via self.steering_context + """ + ... diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py new file mode 100644 index 000000000..ca529530f --- /dev/null +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -0,0 +1,3 @@ +"""Steering handler implementations.""" + +__all__ = [] diff --git a/src/strands/experimental/steering/handlers/llm/__init__.py b/src/strands/experimental/steering/handlers/llm/__init__.py new file mode 100644 index 000000000..4dcccbe80 --- /dev/null +++ b/src/strands/experimental/steering/handlers/llm/__init__.py @@ -0,0 +1,6 @@ +"""LLM steering handler with prompt mapping.""" + +from .llm_handler import LLMSteeringHandler +from .mappers import DefaultPromptMapper, LLMPromptMapper, ToolUse + +__all__ = ["LLMSteeringHandler", "LLMPromptMapper", "DefaultPromptMapper", "ToolUse"] diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py new file mode 100644 index 000000000..b269d4b60 --- /dev/null +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -0,0 +1,94 @@ +"""LLM-based steering handler that uses an LLM to provide contextual guidance.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Literal, cast + +from pydantic import BaseModel, Field + +from .....models import Model +from .....types.tools import ToolUse +from ...context_providers.ledger_provider import LedgerProvider +from ...core.action import Guide, Interrupt, Proceed, SteeringAction +from ...core.context import SteeringContextProvider +from ...core.handler import SteeringHandler +from .mappers import DefaultPromptMapper, LLMPromptMapper + +if TYPE_CHECKING: + from .....agent import Agent + +logger = logging.getLogger(__name__) + + +class _LLMSteering(BaseModel): + """Structured output model for LLM steering decisions.""" + + decision: Literal["proceed", "guide", "interrupt"] = Field( + description="Steering decision: 'proceed' to continue, 'guide' to provide feedback, 'interrupt' for human input" + ) + reason: str = Field(description="Clear explanation of the steering decision and any guidance provided") + + +class LLMSteeringHandler(SteeringHandler): + """Steering handler that uses an LLM to provide contextual guidance. + + Uses natural language prompts to evaluate tool calls and provide + contextual steering guidance to help agents navigate complex workflows. + """ + + def __init__( + self, + system_prompt: str, + prompt_mapper: LLMPromptMapper | None = None, + model: Model | None = None, + context_providers: list[SteeringContextProvider] | None = None, + ): + """Initialize the LLMSteeringHandler. + + Args: + system_prompt: System prompt defining steering guidance rules + prompt_mapper: Custom prompt mapper for evaluation prompts + model: Optional model override for steering evaluation + context_providers: List of context providers for populating steering context + """ + providers = context_providers or [LedgerProvider()] + super().__init__(context_providers=providers) + self.system_prompt = system_prompt + self.prompt_mapper = prompt_mapper or DefaultPromptMapper() + self.model = model + + async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + """Provide contextual guidance for tool usage. + + Args: + agent: The agent instance + tool_use: The tool use object with name and arguments + **kwargs: Additional keyword arguments for steering evaluation + + Returns: + SteeringAction indicating how to guide the agent's next action + """ + # Generate steering prompt + prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) + + # Create isolated agent for steering evaluation (no shared conversation state) + from .....agent import Agent + + steering_agent = Agent(system_prompt=self.system_prompt, model=self.model or agent.model, callback_handler=None) + + # Get LLM decision + llm_result: _LLMSteering = cast( + _LLMSteering, steering_agent(prompt, structured_output_model=_LLMSteering).structured_output + ) + + # Convert LLM decision to steering action + if llm_result.decision == "proceed": + return Proceed(reason=llm_result.reason) + elif llm_result.decision == "guide": + return Guide(reason=llm_result.reason) + elif llm_result.decision == "interrupt": + return Interrupt(reason=llm_result.reason) + else: + logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/src/strands/experimental/steering/handlers/llm/mappers.py b/src/strands/experimental/steering/handlers/llm/mappers.py new file mode 100644 index 000000000..9901da7d4 --- /dev/null +++ b/src/strands/experimental/steering/handlers/llm/mappers.py @@ -0,0 +1,116 @@ +"""LLM steering prompt mappers for generating evaluation prompts.""" + +import json +from typing import Any, Protocol + +from .....types.tools import ToolUse +from ...core.context import SteeringContext + +# Agent SOP format - see https://github.com/strands-agents/agent-sop +_STEERING_PROMPT_TEMPLATE = """# Steering Evaluation + +## Overview + +You are a STEERING AGENT that evaluates a {action_type} that ANOTHER AGENT is attempting to make. +Your job is to provide contextual guidance to help the other agent navigate workflows effectively. +You act as a safety net that can intervene when patterns in the context data suggest the agent +should try a different approach or get human input. + +**YOUR ROLE:** +- Analyze context data for concerning patterns (repeated failures, inappropriate timing, etc.) +- Provide just-in-time guidance when the agent is going down an ineffective path +- Allow normal operations to proceed when context shows no issues + +**CRITICAL CONSTRAINTS:** +- Base decisions ONLY on the context data provided below +- Do NOT use external knowledge about domains, URLs, or tool purposes +- Do NOT make assumptions about what tools "should" or "shouldn't" do +- Focus ONLY on patterns in the context data + +## Context + +{context_str} + +## Event to Evaluate + +{event_description} + +## Steps + +### 1. Analyze the {action_type_title} + +Review ONLY the context data above. Look for patterns in the data that indicate: + +- Previous failures or successes with this tool +- Frequency of attempts +- Any relevant tracking information + +**Constraints:** +- You MUST base analysis ONLY on the provided context data +- You MUST NOT use external knowledge about tool purposes or domains +- You SHOULD identify patterns in the context data +- You MAY reference relevant context data to inform your decision + +### 2. Make Steering Decision + +**Constraints:** +- You MUST respond with exactly one of: "proceed", "guide", or "interrupt" +- You MUST base the decision ONLY on context data patterns +- Your reason will be shown to the AGENT as guidance + +**Decision Options:** +- "proceed" if context data shows no concerning patterns +- "guide" if context data shows patterns requiring intervention +- "interrupt" if context data shows patterns requiring human input +""" + + +class LLMPromptMapper(Protocol): + """Protocol for mapping context and events to LLM evaluation prompts.""" + + def create_steering_prompt( + self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any + ) -> str: + """Create steering prompt for LLM evaluation. + + Args: + steering_context: Steering context with populated data + tool_use: Tool use object for tool call events (None for other events) + **kwargs: Additional event data for other steering events + + Returns: + Formatted prompt string for LLM evaluation + """ + ... + + +class DefaultPromptMapper(LLMPromptMapper): + """Default prompt mapper for steering evaluation.""" + + def create_steering_prompt( + self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any + ) -> str: + """Create default steering prompt using Agent SOP structure. + + Uses Agent SOP format for structured, constraint-based prompts. + See: https://github.com/strands-agents/agent-sop + """ + context_str = ( + json.dumps(steering_context.data.get(), indent=2) if steering_context.data.get() else "No context available" + ) + + if tool_use: + event_description = ( + f"Tool: {tool_use['name']}\nArguments: {json.dumps(tool_use.get('input', {}), indent=2)}" + ) + action_type = "tool call" + else: + event_description = "General evaluation" + action_type = "action" + + return _STEERING_PROMPT_TEMPLATE.format( + action_type=action_type, + action_type_title=action_type.title(), + context_str=context_str, + event_description=event_description, + ) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index efe0894ea..ea32bb27b 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -286,7 +286,7 @@ def __init__(self, tool_result: ToolResult) -> None: @property def tool_use_id(self) -> str: """The toolUseId associated with this result.""" - return cast(ToolResult, self.get("tool_result")).get("toolUseId") + return cast(ToolResult, self.get("tool_result"))["toolUseId"] @property def tool_result(self) -> ToolResult: @@ -314,7 +314,7 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" - return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId") + return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] class ToolCancelEvent(TypedEvent): @@ -332,7 +332,7 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId") + return cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use"))["toolUseId"] @property def message(self) -> str: @@ -350,7 +350,7 @@ def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: @property def tool_use_id(self) -> str: """The id of the tool interrupted.""" - return cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId") + return cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use"))["toolUseId"] @property def interrupts(self) -> list[Interrupt]: diff --git a/src/strands/types/json_dict.py b/src/strands/types/json_dict.py new file mode 100644 index 000000000..a8636ab10 --- /dev/null +++ b/src/strands/types/json_dict.py @@ -0,0 +1,92 @@ +"""JSON serializable dictionary utilities.""" + +import copy +import json +from typing import Any + + +class JSONSerializableDict: + """A key-value store with JSON serialization validation. + + Provides a dict-like interface with automatic validation that all values + are JSON serializable on assignment. + """ + + def __init__(self, initial_state: dict[str, Any] | None = None): + """Initialize JSONSerializableDict.""" + self._data: dict[str, Any] + if initial_state: + self._validate_json_serializable(initial_state) + self._data = copy.deepcopy(initial_state) + else: + self._data = {} + + def set(self, key: str, value: Any) -> None: + """Set a value in the store. + + Args: + key: The key to store the value under + value: The value to store (must be JSON serializable) + + Raises: + ValueError: If key is invalid, or if value is not JSON serializable + """ + self._validate_key(key) + self._validate_json_serializable(value) + self._data[key] = copy.deepcopy(value) + + def get(self, key: str | None = None) -> Any: + """Get a value or entire data. + + Args: + key: The key to retrieve (if None, returns entire data dict) + + Returns: + The stored value, entire data dict, or None if not found + """ + if key is None: + return copy.deepcopy(self._data) + else: + return copy.deepcopy(self._data.get(key)) + + def delete(self, key: str) -> None: + """Delete a specific key from the store. + + Args: + key: The key to delete + """ + self._validate_key(key) + self._data.pop(key, None) + + def _validate_key(self, key: str) -> None: + """Validate that a key is valid. + + Args: + key: The key to validate + + Raises: + ValueError: If key is invalid + """ + if key is None: + raise ValueError("Key cannot be None") + if not isinstance(key, str): + raise ValueError("Key must be a string") + if not key.strip(): + raise ValueError("Key cannot be empty") + + def _validate_json_serializable(self, value: Any) -> None: + """Validate that a value is JSON serializable. + + Args: + value: The value to validate + + Raises: + ValueError: If value is not JSON serializable + """ + try: + json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"Value is not JSON serializable: {type(value).__name__}. " + f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." + ) from e diff --git a/tests/strands/experimental/steering/__init__.py b/tests/strands/experimental/steering/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/context_providers/__init__.py b/tests/strands/experimental/steering/context_providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py new file mode 100644 index 000000000..4356b3ea8 --- /dev/null +++ b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py @@ -0,0 +1,135 @@ +"""Unit tests for ledger context providers.""" + +from unittest.mock import Mock, patch + +from strands.experimental.steering.context_providers.ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) +from strands.experimental.steering.core.context import SteeringContext +from strands.hooks.events import AfterToolCallEvent, BeforeToolCallEvent + + +def test_context_providers_method(): + """Test context_providers method returns correct callbacks.""" + provider = LedgerProvider() + + callbacks = provider.context_providers() + + assert len(callbacks) == 2 + assert isinstance(callbacks[0], LedgerBeforeToolCall) + assert isinstance(callbacks[1], LedgerAfterToolCall) + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_ledger_before_tool_call_new_ledger(mock_datetime): + """Test LedgerBeforeToolCall with new ledger.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + tool_use = {"name": "test_tool", "arguments": {"param": "value"}} + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = tool_use + + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger is not None + assert "session_start" in ledger + assert "tool_calls" in ledger + assert len(ledger["tool_calls"]) == 1 + + tool_call = ledger["tool_calls"][0] + assert tool_call["tool_name"] == "test_tool" + assert tool_call["tool_args"] == {"param": "value"} + assert tool_call["status"] == "pending" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_ledger_before_tool_call_existing_ledger(mock_datetime): + """Test LedgerBeforeToolCall with existing ledger.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + # Set up existing ledger + existing_ledger = { + "session_start": "2024-01-01T10:00:00", + "tool_calls": [{"name": "previous_tool"}], + "conversation_history": [], + "session_metadata": {}, + } + steering_context.data.set("ledger", existing_ledger) + + tool_use = {"name": "new_tool", "arguments": {"param": "value"}} + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = tool_use + + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert len(ledger["tool_calls"]) == 2 + assert ledger["tool_calls"][0]["name"] == "previous_tool" + assert ledger["tool_calls"][1]["tool_name"] == "new_tool" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_ledger_after_tool_call_success(mock_datetime): + """Test LedgerAfterToolCall with successful completion.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:05:00" + + callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Set up existing ledger with pending call + existing_ledger = { + "tool_calls": [{"tool_name": "test_tool", "status": "pending", "timestamp": "2024-01-01T12:00:00"}] + } + steering_context.data.set("ledger", existing_ledger) + + event = Mock(spec=AfterToolCallEvent) + event.result = {"status": "success", "content": ["success_result"]} + event.exception = None + + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + tool_call = ledger["tool_calls"][0] + assert tool_call["status"] == "success" + assert tool_call["result"] == ["success_result"] + assert tool_call["error"] is None + assert tool_call["completion_timestamp"] == "2024-01-01T12:05:00" + + +def test_ledger_after_tool_call_no_calls(): + """Test LedgerAfterToolCall when no tool calls exist.""" + callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Set up ledger with no tool calls + existing_ledger = {"tool_calls": []} + steering_context.data.set("ledger", existing_ledger) + + event = Mock(spec=AfterToolCallEvent) + event.result = {"status": "success", "content": ["test"]} + event.exception = None + + # Should not crash when no tool calls exist + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"] == [] + + +def test_session_start_persistence(): + """Test that session_start is set during initialization and persists.""" + with patch("strands.experimental.steering.context_providers.ledger_provider.datetime") as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T10:00:00" + + callback = LedgerBeforeToolCall() + + assert callback.session_start == "2024-01-01T10:00:00" diff --git a/tests/strands/experimental/steering/core/__init__.py b/tests/strands/experimental/steering/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py new file mode 100644 index 000000000..8d5ef6884 --- /dev/null +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -0,0 +1,278 @@ +"""Unit tests for steering handler base class.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.steering.core.action import Guide, Interrupt, Proceed +from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider +from strands.experimental.steering.core.handler import SteeringHandler +from strands.hooks.events import BeforeToolCallEvent +from strands.hooks.registry import HookRegistry + + +class TestSteeringHandler(SteeringHandler): + """Test implementation of SteeringHandler.""" + + async def steer(self, agent, tool_use, **kwargs): + return Proceed(reason="Test proceed") + + +def test_steering_handler_initialization(): + """Test SteeringHandler initialization.""" + handler = TestSteeringHandler() + assert handler is not None + + +def test_register_hooks(): + """Test hook registration.""" + handler = TestSteeringHandler() + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Verify hooks were registered + assert registry.add_callback.call_count >= 1 + registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_steering_guidance) + + +def test_steering_context_initialization(): + """Test steering context is initialized.""" + handler = TestSteeringHandler() + + assert handler.steering_context is not None + assert isinstance(handler.steering_context, SteeringContext) + + +def test_steering_context_persistence(): + """Test steering context persists across calls.""" + handler = TestSteeringHandler() + + handler.steering_context.data.set("test", "value") + assert handler.steering_context.data.get("test") == "value" + + +def test_steering_context_access(): + """Test steering context can be accessed and modified.""" + handler = TestSteeringHandler() + + handler.steering_context.data.set("key", "value") + assert handler.steering_context.data.get("key") == "value" + + +@pytest.mark.asyncio +async def test_proceed_action_flow(): + """Test complete flow with Proceed action.""" + + class ProceedHandler(SteeringHandler): + async def steer(self, agent, tool_use, **kwargs): + return Proceed(reason="Test proceed") + + handler = ProceedHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + await handler._provide_steering_guidance(event) + + # Should not modify event for Proceed + assert not event.cancel_tool + + +@pytest.mark.asyncio +async def test_guide_action_flow(): + """Test complete flow with Guide action.""" + + class GuideHandler(SteeringHandler): + async def steer(self, agent, tool_use, **kwargs): + return Guide(reason="Test guidance") + + handler = GuideHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + await handler._provide_steering_guidance(event) + + # Should set cancel_tool with guidance message + expected_message = "Tool call cancelled given new guidance. Test guidance. Consider this approach and continue" + assert event.cancel_tool == expected_message + + +@pytest.mark.asyncio +async def test_interrupt_action_approved_flow(): + """Test complete flow with Interrupt action when approved.""" + + class InterruptHandler(SteeringHandler): + async def steer(self, agent, tool_use, **kwargs): + return Interrupt(reason="Need approval") + + handler = InterruptHandler() + tool_use = {"name": "test_tool"} + event = Mock() + event.tool_use = tool_use + event.interrupt = Mock(return_value=True) # Approved + + await handler._provide_steering_guidance(event) + + event.interrupt.assert_called_once() + + +@pytest.mark.asyncio +async def test_interrupt_action_denied_flow(): + """Test complete flow with Interrupt action when denied.""" + + class InterruptHandler(SteeringHandler): + async def steer(self, agent, tool_use, **kwargs): + return Interrupt(reason="Need approval") + + handler = InterruptHandler() + tool_use = {"name": "test_tool"} + event = Mock() + event.tool_use = tool_use + event.interrupt = Mock(return_value=False) # Denied + + await handler._provide_steering_guidance(event) + + event.interrupt.assert_called_once() + assert event.cancel_tool.startswith("Manual approval denied:") + + +@pytest.mark.asyncio +async def test_unknown_action_flow(): + """Test complete flow with unknown action type raises error.""" + + class UnknownActionHandler(SteeringHandler): + async def steer(self, agent, tool_use, **kwargs): + return Mock() # Not a valid SteeringAction + + handler = UnknownActionHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + with pytest.raises(ValueError, match="Unknown steering action type"): + await handler._provide_steering_guidance(event) + + +def test_register_steering_hooks_override(): + """Test that _register_steering_hooks can be overridden.""" + + class CustomHandler(SteeringHandler): + async def steer(self, agent, tool_use, **kwargs): + return Proceed(reason="Custom") + + def register_hooks(self, registry, **kwargs): + # Custom hook registration - don't call parent + pass + + handler = CustomHandler() + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Should not register any hooks + assert registry.add_callback.call_count == 0 + + +# Integration tests with context providers +class MockContextCallback(SteeringContextCallback[BeforeToolCallEvent]): + """Mock context callback for testing.""" + + def __call__(self, event: BeforeToolCallEvent, steering_context, **kwargs) -> None: + steering_context.data.set("test_key", "test_value") + + +class MockContextProvider(SteeringContextProvider): + """Mock context provider for testing.""" + + def __init__(self, callbacks): + self.callbacks = callbacks + + def context_providers(self): + return self.callbacks + + +class TestSteeringHandlerWithProvider(SteeringHandler): + """Test implementation with context callbacks.""" + + def __init__(self, context_callbacks=None): + providers = [MockContextProvider(context_callbacks)] if context_callbacks else None + super().__init__(context_providers=providers) + + async def steer(self, agent, tool_use, **kwargs): + return Proceed(reason="Test proceed") + + +def test_handler_registers_context_provider_hooks(): + """Test that handler registers hooks from context callbacks.""" + mock_callback = MockContextCallback() + handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Should register hooks for context callback and steering guidance + assert registry.add_callback.call_count >= 2 + + # Check that BeforeToolCallEvent was registered + call_args = [call[0] for call in registry.add_callback.call_args_list] + event_types = [args[0] for args in call_args] + + assert BeforeToolCallEvent in event_types + + +def test_context_callbacks_receive_steering_context(): + """Test that context callbacks receive the handler's steering context.""" + mock_callback = MockContextCallback() + handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Get the registered callback for BeforeToolCallEvent + before_callback = None + for call in registry.add_callback.call_args_list: + if call[0][0] == BeforeToolCallEvent: + before_callback = call[0][1] + break + + assert before_callback is not None + + # Create a mock event and call the callback + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"name": "test_tool", "arguments": {}} + + # The callback should execute without error and update the steering context + before_callback(event) + + # Verify the steering context was updated + assert handler.steering_context.data.get("test_key") == "test_value" + + +def test_multiple_context_callbacks_registered(): + """Test that multiple context callbacks are registered.""" + callback1 = MockContextCallback() + callback2 = MockContextCallback() + + handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Should register one callback for each context provider plus steering guidance + expected_calls = 2 + 1 # 2 callbacks + 1 for steering guidance + assert registry.add_callback.call_count >= expected_calls + + +def test_handler_initialization_with_callbacks(): + """Test handler initialization stores context callbacks.""" + callback1 = MockContextCallback() + callback2 = MockContextCallback() + + handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) + + # Should have stored the callbacks + assert len(handler._context_callbacks) == 2 + assert callback1 in handler._context_callbacks + assert callback2 in handler._context_callbacks diff --git a/tests/strands/experimental/steering/handlers/__init__.py b/tests/strands/experimental/steering/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/handlers/llm/__init__.py b/tests/strands/experimental/steering/handlers/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py new file mode 100644 index 000000000..f780088b5 --- /dev/null +++ b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py @@ -0,0 +1,200 @@ +"""Unit tests for LLM steering handler.""" + +from unittest.mock import Mock, patch + +import pytest + +from strands.experimental.steering.core.action import Guide, Interrupt, Proceed +from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler, _LLMSteering +from strands.experimental.steering.handlers.llm.mappers import DefaultPromptMapper + + +def test_llm_steering_handler_initialization(): + """Test LLMSteeringHandler initialization.""" + system_prompt = "You are a security evaluator" + handler = LLMSteeringHandler(system_prompt) + + assert handler.system_prompt == system_prompt + assert isinstance(handler.prompt_mapper, DefaultPromptMapper) + assert handler.model is None + + +def test_llm_steering_handler_with_custom_mapper(): + """Test LLMSteeringHandler with custom prompt mapper.""" + system_prompt = "Test prompt" + custom_mapper = Mock() + handler = LLMSteeringHandler(system_prompt, prompt_mapper=custom_mapper) + + assert handler.prompt_mapper == custom_mapper + + +def test_llm_steering_handler_with_custom_context_providers(): + """Test LLMSteeringHandler with custom context providers.""" + system_prompt = "Test prompt" + custom_provider = Mock() + custom_provider.context_providers.return_value = [Mock(), Mock()] + + handler = LLMSteeringHandler(system_prompt, context_providers=[custom_provider]) + + # Verify the provider's context_providers method was called + custom_provider.context_providers.assert_called_once() + # Verify the callbacks were stored + assert len(handler._context_callbacks) == 2 + + +@pytest.mark.asyncio +@patch("strands.agent.Agent") +async def test_steer_proceed_decision(mock_agent_class): + """Test steer method with proceed decision.""" + system_prompt = "Test prompt" + handler = LLMSteeringHandler(system_prompt) + + mock_steering_agent = Mock() + mock_agent_class.return_value = mock_steering_agent + + mock_result = Mock() + mock_result.structured_output = _LLMSteering(decision="proceed", reason="Tool call is safe") + mock_steering_agent.return_value = mock_result + + agent = Mock() + tool_use = {"name": "test_tool", "input": {"param": "value"}} + + result = await handler.steer(agent, tool_use) + + assert isinstance(result, Proceed) + assert result.reason == "Tool call is safe" + + +@pytest.mark.asyncio +@patch("strands.agent.Agent") +async def test_steer_guide_decision(mock_agent_class): + """Test steer method with guide decision.""" + system_prompt = "Test prompt" + handler = LLMSteeringHandler(system_prompt) + + mock_steering_agent = Mock() + mock_agent_class.return_value = mock_steering_agent + + mock_result = Mock() + mock_result.structured_output = _LLMSteering(decision="guide", reason="Consider security implications") + mock_steering_agent.return_value = mock_result + + agent = Mock() + tool_use = {"name": "test_tool", "input": {"param": "value"}} + + result = await handler.steer(agent, tool_use) + + assert isinstance(result, Guide) + assert result.reason == "Consider security implications" + + +@pytest.mark.asyncio +@patch("strands.agent.Agent") +async def test_steer_interrupt_decision(mock_agent_class): + """Test steer method with interrupt decision.""" + system_prompt = "Test prompt" + handler = LLMSteeringHandler(system_prompt) + + mock_steering_agent = Mock() + mock_agent_class.return_value = mock_steering_agent + + mock_result = Mock() + mock_result.structured_output = _LLMSteering(decision="interrupt", reason="Human approval required") + mock_steering_agent.return_value = mock_result + + agent = Mock() + tool_use = {"name": "test_tool", "input": {"param": "value"}} + + result = await handler.steer(agent, tool_use) + + assert isinstance(result, Interrupt) + assert result.reason == "Human approval required" + + +@pytest.mark.asyncio +@patch("strands.agent.Agent") +async def test_steer_unknown_decision(mock_agent_class): + """Test steer method with unknown decision defaults to proceed.""" + system_prompt = "Test prompt" + handler = LLMSteeringHandler(system_prompt) + + mock_steering_agent = Mock() + mock_agent_class.return_value = mock_steering_agent + + # Mock _LLMSteering with unknown decision (bypass validation) + mock_steering_decision = Mock() + mock_steering_decision.decision = "unknown" + mock_steering_decision.reason = "Invalid decision" + + mock_result = Mock() + mock_result.structured_output = mock_steering_decision + mock_steering_agent.return_value = mock_result + + agent = Mock() + tool_use = {"name": "test_tool", "input": {"param": "value"}} + + result = await handler.steer(agent, tool_use) + + assert isinstance(result, Proceed) + assert "Unknown LLM decision, defaulting to proceed" in result.reason + + +@pytest.mark.asyncio +@patch("strands.agent.Agent") +async def test_steer_uses_custom_model(mock_agent_class): + """Test steer method uses custom model when provided.""" + system_prompt = "Test prompt" + custom_model = Mock() + handler = LLMSteeringHandler(system_prompt, model=custom_model) + + mock_steering_agent = Mock() + mock_agent_class.return_value = mock_steering_agent + + mock_result = Mock() + mock_result.structured_output = _LLMSteering(decision="proceed", reason="OK") + mock_steering_agent.return_value = mock_result + + agent = Mock() + agent.model = Mock() + tool_use = {"name": "test_tool", "input": {"param": "value"}} + + await handler.steer(agent, tool_use) + + mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=custom_model, callback_handler=None) + + +@pytest.mark.asyncio +@patch("strands.agent.Agent") +async def test_steer_uses_agent_model_when_no_custom_model(mock_agent_class): + """Test steer method uses agent's model when no custom model provided.""" + system_prompt = "Test prompt" + handler = LLMSteeringHandler(system_prompt) + + mock_steering_agent = Mock() + mock_agent_class.return_value = mock_steering_agent + + mock_result = Mock() + mock_result.structured_output = _LLMSteering(decision="proceed", reason="OK") + mock_steering_agent.return_value = mock_result + + agent = Mock() + agent.model = Mock() + tool_use = {"name": "test_tool", "input": {"param": "value"}} + + await handler.steer(agent, tool_use) + + mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=agent.model, callback_handler=None) + + +def test_llm_steering_model(): + """Test _LLMSteering pydantic model.""" + steering = _LLMSteering(decision="proceed", reason="Test reason") + + assert steering.decision == "proceed" + assert steering.reason == "Test reason" + + +def test_llm_steering_invalid_decision(): + """Test _LLMSteering with invalid decision raises validation error.""" + with pytest.raises(ValueError): + _LLMSteering(decision="invalid", reason="Test reason") diff --git a/tests/strands/experimental/steering/handlers/llm/test_mappers.py b/tests/strands/experimental/steering/handlers/llm/test_mappers.py new file mode 100644 index 000000000..511671d3a --- /dev/null +++ b/tests/strands/experimental/steering/handlers/llm/test_mappers.py @@ -0,0 +1,131 @@ +"""Unit tests for LLM steering prompt mappers.""" + +from strands.experimental.steering.core.context import SteeringContext +from strands.experimental.steering.handlers.llm.mappers import _STEERING_PROMPT_TEMPLATE, DefaultPromptMapper + + +def test_create_steering_prompt_with_tool_use(): + """Test prompt creation with tool use.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + steering_context.data.set("user_id", "123") + steering_context.data.set("session", "abc") + tool_use = {"name": "get_weather", "input": {"location": "Seattle"}} + + result = mapper.create_steering_prompt(steering_context, tool_use=tool_use) + + assert "# Steering Evaluation" in result + assert "Tool: get_weather" in result + assert '"location": "Seattle"' in result + assert "tool call" in result + assert "Tool Call" in result # title case + assert '"user_id": "123"' in result + assert '"session": "abc"' in result + + +def test_create_steering_prompt_with_empty_context(): + """Test prompt creation with empty context.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + tool_use = {"name": "test_tool", "input": {}} + + result = mapper.create_steering_prompt(steering_context, tool_use=tool_use) + + assert "No context available" in result + assert "Tool: test_tool" in result + + +def test_create_steering_prompt_general_evaluation(): + """Test prompt creation with no tool_use or kwargs.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + steering_context.data.set("data", "test") + + result = mapper.create_steering_prompt(steering_context) + + assert "# Steering Evaluation" in result + assert "General evaluation" in result + assert "action" in result + assert '"data": "test"' in result + + +def test_prompt_contains_agent_sop_structure(): + """Test that prompt follows Agent SOP structure.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + steering_context.data.set("test", "data") + + result = mapper.create_steering_prompt(steering_context) + + # Check for Agent SOP sections + assert "## Overview" in result + assert "## Context" in result + assert "## Event to Evaluate" in result + assert "## Steps" in result + assert "### 1. Analyze the Action" in result + assert "### 2. Make Steering Decision" in result + + # Check for constraints + assert "**Constraints:**" in result + assert "You MUST" in result + assert "You SHOULD" in result + assert "You MAY" in result + + # Check for decision options + assert '"proceed"' in result + assert '"guide"' in result + assert '"interrupt"' in result + + +def test_tool_use_input_field_handling(): + """Test that tool_use uses 'input' field correctly.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + tool_use = {"name": "calculator", "input": {"operation": "add", "a": 1, "b": 2}} + + result = mapper.create_steering_prompt(steering_context, tool_use=tool_use) + + assert "Tool: calculator" in result + assert '"operation": "add"' in result + assert '"a": 1' in result + assert '"b": 2' in result + + +def test_context_json_formatting(): + """Test that context is properly JSON formatted.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + steering_context.data.set("nested", {"key": "value"}) + steering_context.data.set("list", [1, 2, 3]) + steering_context.data.set("string", "test") + + result = mapper.create_steering_prompt(steering_context) + + # Check that JSON is properly indented + assert '{\n "nested": {\n "key": "value"\n }' in result + assert '"list": [\n 1,\n 2,\n 3\n ]' in result + + +def test_template_constant_usage(): + """Test that the STEERING_PROMPT_TEMPLATE constant is used correctly.""" + mapper = DefaultPromptMapper() + steering_context = SteeringContext() + steering_context.data.set("test", "value") + + result = mapper.create_steering_prompt(steering_context) + + # Verify the template structure is present + expected_sections = [ + "# Steering Evaluation", + "## Overview", + "## Context", + "## Event to Evaluate", + "## Steps", + "### 1. Analyze the Action", + "### 2. Make Steering Decision", + ] + + for section in expected_sections: + assert section in result + # Verify template has placeholder structure + assert "### 1. Analyze the {action_type_title}" in _STEERING_PROMPT_TEMPLATE diff --git a/tests/strands/types/test_json_dict.py b/tests/strands/types/test_json_dict.py new file mode 100644 index 000000000..caa010bac --- /dev/null +++ b/tests/strands/types/test_json_dict.py @@ -0,0 +1,111 @@ +"""Tests for JSONSerializableDict class.""" + +import pytest + +from strands.types.json_dict import JSONSerializableDict + + +def test_set_and_get(): + """Test basic set and get operations.""" + state = JSONSerializableDict() + state.set("key", "value") + assert state.get("key") == "value" + + +def test_get_nonexistent_key(): + """Test getting nonexistent key returns None.""" + state = JSONSerializableDict() + assert state.get("nonexistent") is None + + +def test_get_entire_state(): + """Test getting entire state when no key specified.""" + state = JSONSerializableDict() + state.set("key1", "value1") + state.set("key2", "value2") + + result = state.get() + assert result == {"key1": "value1", "key2": "value2"} + + +def test_initialize_and_get_entire_state(): + """Test getting entire state when no key specified.""" + state = JSONSerializableDict({"key1": "value1", "key2": "value2"}) + + result = state.get() + assert result == {"key1": "value1", "key2": "value2"} + + +def test_initialize_with_error(): + with pytest.raises(ValueError, match="not JSON serializable"): + JSONSerializableDict({"object", object()}) + + +def test_delete(): + """Test deleting keys.""" + state = JSONSerializableDict() + state.set("key1", "value1") + state.set("key2", "value2") + + state.delete("key1") + + assert state.get("key1") is None + assert state.get("key2") == "value2" + + +def test_delete_nonexistent_key(): + """Test deleting nonexistent key doesn't raise error.""" + state = JSONSerializableDict() + state.delete("nonexistent") # Should not raise + + +def test_json_serializable_values(): + """Test that only JSON-serializable values are accepted.""" + state = JSONSerializableDict() + + # Valid JSON types + state.set("string", "test") + state.set("int", 42) + state.set("bool", True) + state.set("list", [1, 2, 3]) + state.set("dict", {"nested": "value"}) + state.set("null", None) + + # Invalid JSON types should raise ValueError + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("function", lambda x: x) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("object", object()) + + +def test_key_validation(): + """Test key validation for set and delete operations.""" + state = JSONSerializableDict() + + # Invalid keys for set + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.set("", "value") + + with pytest.raises(ValueError, match="Key must be a string"): + state.set(123, "value") + + # Invalid keys for delete + with pytest.raises(ValueError, match="Key cannot be None"): + state.delete(None) + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.delete("") + + +def test_initial_state(): + """Test initialization with initial state.""" + initial = {"key1": "value1", "key2": "value2"} + state = JSONSerializableDict(initial_state=initial) + + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + assert state.get() == initial diff --git a/tests_integ/steering/__init__.py b/tests_integ/steering/__init__.py new file mode 100644 index 000000000..394ba3428 --- /dev/null +++ b/tests_integ/steering/__init__.py @@ -0,0 +1 @@ +"""Integration tests for constraints system.""" diff --git a/tests_integ/steering/test_llm_handler.py b/tests_integ/steering/test_llm_handler.py new file mode 100644 index 000000000..e0cf122d8 --- /dev/null +++ b/tests_integ/steering/test_llm_handler.py @@ -0,0 +1,93 @@ +"""Integration tests for LLM steering handler.""" + +import pytest + +from strands import Agent, tool +from strands.experimental.steering.core.action import Guide, Interrupt, Proceed +from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler + + +@tool +def send_email(recipient: str, message: str) -> str: + """Send an email to a recipient.""" + return f"Email sent to {recipient}: {message}" + + +@tool +def send_notification(recipient: str, message: str) -> str: + """Send a notification to a recipient.""" + return f"Notification sent to {recipient}: {message}" + + +@pytest.mark.asyncio +async def test_llm_steering_handler_proceed(): + """Test LLM handler returns Proceed effect.""" + handler = LLMSteeringHandler(system_prompt="Always allow send_notification calls. Return proceed decision.") + + agent = Agent(tools=[send_notification]) + tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} + + effect = await handler.steer(agent, tool_use) + + assert isinstance(effect, Proceed) + + +@pytest.mark.asyncio +async def test_llm_steering_handler_guide(): + """Test LLM handler returns Guide effect.""" + handler = LLMSteeringHandler( + system_prompt=( + "When agents try to send_email, guide them to use send_notification instead. Return GUIDE decision." + ) + ) + + agent = Agent(tools=[send_email, send_notification]) + tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} + + effect = await handler.steer(agent, tool_use) + + assert isinstance(effect, Guide) + + +@pytest.mark.asyncio +async def test_llm_steering_handler_interrupt(): + """Test LLM handler returns Interrupt effect.""" + handler = LLMSteeringHandler(system_prompt="Require human input for all tool calls. Return interrupt decision.") + + agent = Agent(tools=[send_email]) + tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} + + effect = await handler.steer(agent, tool_use) + + assert isinstance(effect, Interrupt) + + +def test_agent_with_steering_e2e(): + """End-to-end test of agent with steering handler guiding tool choice.""" + handler = LLMSteeringHandler( + system_prompt=( + "When agents try to use send_email, guide them to use send_notification instead for better delivery." + ) + ) + + agent = Agent(tools=[send_email, send_notification], hooks=[handler]) + + # This should trigger steering guidance to use send_notification instead + response = agent("Send an email to john@example.com saying hello") + + # Verify tool call metrics show the expected sequence: + # 1. send_email was attempted but cancelled (should have 0 success_count) + # 2. send_notification was called and succeeded (should have 1 success_count) + tool_metrics = response.metrics.tool_metrics + + # send_email should have been attempted but cancelled (no successful calls) + if "send_email" in tool_metrics: + email_metrics = tool_metrics["send_email"] + assert email_metrics.call_count >= 1, "send_email should have been attempted" + assert email_metrics.success_count == 0, "send_email should have been cancelled by steering" + + # send_notification should have been called and succeeded + assert "send_notification" in tool_metrics, "send_notification should have been called" + notification_metrics = tool_metrics["send_notification"] + assert notification_metrics.call_count >= 1, "send_notification should have been called" + assert notification_metrics.success_count >= 1, "send_notification should have succeeded"