From 0835dc39173375be857713433a46b43457e49937 Mon Sep 17 00:00:00 2001 From: Kien Pham Date: Thu, 30 Apr 2026 16:33:27 -0700 Subject: [PATCH 1/5] feat(bedrock): add TTL support to auto-injected tool and system/user cache points Extends prompt caching TTL coverage beyond user-supplied cachePoint blocks (PR #1660) to the two SDK-managed auto-injected paths on BedrockModel: - Adds cache_tools_ttl config option so the toolConfig auto-injected cache point can carry a TTL (e.g. '5m' or '1h'). - Adds ttl field to CacheConfig dataclass so _inject_cache_point propagates TTL into the cache point appended to the last user message when strategy='auto'. Together, these let users align all three cache checkpoint TTLs (toolConfig -> system -> messages) to satisfy Bedrock's non-increasing TTL ordering rule -- which was previously impossible because cache_tools hardcoded an implicit 5m TTL. Partially addresses #2121 (Bug 2: cache_tools ordering violation with 1h TTL). Bug 1 from #2121 was resolved by #1660. Tests: - 4 unit tests covering cache_tools_ttl and CacheConfig.ttl with and without TTL (backward-compat). - 3 integration tests against Claude Haiku 4.5 (officially documented for 1h TTL on Bedrock), including a regression test that sets 1h TTL on all three cache checkpoints simultaneously. - Model ID extracted into a _CACHE_TTL_MODEL_ID module constant so future model bumps are a one-line change. --- src/strands/models/bedrock.py | 21 +++- src/strands/models/model.py | 3 + tests/strands/models/test_bedrock.py | 56 +++++++++ tests_integ/models/test_model_bedrock.py | 137 ++++++++++++++++++++++- 4 files changed, 214 insertions(+), 3 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c74a63a3b..5a7ed9d8f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -91,6 +91,7 @@ class BedrockConfig(BaseModelConfig, total=False): cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. cache_tools: Cache point type for tools + cache_tools_ttl: Optional TTL duration for tool cache points (e.g. "5m", "1h") guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -128,6 +129,7 @@ class BedrockConfig(BaseModelConfig, total=False): cache_prompt: str | None cache_config: CacheConfig | None cache_tools: str | None + cache_tools_ttl: str | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -293,7 +295,18 @@ def _format_request( for tool_spec in tool_specs ], *( - [{"cachePoint": {"type": self.config["cache_tools"]}}] + [ + { + "cachePoint": { + "type": self.config["cache_tools"], + **( + {"ttl": self.config["cache_tools_ttl"]} + if self.config.get("cache_tools_ttl") + else {} + ), + } + } + ] if self.config.get("cache_tools") else [] ), @@ -395,7 +408,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: last_user_idx = msg_idx if last_user_idx is not None and messages[last_user_idx].get("content"): - messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + cache_point: dict[str, Any] = {"type": "default"} + cache_config = self.config.get("cache_config") + if cache_config and cache_config.ttl: + cache_point["ttl"] = cache_config.ttl + messages[last_user_idx]["content"].append({"cachePoint": cache_point}) logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index dd2f9eed2..d8ff91625 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -134,9 +134,12 @@ class CacheConfig: strategy: Caching strategy to use. - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check + ttl: Optional TTL duration for cache entries (e.g. "5m", "1h"). + When specified, auto-injected cache points will include this TTL value. """ strategy: Literal["auto", "anthropic"] = "auto" + ttl: str | None = None class Model(abc.ABC): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2f1f7d1f1..3d1faef34 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3495,3 +3495,59 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c bedrock_client.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + +def test_inject_cache_point_with_ttl(bedrock_client): + """Test that _inject_cache_point includes TTL when cache_config has ttl set.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert cache_point["ttl"] == "5m" + + +def test_inject_cache_point_without_ttl(bedrock_client): + """Test that _inject_cache_point omits TTL when cache_config has no ttl.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert "ttl" not in cache_point + + +def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that cache_tools_ttl propagates into toolConfig cachePoint.""" + model.update_config(cache_tools=cache_type, cache_tools_ttl="5m") + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that toolConfig cachePoint omits TTL when cache_tools_ttl is not set.""" + model.update_config(cache_tools=cache_type) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 73d67f414..7b481adb7 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,9 +6,18 @@ import strands from strands import Agent -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig from strands.types.content import ContentBlock +# Model ID used for prompt-caching TTL integration tests. Per +# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5, +# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku +# available and is preferred for CI due to lower latency and cost relative to +# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that +# supports CachePoint TTL. +_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + @pytest.fixture def system_prompt(): @@ -576,3 +585,129 @@ def calculator(expression: str) -> float: agent('Search for "python" with tags ["programming", "language"] using the search tool.') assert "search" in tools_called + + +def test_prompt_caching_cache_tools_ttl(): + """Test that cache_tools_ttl propagates into the auto-injected toolConfig cache point. + + Verifies that BedrockModel(cache_tools="default", cache_tools_ttl="5m") produces a + Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call + completes without a ValidationException on the TTL field. + + Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig + prefix because Bedrock's tool-prefix cache threshold varies by model and region. + The critical behavior under test here is that the TTL field is accepted end-to-end. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + (Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL). + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools="default", + cache_tools_ttl="5m", + ) + + @strands.tool + def lookup_fact(topic: str) -> str: + """Look up a fact about the given topic. + + This tool is useful when you need authoritative information. + """ + return f"Fact about {topic}: example" + + agent = Agent( + model=model, + tools=[lookup_fact], + load_tools_from_directory=False, + ) + + # The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint + # without raising a ValidationException. + result = agent("Use the lookup_fact tool to look up 'python'.") + assert len(str(result)) > 0 + + +def test_prompt_caching_cache_config_auto_with_ttl(): + """Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point. + + Verifies that the cache point appended to the last user message by _inject_cache_point + carries the configured TTL, and that Bedrock accepts the request. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?" + + agent = Agent( + model=model, + load_tools_from_directory=False, + ) + + # First call: auto-injected cache point on the last user message must include ttl and be accepted + result1 = agent(large_message) + assert len(str(result1)) > 0 + + # Verify cache write occurred with auto-inject + ttl + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')" + ) + + +def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): + """Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121). + + Bedrock processes cache checkpoints in order: toolConfig -> system -> messages, + and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded + an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a + ValidationException. + + This test sets 1h TTL on all three checkpoints simultaneously and verifies the + call succeeds. + + Uses Claude Haiku 4.5 which supports 1h TTL per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools="default", + cache_tools_ttl="1h", + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + # Timestamp-based uniqueness to avoid cache conflicts across CI runs + unique_id = str(int(time.time() * 1000000)) + large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000) + + # User-supplied 1h cache point on system prompt — third checkpoint also at 1h + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + @strands.tool + def echo(value: str) -> str: + """Echo the given value back.""" + return value + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + tools=[echo], + load_tools_from_directory=False, + ) + + # Must succeed without ValidationException on the non-increasing TTL rule + result = agent("What is 2+2?") + assert len(str(result)) > 0 From cc32ea8758757979470432b1b0f681883ae29603 Mon Sep 17 00:00:00 2001 From: Kien Pham Date: Mon, 18 May 2026 13:58:12 -0700 Subject: [PATCH 2/5] refactor(bedrock): replace cache_tools_ttl with CacheToolsConfig dataclass Address review feedback: couple `type` and `ttl` together in a single config object so users can't set a TTL without a type. - Add CacheToolsConfig(type, ttl) dataclass - Change cache_tools to `str | CacheToolsConfig | None` (str preserved for back-compat) - Drop standalone cache_tools_ttl field - Export CacheToolsConfig from strands.models --- src/strands/models/__init__.py | 3 +- src/strands/models/bedrock.py | 57 +++++++++++++++--------- tests/strands/models/test_bedrock.py | 22 ++++++--- tests_integ/models/test_model_bedrock.py | 12 +++-- 4 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 3a23e257a..2d78f3048 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -6,7 +6,7 @@ from typing import Any from . import bedrock, model -from .bedrock import BedrockModel +from .bedrock import BedrockModel, CacheToolsConfig from .model import BaseModelConfig, CacheConfig, Model __all__ = [ @@ -15,6 +15,7 @@ "BaseModelConfig", "BedrockModel", "CacheConfig", + "CacheToolsConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 15dbd8c50..6b56b0017 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -9,6 +9,7 @@ import os import warnings from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView +from dataclasses import dataclass from typing import Any, Literal, TypeVar, cast import boto3 @@ -69,6 +70,19 @@ def _clear_skip_count_tokens_cache() -> None: DEFAULT_READ_TIMEOUT = 120 +@dataclass +class CacheToolsConfig: + """Configuration for the Bedrock toolConfig cache point. + + Attributes: + type: Cache point type (e.g. "default"). + ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). + """ + + type: str = "default" + ttl: str | None = None + + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -90,8 +104,8 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. - cache_tools: Cache point type for tools - cache_tools_ttl: Optional TTL duration for tool cache points (e.g. "5m", "1h") + cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL, + or a CacheToolsConfig instance to set both type and TTL (e.g. "1h"). guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -128,8 +142,7 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: list[str] | None cache_prompt: str | None cache_config: CacheConfig | None - cache_tools: str | None - cache_tools_ttl: str | None + cache_tools: str | CacheToolsConfig | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -294,22 +307,7 @@ def _format_request( } for tool_spec in tool_specs ], - *( - [ - { - "cachePoint": { - "type": self.config["cache_tools"], - **( - {"ttl": self.config["cache_tools_ttl"]} - if self.config.get("cache_tools_ttl") - else {} - ), - } - } - ] - if self.config.get("cache_tools") - else [] - ), + *self._build_tools_cache_point(), ], **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } @@ -384,6 +382,25 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} + def _build_tools_cache_point(self) -> list[dict[str, Any]]: + """Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured. + + Returns: + A single-element list containing the cache point block, or an empty list if no cache_tools is set. + """ + cache_tools = self.config.get("cache_tools") + if not cache_tools: + return [] + + if isinstance(cache_tools, CacheToolsConfig): + cache_point: dict[str, Any] = {"type": cache_tools.type} + if cache_tools.ttl: + cache_point["ttl"] = cache_tools.ttl + else: + cache_point = {"type": cache_tools} + + return [{"cachePoint": cache_point}] + def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: """Inject a cache point at the end of the last user message. diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index d4806073a..319b5574f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -14,7 +14,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel, CacheConfig +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.models.bedrock import ( DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, @@ -3592,9 +3592,9 @@ def test_inject_cache_point_without_ttl(bedrock_client): assert "ttl" not in cache_point -def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spec, cache_type): - """Test that cache_tools_ttl propagates into toolConfig cachePoint.""" - model.update_config(cache_tools=cache_type, cache_tools_ttl="5m") +def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m")) tru_request = model._format_request(messages, tool_specs=[tool_spec]) @@ -3602,8 +3602,18 @@ def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spe assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point -def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec, cache_type): - """Test that toolConfig cachePoint omits TTL when cache_tools_ttl is not set.""" +def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig without ttl produces a cachePoint with only type.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type)) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type): + """Test that passing cache_tools as a string still produces a cachePoint with only type.""" model.update_config(cache_tools=cache_type) tru_request = model._format_request(messages, tool_specs=[tool_spec]) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index d9d963262..509a300f3 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,7 +6,7 @@ import strands from strands import Agent -from strands.models import BedrockModel, CacheConfig +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.types.content import ContentBlock # Model ID used for prompt-caching TTL integration tests. Per @@ -573,9 +573,9 @@ def calculator(expression: str) -> float: def test_prompt_caching_cache_tools_ttl(): - """Test that cache_tools_ttl propagates into the auto-injected toolConfig cache point. + """Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point. - Verifies that BedrockModel(cache_tools="default", cache_tools_ttl="5m") produces a + Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call completes without a ValidationException on the TTL field. @@ -590,8 +590,7 @@ def test_prompt_caching_cache_tools_ttl(): model = BedrockModel( model_id=_CACHE_TTL_MODEL_ID, streaming=False, - cache_tools="default", - cache_tools_ttl="5m", + cache_tools=CacheToolsConfig(type="default", ttl="5m"), ) @strands.tool @@ -665,8 +664,7 @@ def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): model = BedrockModel( model_id=_CACHE_TTL_MODEL_ID, streaming=False, - cache_tools="default", - cache_tools_ttl="1h", + cache_tools=CacheToolsConfig(type="default", ttl="1h"), cache_config=CacheConfig(strategy="auto", ttl="1h"), ) From 8759076e360c2ed87158e84fc44c45cf7578e2dc Mon Sep 17 00:00:00 2001 From: Kien Pham Date: Mon, 18 May 2026 14:26:01 -0700 Subject: [PATCH 3/5] refactor(bedrock): drop CacheToolsConfig, source tool TTL from cache_config Address review feedback: instead of a separate CacheToolsConfig, the toolConfig cache point now picks up its TTL from cache_config.ttl. This keeps a single TTL knob for SDK-managed checkpoints, which satisfies Bedrock's non-increasing-TTL ordering rule trivially. - Revert cache_tools type to `str | None` - _build_tools_cache_point reads ttl from self.config["cache_config"].ttl - Remove CacheToolsConfig class and export - Update tests to set TTL via cache_config --- src/strands/models/__init__.py | 3 +- src/strands/models/bedrock.py | 35 ++++++++---------------- tests/strands/models/test_bedrock.py | 22 ++++----------- tests_integ/models/test_model_bedrock.py | 15 +++++----- 4 files changed, 27 insertions(+), 48 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 2d78f3048..3a23e257a 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -6,7 +6,7 @@ from typing import Any from . import bedrock, model -from .bedrock import BedrockModel, CacheToolsConfig +from .bedrock import BedrockModel from .model import BaseModelConfig, CacheConfig, Model __all__ = [ @@ -15,7 +15,6 @@ "BaseModelConfig", "BedrockModel", "CacheConfig", - "CacheToolsConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 6b56b0017..9b4b25450 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -9,7 +9,6 @@ import os import warnings from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView -from dataclasses import dataclass from typing import Any, Literal, TypeVar, cast import boto3 @@ -70,19 +69,6 @@ def _clear_skip_count_tokens_cache() -> None: DEFAULT_READ_TIMEOUT = 120 -@dataclass -class CacheToolsConfig: - """Configuration for the Bedrock toolConfig cache point. - - Attributes: - type: Cache point type (e.g. "default"). - ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). - """ - - type: str = "default" - ttl: str | None = None - - class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -104,8 +90,10 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. - cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL, - or a CacheToolsConfig instance to set both type and TTL (e.g. "1h"). + When set, ``cache_config.ttl`` is also applied to the ``cache_tools`` checkpoint (if configured) + to satisfy Bedrock's non-increasing TTL ordering rule across checkpoints. + cache_tools: Cache point type for tools (e.g. "default"). The TTL on this checkpoint is taken + from ``cache_config.ttl`` when set. guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -142,7 +130,7 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: list[str] | None cache_prompt: str | None cache_config: CacheConfig | None - cache_tools: str | CacheToolsConfig | None + cache_tools: str | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -385,6 +373,9 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict def _build_tools_cache_point(self) -> list[dict[str, Any]]: """Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured. + TTL is sourced from ``cache_config.ttl`` so that all SDK-managed cache checkpoints (toolConfig, + message) share a single TTL — Bedrock requires TTLs across checkpoints to be non-increasing. + Returns: A single-element list containing the cache point block, or an empty list if no cache_tools is set. """ @@ -392,12 +383,10 @@ def _build_tools_cache_point(self) -> list[dict[str, Any]]: if not cache_tools: return [] - if isinstance(cache_tools, CacheToolsConfig): - cache_point: dict[str, Any] = {"type": cache_tools.type} - if cache_tools.ttl: - cache_point["ttl"] = cache_tools.ttl - else: - cache_point = {"type": cache_tools} + cache_point: dict[str, Any] = {"type": cache_tools} + cache_config = self.config.get("cache_config") + if cache_config and cache_config.ttl: + cache_point["ttl"] = cache_config.ttl return [{"cachePoint": cache_point}] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 319b5574f..e04ea3dcf 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -14,7 +14,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel, CacheConfig, CacheToolsConfig +from strands.models import BedrockModel, CacheConfig from strands.models.bedrock import ( DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, @@ -3592,9 +3592,9 @@ def test_inject_cache_point_without_ttl(bedrock_client): assert "ttl" not in cache_point -def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type): - """Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint.""" - model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m")) +def test_format_request_cache_tools_with_ttl_from_cache_config(model, messages, model_id, tool_spec, cache_type): + """Test that cache_config.ttl propagates into the toolConfig cachePoint.""" + model.update_config(cache_tools=cache_type, cache_config=CacheConfig(strategy="auto", ttl="5m")) tru_request = model._format_request(messages, tool_specs=[tool_spec]) @@ -3602,18 +3602,8 @@ def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, t assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point -def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type): - """Test that CacheToolsConfig without ttl produces a cachePoint with only type.""" - model.update_config(cache_tools=CacheToolsConfig(type=cache_type)) - - tru_request = model._format_request(messages, tool_specs=[tool_spec]) - - exp_cache_point = {"cachePoint": {"type": cache_type}} - assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point - - -def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type): - """Test that passing cache_tools as a string still produces a cachePoint with only type.""" +def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that toolConfig cachePoint omits TTL when cache_config has no ttl.""" model.update_config(cache_tools=cache_type) tru_request = model._format_request(messages, tool_specs=[tool_spec]) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 509a300f3..4848b1939 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,7 +6,7 @@ import strands from strands import Agent -from strands.models import BedrockModel, CacheConfig, CacheToolsConfig +from strands.models import BedrockModel, CacheConfig from strands.types.content import ContentBlock # Model ID used for prompt-caching TTL integration tests. Per @@ -573,11 +573,11 @@ def calculator(expression: str) -> float: def test_prompt_caching_cache_tools_ttl(): - """Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point. + """Test that cache_config.ttl propagates into the auto-injected toolConfig cache point. - Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a - Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call - completes without a ValidationException on the TTL field. + Verifies that BedrockModel(cache_tools="default", cache_config=CacheConfig(ttl="5m")) produces + a Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call completes + without a ValidationException on the TTL field. Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig prefix because Bedrock's tool-prefix cache threshold varies by model and region. @@ -590,7 +590,8 @@ def test_prompt_caching_cache_tools_ttl(): model = BedrockModel( model_id=_CACHE_TTL_MODEL_ID, streaming=False, - cache_tools=CacheToolsConfig(type="default", ttl="5m"), + cache_tools="default", + cache_config=CacheConfig(strategy="auto", ttl="5m"), ) @strands.tool @@ -664,7 +665,7 @@ def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): model = BedrockModel( model_id=_CACHE_TTL_MODEL_ID, streaming=False, - cache_tools=CacheToolsConfig(type="default", ttl="1h"), + cache_tools="default", cache_config=CacheConfig(strategy="auto", ttl="1h"), ) From 38c9427a2601ccc0bcc5d2794b3e128cad25b497 Mon Sep 17 00:00:00 2001 From: Kien Pham Date: Tue, 19 May 2026 09:59:16 -0700 Subject: [PATCH 4/5] Revert "refactor(bedrock): drop CacheToolsConfig, source tool TTL from cache_config" This reverts commit 8759076e360c2ed87158e84fc44c45cf7578e2dc. --- src/strands/models/__init__.py | 3 +- src/strands/models/bedrock.py | 35 ++++++++++++++++-------- tests/strands/models/test_bedrock.py | 22 +++++++++++---- tests_integ/models/test_model_bedrock.py | 15 +++++----- 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 3a23e257a..2d78f3048 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -6,7 +6,7 @@ from typing import Any from . import bedrock, model -from .bedrock import BedrockModel +from .bedrock import BedrockModel, CacheToolsConfig from .model import BaseModelConfig, CacheConfig, Model __all__ = [ @@ -15,6 +15,7 @@ "BaseModelConfig", "BedrockModel", "CacheConfig", + "CacheToolsConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b4b25450..6b56b0017 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -9,6 +9,7 @@ import os import warnings from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView +from dataclasses import dataclass from typing import Any, Literal, TypeVar, cast import boto3 @@ -69,6 +70,19 @@ def _clear_skip_count_tokens_cache() -> None: DEFAULT_READ_TIMEOUT = 120 +@dataclass +class CacheToolsConfig: + """Configuration for the Bedrock toolConfig cache point. + + Attributes: + type: Cache point type (e.g. "default"). + ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). + """ + + type: str = "default" + ttl: str | None = None + + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -90,10 +104,8 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. - When set, ``cache_config.ttl`` is also applied to the ``cache_tools`` checkpoint (if configured) - to satisfy Bedrock's non-increasing TTL ordering rule across checkpoints. - cache_tools: Cache point type for tools (e.g. "default"). The TTL on this checkpoint is taken - from ``cache_config.ttl`` when set. + cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL, + or a CacheToolsConfig instance to set both type and TTL (e.g. "1h"). guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -130,7 +142,7 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: list[str] | None cache_prompt: str | None cache_config: CacheConfig | None - cache_tools: str | None + cache_tools: str | CacheToolsConfig | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -373,9 +385,6 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict def _build_tools_cache_point(self) -> list[dict[str, Any]]: """Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured. - TTL is sourced from ``cache_config.ttl`` so that all SDK-managed cache checkpoints (toolConfig, - message) share a single TTL — Bedrock requires TTLs across checkpoints to be non-increasing. - Returns: A single-element list containing the cache point block, or an empty list if no cache_tools is set. """ @@ -383,10 +392,12 @@ def _build_tools_cache_point(self) -> list[dict[str, Any]]: if not cache_tools: return [] - cache_point: dict[str, Any] = {"type": cache_tools} - cache_config = self.config.get("cache_config") - if cache_config and cache_config.ttl: - cache_point["ttl"] = cache_config.ttl + if isinstance(cache_tools, CacheToolsConfig): + cache_point: dict[str, Any] = {"type": cache_tools.type} + if cache_tools.ttl: + cache_point["ttl"] = cache_tools.ttl + else: + cache_point = {"type": cache_tools} return [{"cachePoint": cache_point}] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e04ea3dcf..319b5574f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -14,7 +14,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel, CacheConfig +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.models.bedrock import ( DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, @@ -3592,9 +3592,9 @@ def test_inject_cache_point_without_ttl(bedrock_client): assert "ttl" not in cache_point -def test_format_request_cache_tools_with_ttl_from_cache_config(model, messages, model_id, tool_spec, cache_type): - """Test that cache_config.ttl propagates into the toolConfig cachePoint.""" - model.update_config(cache_tools=cache_type, cache_config=CacheConfig(strategy="auto", ttl="5m")) +def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m")) tru_request = model._format_request(messages, tool_specs=[tool_spec]) @@ -3602,8 +3602,18 @@ def test_format_request_cache_tools_with_ttl_from_cache_config(model, messages, assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point -def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec, cache_type): - """Test that toolConfig cachePoint omits TTL when cache_config has no ttl.""" +def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig without ttl produces a cachePoint with only type.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type)) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type): + """Test that passing cache_tools as a string still produces a cachePoint with only type.""" model.update_config(cache_tools=cache_type) tru_request = model._format_request(messages, tool_specs=[tool_spec]) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 4848b1939..509a300f3 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,7 +6,7 @@ import strands from strands import Agent -from strands.models import BedrockModel, CacheConfig +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.types.content import ContentBlock # Model ID used for prompt-caching TTL integration tests. Per @@ -573,11 +573,11 @@ def calculator(expression: str) -> float: def test_prompt_caching_cache_tools_ttl(): - """Test that cache_config.ttl propagates into the auto-injected toolConfig cache point. + """Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point. - Verifies that BedrockModel(cache_tools="default", cache_config=CacheConfig(ttl="5m")) produces - a Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call completes - without a ValidationException on the TTL field. + Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a + Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call + completes without a ValidationException on the TTL field. Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig prefix because Bedrock's tool-prefix cache threshold varies by model and region. @@ -590,8 +590,7 @@ def test_prompt_caching_cache_tools_ttl(): model = BedrockModel( model_id=_CACHE_TTL_MODEL_ID, streaming=False, - cache_tools="default", - cache_config=CacheConfig(strategy="auto", ttl="5m"), + cache_tools=CacheToolsConfig(type="default", ttl="5m"), ) @strands.tool @@ -665,7 +664,7 @@ def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): model = BedrockModel( model_id=_CACHE_TTL_MODEL_ID, streaming=False, - cache_tools="default", + cache_tools=CacheToolsConfig(type="default", ttl="1h"), cache_config=CacheConfig(strategy="auto", ttl="1h"), ) From f255d9b6e9ed817dcfc647ee34d04a29159d77f1 Mon Sep 17 00:00:00 2001 From: Kien Pham Date: Tue, 19 May 2026 10:02:37 -0700 Subject: [PATCH 5/5] refactor(models): move CacheToolsConfig to model.py Address review feedback (#2232 review): define and export CacheToolsConfig from model.py instead of bedrock.py so the dataclass lives next to CacheConfig rather than in a model-provider-specific module. --- src/strands/models/__init__.py | 4 ++-- src/strands/models/bedrock.py | 16 +--------------- src/strands/models/model.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 2d78f3048..8ae660da0 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -6,8 +6,8 @@ from typing import Any from . import bedrock, model -from .bedrock import BedrockModel, CacheToolsConfig -from .model import BaseModelConfig, CacheConfig, Model +from .bedrock import BedrockModel +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model __all__ = [ "bedrock", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 6b56b0017..4cd6f7fbc 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -9,7 +9,6 @@ import os import warnings from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView -from dataclasses import dataclass from typing import Any, Literal, TypeVar, cast import boto3 @@ -35,7 +34,7 @@ from ._defaults import resolve_config_metadata from ._strict_schema import ensure_strict_json_schema from ._validation import validate_config_keys -from .model import BaseModelConfig, CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model logger = logging.getLogger(__name__) @@ -70,19 +69,6 @@ def _clear_skip_count_tokens_cache() -> None: DEFAULT_READ_TIMEOUT = 120 -@dataclass -class CacheToolsConfig: - """Configuration for the Bedrock toolConfig cache point. - - Attributes: - type: Cache point type (e.g. "default"). - ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). - """ - - type: str = "default" - ttl: str | None = None - - class BedrockModel(Model): """AWS Bedrock model provider implementation. diff --git a/src/strands/models/model.py b/src/strands/models/model.py index d8ff91625..77ef1df40 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -142,6 +142,19 @@ class CacheConfig: ttl: str | None = None +@dataclass +class CacheToolsConfig: + """Configuration for the toolConfig cache point. + + Attributes: + type: Cache point type (e.g. "default"). + ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). + """ + + type: str = "default" + ttl: str | None = None + + class Model(abc.ABC): """Abstract base class for Agent model providers.