diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 3a23e257a..8ae660da0 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -7,7 +7,7 @@ from . import bedrock, model from .bedrock import BedrockModel -from .model import BaseModelConfig, CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model __all__ = [ "bedrock", @@ -15,6 +15,7 @@ "BaseModelConfig", "BedrockModel", "CacheConfig", + "CacheToolsConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ab9adb67a..4cd6f7fbc 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -34,7 +34,7 @@ from ._defaults import resolve_config_metadata from ._strict_schema import ensure_strict_json_schema from ._validation import validate_config_keys -from .model import BaseModelConfig, CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model logger = logging.getLogger(__name__) @@ -90,7 +90,8 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. - cache_tools: Cache point type for tools + cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL, + or a CacheToolsConfig instance to set both type and TTL (e.g. "1h"). guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -127,7 +128,7 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: list[str] | None cache_prompt: str | None cache_config: CacheConfig | None - cache_tools: str | None + cache_tools: str | CacheToolsConfig | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -292,11 +293,7 @@ def _format_request( } for tool_spec in tool_specs ], - *( - [{"cachePoint": {"type": self.config["cache_tools"]}}] - if self.config.get("cache_tools") - else [] - ), + *self._build_tools_cache_point(), ], **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } @@ -371,6 +368,25 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} + def _build_tools_cache_point(self) -> list[dict[str, Any]]: + """Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured. + + Returns: + A single-element list containing the cache point block, or an empty list if no cache_tools is set. + """ + cache_tools = self.config.get("cache_tools") + if not cache_tools: + return [] + + if isinstance(cache_tools, CacheToolsConfig): + cache_point: dict[str, Any] = {"type": cache_tools.type} + if cache_tools.ttl: + cache_point["ttl"] = cache_tools.ttl + else: + cache_point = {"type": cache_tools} + + return [{"cachePoint": cache_point}] + def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: """Inject a cache point at the end of the last user message. @@ -395,7 +411,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: last_user_idx = msg_idx if last_user_idx is not None and messages[last_user_idx].get("content"): - messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + cache_point: dict[str, Any] = {"type": "default"} + cache_config = self.config.get("cache_config") + if cache_config and cache_config.ttl: + cache_point["ttl"] = cache_config.ttl + messages[last_user_idx]["content"].append({"cachePoint": cache_point}) logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index dd2f9eed2..77ef1df40 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -134,9 +134,25 @@ class CacheConfig: strategy: Caching strategy to use. - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check + ttl: Optional TTL duration for cache entries (e.g. "5m", "1h"). + When specified, auto-injected cache points will include this TTL value. """ strategy: Literal["auto", "anthropic"] = "auto" + ttl: str | None = None + + +@dataclass +class CacheToolsConfig: + """Configuration for the toolConfig cache point. + + Attributes: + type: Cache point type (e.g. "default"). + ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). + """ + + type: str = "default" + ttl: str | None = None class Model(abc.ABC): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2e105d64a..319b5574f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -14,7 +14,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel, CacheConfig +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.models.bedrock import ( DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, @@ -3554,3 +3554,69 @@ async def test_skip_native_api_by_default(self, bedrock_client, model_id, messag bedrock_client.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + +def test_inject_cache_point_with_ttl(bedrock_client): + """Test that _inject_cache_point includes TTL when cache_config has ttl set.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert cache_point["ttl"] == "5m" + + +def test_inject_cache_point_without_ttl(bedrock_client): + """Test that _inject_cache_point omits TTL when cache_config has no ttl.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert "ttl" not in cache_point + + +def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m")) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig without ttl produces a cachePoint with only type.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type)) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type): + """Test that passing cache_tools as a string still produces a cachePoint with only type.""" + model.update_config(cache_tools=cache_type) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 06c72ef88..509a300f3 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,9 +6,18 @@ import strands from strands import Agent -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.types.content import ContentBlock +# Model ID used for prompt-caching TTL integration tests. Per +# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5, +# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku +# available and is preferred for CI due to lower latency and cost relative to +# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that +# supports CachePoint TTL. +_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + @pytest.fixture def system_prompt(): @@ -561,3 +570,127 @@ def calculator(expression: str) -> float: agent('Search for "python" with tags ["programming", "language"] using the search tool.') assert "search" in tools_called + + +def test_prompt_caching_cache_tools_ttl(): + """Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point. + + Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a + Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call + completes without a ValidationException on the TTL field. + + Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig + prefix because Bedrock's tool-prefix cache threshold varies by model and region. + The critical behavior under test here is that the TTL field is accepted end-to-end. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + (Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL). + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools=CacheToolsConfig(type="default", ttl="5m"), + ) + + @strands.tool + def lookup_fact(topic: str) -> str: + """Look up a fact about the given topic. + + This tool is useful when you need authoritative information. + """ + return f"Fact about {topic}: example" + + agent = Agent( + model=model, + tools=[lookup_fact], + load_tools_from_directory=False, + ) + + # The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint + # without raising a ValidationException. + result = agent("Use the lookup_fact tool to look up 'python'.") + assert len(str(result)) > 0 + + +def test_prompt_caching_cache_config_auto_with_ttl(): + """Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point. + + Verifies that the cache point appended to the last user message by _inject_cache_point + carries the configured TTL, and that Bedrock accepts the request. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?" + + agent = Agent( + model=model, + load_tools_from_directory=False, + ) + + # First call: auto-injected cache point on the last user message must include ttl and be accepted + result1 = agent(large_message) + assert len(str(result1)) > 0 + + # Verify cache write occurred with auto-inject + ttl + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')" + ) + + +def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): + """Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121). + + Bedrock processes cache checkpoints in order: toolConfig -> system -> messages, + and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded + an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a + ValidationException. + + This test sets 1h TTL on all three checkpoints simultaneously and verifies the + call succeeds. + + Uses Claude Haiku 4.5 which supports 1h TTL per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools=CacheToolsConfig(type="default", ttl="1h"), + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + # Timestamp-based uniqueness to avoid cache conflicts across CI runs + unique_id = str(int(time.time() * 1000000)) + large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000) + + # User-supplied 1h cache point on system prompt — third checkpoint also at 1h + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + @strands.tool + def echo(value: str) -> str: + """Echo the given value back.""" + return value + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + tools=[echo], + load_tools_from_directory=False, + ) + + # Must succeed without ValidationException on the non-increasing TTL rule + result = agent("What is 2+2?") + assert len(str(result)) > 0