Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

from . import bedrock, model
from .bedrock import BedrockModel
from .model import BaseModelConfig, CacheConfig, Model
from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model

__all__ = [
"bedrock",
"model",
"BaseModelConfig",
"BedrockModel",
"CacheConfig",
"CacheToolsConfig",
"Model",
]

Expand Down
38 changes: 29 additions & 9 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._defaults import resolve_config_metadata
from ._strict_schema import ensure_strict_json_schema
from ._validation import validate_config_keys
from .model import BaseModelConfig, CacheConfig, Model
from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +90,8 @@ class BedrockConfig(BaseModelConfig, total=False):
additional_response_field_paths: Additional response field paths to extract
cache_prompt: Cache point type for the system prompt (deprecated, use cache_config)
cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching.
cache_tools: Cache point type for tools
cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL,
or a CacheToolsConfig instance to set both type and TTL (e.g. "1h").
guardrail_id: ID of the guardrail to apply
guardrail_trace: Guardrail trace mode. Defaults to enabled.
guardrail_version: Version of the guardrail to apply
Expand Down Expand Up @@ -127,7 +128,7 @@ class BedrockConfig(BaseModelConfig, total=False):
additional_response_field_paths: list[str] | None
cache_prompt: str | None
cache_config: CacheConfig | None
cache_tools: str | None
cache_tools: str | CacheToolsConfig | None
guardrail_id: str | None
guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None
guardrail_stream_processing_mode: Literal["sync", "async"] | None
Expand Down Expand Up @@ -292,11 +293,7 @@ def _format_request(
}
for tool_spec in tool_specs
],
*(
[{"cachePoint": {"type": self.config["cache_tools"]}}]
if self.config.get("cache_tools")
else []
),
*self._build_tools_cache_point(),
],
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
}
Expand Down Expand Up @@ -371,6 +368,25 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict

return {"additionalModelRequestFields": additional_fields}

def _build_tools_cache_point(self) -> list[dict[str, Any]]:
"""Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured.

Returns:
A single-element list containing the cache point block, or an empty list if no cache_tools is set.
"""
cache_tools = self.config.get("cache_tools")
if not cache_tools:
return []

if isinstance(cache_tools, CacheToolsConfig):
cache_point: dict[str, Any] = {"type": cache_tools.type}
if cache_tools.ttl:
cache_point["ttl"] = cache_tools.ttl
else:
cache_point = {"type": cache_tools}

return [{"cachePoint": cache_point}]

def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
"""Inject a cache point at the end of the last user message.

Expand All @@ -395,7 +411,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
last_user_idx = msg_idx

if last_user_idx is not None and messages[last_user_idx].get("content"):
messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}})
cache_point: dict[str, Any] = {"type": "default"}
cache_config = self.config.get("cache_config")
if cache_config and cache_config.ttl:
cache_point["ttl"] = cache_config.ttl
messages[last_user_idx]["content"].append({"cachePoint": cache_point})
logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx)

def _find_last_user_text_message_index(self, messages: Messages) -> int | None:
Expand Down
16 changes: 16 additions & 0 deletions src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,25 @@ class CacheConfig:
strategy: Caching strategy to use.
- "auto": Automatically detect model support and inject cachePoint to maximize cache coverage
- "anthropic": Inject cachePoint in Anthropic-compatible format without model support check
ttl: Optional TTL duration for cache entries (e.g. "5m", "1h").
When specified, auto-injected cache points will include this TTL value.
"""

strategy: Literal["auto", "anthropic"] = "auto"
ttl: str | None = None


@dataclass
class CacheToolsConfig:
"""Configuration for the toolConfig cache point.

Attributes:
type: Cache point type (e.g. "default").
ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h").
"""

type: str = "default"
ttl: str | None = None


class Model(abc.ABC):
Expand Down
68 changes: 67 additions & 1 deletion tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import strands
from strands import _exception_notes
from strands.models import BedrockModel, CacheConfig
from strands.models import BedrockModel, CacheConfig, CacheToolsConfig
from strands.models.bedrock import (
DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_REGION,
Expand Down Expand Up @@ -3554,3 +3554,69 @@ async def test_skip_native_api_by_default(self, bedrock_client, model_id, messag
bedrock_client.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0


def test_inject_cache_point_with_ttl(bedrock_client):
"""Test that _inject_cache_point includes TTL when cache_config has ttl set."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
cache_config=CacheConfig(strategy="auto", ttl="5m"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(cleaned_messages)

cache_point = cleaned_messages[0]["content"][-1]["cachePoint"]
assert cache_point["type"] == "default"
assert cache_point["ttl"] == "5m"


def test_inject_cache_point_without_ttl(bedrock_client):
"""Test that _inject_cache_point omits TTL when cache_config has no ttl."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
cache_config=CacheConfig(strategy="auto"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(cleaned_messages)

cache_point = cleaned_messages[0]["content"][-1]["cachePoint"]
assert cache_point["type"] == "default"
assert "ttl" not in cache_point


def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type):
"""Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint."""
model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m"))

tru_request = model._format_request(messages, tool_specs=[tool_spec])

exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}}
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point


def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type):
"""Test that CacheToolsConfig without ttl produces a cachePoint with only type."""
model.update_config(cache_tools=CacheToolsConfig(type=cache_type))

tru_request = model._format_request(messages, tool_specs=[tool_spec])

exp_cache_point = {"cachePoint": {"type": cache_type}}
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point


def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type):
"""Test that passing cache_tools as a string still produces a cachePoint with only type."""
model.update_config(cache_tools=cache_type)

tru_request = model._format_request(messages, tool_specs=[tool_spec])

exp_cache_point = {"cachePoint": {"type": cache_type}}
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point
135 changes: 134 additions & 1 deletion tests_integ/models/test_model_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,18 @@

import strands
from strands import Agent
from strands.models import BedrockModel
from strands.models import BedrockModel, CacheConfig, CacheToolsConfig
from strands.types.content import ContentBlock

# Model ID used for prompt-caching TTL integration tests. Per
# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5,
# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku
# available and is preferred for CI due to lower latency and cost relative to
# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that
# supports CachePoint TTL.
_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0"


@pytest.fixture
def system_prompt():
Expand Down Expand Up @@ -561,3 +570,127 @@ def calculator(expression: str) -> float:
agent('Search for "python" with tags ["programming", "language"] using the search tool.')

assert "search" in tools_called


def test_prompt_caching_cache_tools_ttl():
"""Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point.

Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a
Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call
completes without a ValidationException on the TTL field.

Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig
prefix because Bedrock's tool-prefix cache threshold varies by model and region.
The critical behavior under test here is that the TTL field is accepted end-to-end.

Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
(Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL).
"""
model = BedrockModel(
model_id=_CACHE_TTL_MODEL_ID,
streaming=False,
cache_tools=CacheToolsConfig(type="default", ttl="5m"),
)

@strands.tool
def lookup_fact(topic: str) -> str:
"""Look up a fact about the given topic.

This tool is useful when you need authoritative information.
"""
return f"Fact about {topic}: example"

agent = Agent(
model=model,
tools=[lookup_fact],
load_tools_from_directory=False,
)

# The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint
# without raising a ValidationException.
result = agent("Use the lookup_fact tool to look up 'python'.")
assert len(str(result)) > 0


def test_prompt_caching_cache_config_auto_with_ttl():
"""Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point.

Verifies that the cache point appended to the last user message by _inject_cache_point
carries the configured TTL, and that Bedrock accepts the request.

Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
"""
model = BedrockModel(
model_id=_CACHE_TTL_MODEL_ID,
streaming=False,
cache_config=CacheConfig(strategy="auto", ttl="5m"),
)

unique_id = str(uuid.uuid4())
# Minimum 4096 tokens required for caching with Haiku 4.5
large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?"

agent = Agent(
model=model,
load_tools_from_directory=False,
)

# First call: auto-injected cache point on the last user message must include ttl and be accepted
result1 = agent(large_message)
assert len(str(result1)) > 0

# Verify cache write occurred with auto-inject + ttl
assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
"Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')"
)


def test_prompt_caching_aligned_1h_ttl_across_checkpoints():
"""Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121).

Bedrock processes cache checkpoints in order: toolConfig -> system -> messages,
and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded
an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a
ValidationException.

This test sets 1h TTL on all three checkpoints simultaneously and verifies the
call succeeds.

Uses Claude Haiku 4.5 which supports 1h TTL per
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
"""
model = BedrockModel(
model_id=_CACHE_TTL_MODEL_ID,
streaming=False,
cache_tools=CacheToolsConfig(type="default", ttl="1h"),
cache_config=CacheConfig(strategy="auto", ttl="1h"),
)

# Timestamp-based uniqueness to avoid cache conflicts across CI runs
unique_id = str(int(time.time() * 1000000))
large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000)

# User-supplied 1h cache point on system prompt — third checkpoint also at 1h
system_prompt_with_cache = [
{"text": large_context},
{"cachePoint": {"type": "default", "ttl": "1h"}},
{"text": "You are a helpful assistant."},
]

@strands.tool
def echo(value: str) -> str:
"""Echo the given value back."""
return value

agent = Agent(
model=model,
system_prompt=system_prompt_with_cache,
tools=[echo],
load_tools_from_directory=False,
)

# Must succeed without ValidationException on the non-increasing TTL rule
result = agent("What is 2+2?")
assert len(str(result)) > 0
Loading